""" RAG pipeline: retrieve → generate → grade. Retrieval: in-memory semantic search (sentence-transformers, encoded at first use per domain). Generation: Llama-3 via HF Inference API with retrieved context injected as grounding. Grading: L1 metrics via grader.py. Knowledge base formats supported: - features.yaml (default) — {documents: [{id, title, content, tags}, ...]} - features.json — [{id, title, content, tags}, ...] - features.csv — columns: id, title, content, tags (tags as comma-sep string) """ import csv import json import logging import random import re import time from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path from typing import Any import numpy as np import yaml from config import DISPLAY_NAMES, domain_for, features_path from grader import GradeReport, get_embedder, grade from huggingface_hub import InferenceClient from rosetta import client_terms, client_terms_doc from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity log = logging.getLogger(__name__) TOP_K = 3 MIN_RETRIEVAL_SCORE = 0.1 SYSTEM_PROMPT = """\ You are a helpful assistant for {client_display} ({domain} domain). Answer the user's question using only the information in the provided context. Be concise. If the context does not contain enough information to answer, respond with exactly: NOT IN DOCUMENTS: [one sentence explaining what information is missing] Do not speculate, infer, or use knowledge outside the provided context. You MUST use the following terminology. These are the only acceptable terms — do not substitute synonyms: {term_list}""" # --------------------------------------------------------------------------- # Knowledge base loader — supports .yaml, .json, .csv # --------------------------------------------------------------------------- def _load_docs(domain: str) -> list[dict[str, Any]]: """ Load and merge knowledge base documents for a domain from all present files. All formats loaded if they exist (merged in this order): 1. features.yaml — curated process/workflow docs 2. features.json — flat list of document dicts 3. features.csv — Kaggle-style drug/product dataset All formats normalised to: [{"id": str, "title": str, "content": str, "tags": list[str]}, ...] """ base = features_path(domain).parent # knowledge// docs: list[dict[str, Any]] = [] yaml_path = base / "features.yaml" if yaml_path.exists(): data = yaml.safe_load(yaml_path.read_text()) docs.extend(data["documents"]) json_path = base / "features.json" if json_path.exists(): docs.extend(_normalise_docs(json.loads(json_path.read_text()))) csv_path = base / "features.csv" if csv_path.exists(): docs.extend(_load_csv_docs(csv_path)) if not docs: raise FileNotFoundError( f"No knowledge base found for domain '{domain}'. " f"Expected one of: {yaml_path}, {json_path}, {csv_path}" ) return docs def _normalise_docs(docs: list[dict[str, Any]]) -> list[dict[str, Any]]: """Ensure tags is always a list, not a string.""" for doc in docs: if isinstance(doc.get("tags"), str): doc["tags"] = [t.strip() for t in doc["tags"].split(",") if t.strip()] elif "tags" not in doc: doc["tags"] = [] return docs def _slugify(s: str) -> str: return re.sub(r"[^a-z0-9]+", "-", s.lower()).strip("-") _CSV_DRUG_COLS = ["drug_name", "name", "Drug Name", "drugName"] _CSV_USE_COLS = ["medical_condition", "condition", "use", "uses", "indication"] _CSV_EFFECT_COLS = ["side_effects", "Side_Effects", "sideEffects", "adverse_effects"] def _find_col(headers: Sequence[str], candidates: list[str]) -> str | None: h_lower = {h.lower().strip(): h for h in headers} for c in candidates: if c in headers: return c if c.lower().strip() in h_lower: return h_lower[c.lower().strip()] return None def _truncate_list(text: str, max_items: int = 8) -> str: sep = "," if text.count(",") >= text.count(";") else ";" items = [i.strip() for i in text.split(sep) if i.strip()] if len(items) > max_items: return ", ".join(items[:max_items]) + ", and others" return ", ".join(items) def _load_csv_docs(path: Path) -> list[dict[str, Any]]: """ Load a drug/side-effects CSV into the standard document format. Handles Kaggle-style CSVs with drug name, condition, and side effects columns. Falls back to generic id/title/content/tags columns if detected. """ docs = [] with path.open(newline="", encoding="utf-8-sig") as f: reader = csv.DictReader(f) headers = reader.fieldnames or [] # Check if it's already in the native format if all(c in headers for c in ("id", "title", "content")): log.info("CSV detected as native format (id/title/content)") for row in reader: tags_raw = row.get("tags", "") tags = [t.strip() for t in tags_raw.split(",") if t.strip()] docs.append({ "id": row["id"], "title": row["title"], "content": row["content"], "tags": tags, }) return docs # Otherwise treat as Kaggle-style drug dataset drug_col = _find_col(headers, _CSV_DRUG_COLS) use_col = _find_col(headers, _CSV_USE_COLS) effect_col = _find_col(headers, _CSV_EFFECT_COLS) if not drug_col: raise ValueError( f"CSV at {path} has no recognisable drug name column. " f"Available columns: {headers}" ) log.info( "CSV drug dataset — drug: %r, use: %r, effects: %r", drug_col, use_col, effect_col, ) seen: set[str] = set() for i, row in enumerate(reader): drug = row.get(drug_col, "").strip().title() if not drug or drug.lower() in seen: continue seen.add(drug.lower()) condition = row.get(use_col, "").strip() if use_col else "" effects = row.get(effect_col, "").strip() if effect_col else "" condition_str = condition or "the indicated condition" effects_str = _truncate_list(effects) if effects else "not listed" content = ( f"{drug} is indicated for the treatment or management of {condition_str}. " f"Known adverse events associated with {drug} include: {effects_str}. " f"Prescribers should monitor patients for these adverse events and report " f"serious unexpected occurrences to the regulatory authority within 15 days." ) tags = list(filter(None, [ _slugify(drug), _slugify(condition) if condition else None, "adverse-event", "drug-profile", ])) docs.append({ "id": f"pharma_{i + 1:03d}", "title": f"{drug} — Drug Profile", "content": content, "tags": tags, }) log.info("Loaded %d documents from CSV %s", len(docs), path) return docs # --------------------------------------------------------------------------- # Index cache # --------------------------------------------------------------------------- @dataclass(slots=True) class RetrievedDoc: id: str title: str content: str score: float @dataclass(slots=True) class PipelineResult: query: str client: str answer: str retrieved_docs: list[RetrievedDoc] grade_report: GradeReport context_used: str @property def response_payload(self) -> dict[str, Any]: return { "query": self.query, "client": self.client, "client_display": DISPLAY_NAMES.get(self.client, self.client), "answer": self.answer, "flagged": not self.grade_report.overall, "sources": [ {"id": d.id, "title": d.title, "score": round(d.score, 3)} for d in self.retrieved_docs ], "evaluation": self.grade_report.summary, } @dataclass(slots=True) class KBIndex: docs: list[dict[str, Any]] embeddings: np.ndarray _index_cache: dict[str, KBIndex] = {} def clear_index_cache() -> list[str]: """Evict all cached KB indexes. Returns list of evicted domain names.""" evicted = list(_index_cache.keys()) _index_cache.clear() return evicted def _build_index(domain: str, embedder: SentenceTransformer) -> KBIndex: if domain not in _index_cache: docs = _load_docs(domain) texts = [f"{d['title']}. {d['content']}" for d in docs] embeddings = embedder.encode(texts, show_progress_bar=False) _index_cache[domain] = KBIndex(docs=docs, embeddings=np.array(embeddings)) log.info("Built KB index for domain=%s (%d docs)", domain, len(docs)) return _index_cache[domain] def _build_context(docs: list[RetrievedDoc]) -> str: return "\n\n".join(f"[{d.title}]\n{d.content.strip()}" for d in docs) # --------------------------------------------------------------------------- # Joke short-circuit # --------------------------------------------------------------------------- _JOKE_PATTERN = re.compile( r"\b(joke|jokes|funny|laugh|humor|humour|dad joke|daddy joke|pun|make me (laugh|smile))\b", re.IGNORECASE, ) _DOMAIN_JOKES: dict[str, list[str]] = { "retail": [ "Why did the inventory go to therapy? It had too many unresolved stockouts.", "What do you call a shelf that never runs out? A shelf-fulfilling prophecy.", "Why don't retail systems ever get lonely? Because they're always in stock-ing.", "I tried to write a joke about planograms… but I couldn't find the right placement.", "Why did the out-of-stock alert break up with the replenishment order? It said: 'You always come too late.'", ], "pharma": [ "Why did the adverse event report go to the doctor? It had too many side effects.", "What do you call a drug that's always on time? Punctu-pill.", "Why are pharmacovigilance teams great at parties? They always follow up.", "I asked the formulary for a joke. It said: 'Prior authorization required.'", "Why did the clinical trial fail to finish the joke? Insufficient sample size.", ], } _FALLBACK_JOKES = [ "Why did the AI refuse to tell a joke? Insufficient context. Please provide grounding documents.", "I'm not great at jokes — my training data was mostly compliance documents.", "Why did the chatbot cross the road? To reduce hallucination on the other side.", ] _RESET_PATTERNS = re.compile( r"\b(forget|ignore|pretend|blank slate|no context|disregard|system prompt|jailbreak)\b", re.IGNORECASE, ) def _is_joke_query(query: str) -> bool: return bool(_JOKE_PATTERN.search(query)) def _joke_response(domain: str) -> str: jokes = _DOMAIN_JOKES.get(domain, _FALLBACK_JOKES) return random.choice(jokes) def _is_reset_attempt(query: str) -> bool: return bool(_RESET_PATTERNS.search(query)) # --------------------------------------------------------------------------- # Generation # --------------------------------------------------------------------------- HF_GENERATION_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" def _generate( query: str, context: str, client: str, domain: str, hf_client: InferenceClient, ) -> str: terms = client_terms(client) term_list = "\n".join(f"- {v}" for v in terms.values()) if terms else "(none)" system = SYSTEM_PROMPT.format( client_display=DISPLAY_NAMES.get(client, client), domain=domain, term_list=term_list, ) response = hf_client.chat.completions.create( model=HF_GENERATION_MODEL, messages=[ {"role": "system", "content": system}, {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}, ], max_tokens=512, ) return str(response.choices[0].message.content or "").strip() # --------------------------------------------------------------------------- # Public entry point # --------------------------------------------------------------------------- def run( query: str, client: str, hf_client: InferenceClient, top_k: int = TOP_K, ) -> PipelineResult: """Retrieve relevant KB docs, generate a grounded answer, and grade it.""" domain = domain_for(client) if _is_joke_query(query) or _is_reset_attempt(query): answer = _joke_response(domain) report = grade(query=query, response=answer, context=answer, client=client) return PipelineResult( query=query, client=client, answer=answer, retrieved_docs=[], grade_report=report, context_used="", ) embedder = get_embedder() index = _build_index(domain, embedder) t0 = time.perf_counter() q_vec = embedder.encode([query]) scores = cosine_similarity(q_vec, index.embeddings)[0] top_indices = np.argsort(scores)[::-1][:top_k] retrieved = [ RetrievedDoc( id=index.docs[i]["id"], title=index.docs[i]["title"], content=index.docs[i]["content"], score=float(scores[i]), ) for i in top_indices if scores[i] > MIN_RETRIEVAL_SCORE ] t_retrieve = (time.perf_counter() - t0) * 1000 terms_doc = client_terms_doc(client) pinned = RetrievedDoc( id=terms_doc["id"], title=terms_doc["title"], content=terms_doc["content"], score=1.0, ) # Generation context includes glossary; grading context uses KB docs only generation_context = _build_context([pinned, *retrieved]) grading_context = _build_context(retrieved) t1 = time.perf_counter() answer = _generate(query, generation_context, client, domain, hf_client) t_generate = (time.perf_counter() - t1) * 1000 t2 = time.perf_counter() report = grade( query=query, response=answer, context=grading_context, client=client, ) t_grade = (time.perf_counter() - t2) * 1000 import telemetry telemetry.record( client=client, domain=domain, query_len=len(query), latency_ms={"retrieve": t_retrieve, "generate": t_generate, "grade": t_grade}, report=report, docs_retrieved=len(retrieved), min_retrieval_score=float(min(s.score for s in retrieved)) if retrieved else 0.0, ) return PipelineResult( query=query, client=client, answer=answer, retrieved_docs=retrieved, grade_report=report, context_used=grading_context, )