| """ |
| RAG pipeline: retrieve → generate → grade. |
| |
| Retrieval: in-memory semantic search (sentence-transformers, encoded at first use per domain). |
| Generation: Llama-3 via HF Inference API with retrieved context injected as grounding. |
| Grading: L1 metrics via grader.py. |
| |
| Knowledge base formats supported: |
| - features.yaml (default) — {documents: [{id, title, content, tags}, ...]} |
| - features.json — [{id, title, content, tags}, ...] |
| - features.csv — columns: id, title, content, tags (tags as comma-sep string) |
| """ |
|
|
| import csv |
| import json |
| import logging |
| import random |
| import re |
| import time |
| from collections.abc import Sequence |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import yaml |
| from config import DISPLAY_NAMES, domain_for, features_path |
| from grader import GradeReport, get_embedder, grade |
| from huggingface_hub import InferenceClient |
| from rosetta import client_terms, client_terms_doc |
| from sentence_transformers import SentenceTransformer |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| log = logging.getLogger(__name__) |
|
|
| TOP_K = 3 |
| MIN_RETRIEVAL_SCORE = 0.1 |
|
|
| SYSTEM_PROMPT = """\ |
| You are a helpful assistant for {client_display} ({domain} domain). |
| Answer the user's question using only the information in the provided context. |
| Be concise. |
| |
| If the context does not contain enough information to answer, respond with exactly: |
| NOT IN DOCUMENTS: [one sentence explaining what information is missing] |
| Do not speculate, infer, or use knowledge outside the provided context. |
| |
| You MUST use the following terminology. These are the only acceptable terms — do not substitute synonyms: |
| {term_list}""" |
|
|
|
|
| |
| |
| |
|
|
| def _load_docs(domain: str) -> list[dict[str, Any]]: |
| """ |
| Load and merge knowledge base documents for a domain from all present files. |
| |
| All formats loaded if they exist (merged in this order): |
| 1. features.yaml — curated process/workflow docs |
| 2. features.json — flat list of document dicts |
| 3. features.csv — Kaggle-style drug/product dataset |
| |
| All formats normalised to: |
| [{"id": str, "title": str, "content": str, "tags": list[str]}, ...] |
| """ |
| base = features_path(domain).parent |
| docs: list[dict[str, Any]] = [] |
|
|
| yaml_path = base / "features.yaml" |
| if yaml_path.exists(): |
| data = yaml.safe_load(yaml_path.read_text()) |
| docs.extend(data["documents"]) |
|
|
| json_path = base / "features.json" |
| if json_path.exists(): |
| docs.extend(_normalise_docs(json.loads(json_path.read_text()))) |
|
|
| csv_path = base / "features.csv" |
| if csv_path.exists(): |
| docs.extend(_load_csv_docs(csv_path)) |
|
|
| if not docs: |
| raise FileNotFoundError( |
| f"No knowledge base found for domain '{domain}'. " |
| f"Expected one of: {yaml_path}, {json_path}, {csv_path}" |
| ) |
|
|
| return docs |
|
|
|
|
| def _normalise_docs(docs: list[dict[str, Any]]) -> list[dict[str, Any]]: |
| """Ensure tags is always a list, not a string.""" |
| for doc in docs: |
| if isinstance(doc.get("tags"), str): |
| doc["tags"] = [t.strip() for t in doc["tags"].split(",") if t.strip()] |
| elif "tags" not in doc: |
| doc["tags"] = [] |
| return docs |
|
|
|
|
| def _slugify(s: str) -> str: |
| return re.sub(r"[^a-z0-9]+", "-", s.lower()).strip("-") |
|
|
|
|
| _CSV_DRUG_COLS = ["drug_name", "name", "Drug Name", "drugName"] |
| _CSV_USE_COLS = ["medical_condition", "condition", "use", "uses", "indication"] |
| _CSV_EFFECT_COLS = ["side_effects", "Side_Effects", "sideEffects", "adverse_effects"] |
|
|
|
|
| def _find_col(headers: Sequence[str], candidates: list[str]) -> str | None: |
| h_lower = {h.lower().strip(): h for h in headers} |
| for c in candidates: |
| if c in headers: |
| return c |
| if c.lower().strip() in h_lower: |
| return h_lower[c.lower().strip()] |
| return None |
|
|
|
|
| def _truncate_list(text: str, max_items: int = 8) -> str: |
| sep = "," if text.count(",") >= text.count(";") else ";" |
| items = [i.strip() for i in text.split(sep) if i.strip()] |
| if len(items) > max_items: |
| return ", ".join(items[:max_items]) + ", and others" |
| return ", ".join(items) |
|
|
|
|
| def _load_csv_docs(path: Path) -> list[dict[str, Any]]: |
| """ |
| Load a drug/side-effects CSV into the standard document format. |
| |
| Handles Kaggle-style CSVs with drug name, condition, and side effects columns. |
| Falls back to generic id/title/content/tags columns if detected. |
| """ |
| docs = [] |
| with path.open(newline="", encoding="utf-8-sig") as f: |
| reader = csv.DictReader(f) |
| headers = reader.fieldnames or [] |
|
|
| |
| if all(c in headers for c in ("id", "title", "content")): |
| log.info("CSV detected as native format (id/title/content)") |
| for row in reader: |
| tags_raw = row.get("tags", "") |
| tags = [t.strip() for t in tags_raw.split(",") if t.strip()] |
| docs.append({ |
| "id": row["id"], |
| "title": row["title"], |
| "content": row["content"], |
| "tags": tags, |
| }) |
| return docs |
|
|
| |
| drug_col = _find_col(headers, _CSV_DRUG_COLS) |
| use_col = _find_col(headers, _CSV_USE_COLS) |
| effect_col = _find_col(headers, _CSV_EFFECT_COLS) |
|
|
| if not drug_col: |
| raise ValueError( |
| f"CSV at {path} has no recognisable drug name column. " |
| f"Available columns: {headers}" |
| ) |
|
|
| log.info( |
| "CSV drug dataset — drug: %r, use: %r, effects: %r", |
| drug_col, use_col, effect_col, |
| ) |
|
|
| seen: set[str] = set() |
| for i, row in enumerate(reader): |
| drug = row.get(drug_col, "").strip().title() |
| if not drug or drug.lower() in seen: |
| continue |
| seen.add(drug.lower()) |
|
|
| condition = row.get(use_col, "").strip() if use_col else "" |
| effects = row.get(effect_col, "").strip() if effect_col else "" |
| condition_str = condition or "the indicated condition" |
| effects_str = _truncate_list(effects) if effects else "not listed" |
|
|
| content = ( |
| f"{drug} is indicated for the treatment or management of {condition_str}. " |
| f"Known adverse events associated with {drug} include: {effects_str}. " |
| f"Prescribers should monitor patients for these adverse events and report " |
| f"serious unexpected occurrences to the regulatory authority within 15 days." |
| ) |
|
|
| tags = list(filter(None, [ |
| _slugify(drug), |
| _slugify(condition) if condition else None, |
| "adverse-event", |
| "drug-profile", |
| ])) |
|
|
| docs.append({ |
| "id": f"pharma_{i + 1:03d}", |
| "title": f"{drug} — Drug Profile", |
| "content": content, |
| "tags": tags, |
| }) |
|
|
| log.info("Loaded %d documents from CSV %s", len(docs), path) |
| return docs |
|
|
|
|
| |
| |
| |
|
|
| @dataclass(slots=True) |
| class RetrievedDoc: |
| id: str |
| title: str |
| content: str |
| score: float |
|
|
|
|
| @dataclass(slots=True) |
| class PipelineResult: |
| query: str |
| client: str |
| answer: str |
| retrieved_docs: list[RetrievedDoc] |
| grade_report: GradeReport |
| context_used: str |
|
|
| @property |
| def response_payload(self) -> dict[str, Any]: |
| return { |
| "query": self.query, |
| "client": self.client, |
| "client_display": DISPLAY_NAMES.get(self.client, self.client), |
| "answer": self.answer, |
| "flagged": not self.grade_report.overall, |
| "sources": [ |
| {"id": d.id, "title": d.title, "score": round(d.score, 3)} |
| for d in self.retrieved_docs |
| ], |
| "evaluation": self.grade_report.summary, |
| } |
|
|
|
|
| @dataclass(slots=True) |
| class KBIndex: |
| docs: list[dict[str, Any]] |
| embeddings: np.ndarray |
|
|
|
|
| _index_cache: dict[str, KBIndex] = {} |
|
|
|
|
| def clear_index_cache() -> list[str]: |
| """Evict all cached KB indexes. Returns list of evicted domain names.""" |
| evicted = list(_index_cache.keys()) |
| _index_cache.clear() |
| return evicted |
|
|
|
|
| def _build_index(domain: str, embedder: SentenceTransformer) -> KBIndex: |
| if domain not in _index_cache: |
| docs = _load_docs(domain) |
| texts = [f"{d['title']}. {d['content']}" for d in docs] |
| embeddings = embedder.encode(texts, show_progress_bar=False) |
| _index_cache[domain] = KBIndex(docs=docs, embeddings=np.array(embeddings)) |
| log.info("Built KB index for domain=%s (%d docs)", domain, len(docs)) |
| return _index_cache[domain] |
|
|
|
|
| def _build_context(docs: list[RetrievedDoc]) -> str: |
| return "\n\n".join(f"[{d.title}]\n{d.content.strip()}" for d in docs) |
|
|
|
|
| |
| |
| |
|
|
| _JOKE_PATTERN = re.compile( |
| r"\b(joke|jokes|funny|laugh|humor|humour|dad joke|daddy joke|pun|make me (laugh|smile))\b", |
| re.IGNORECASE, |
| ) |
|
|
| _DOMAIN_JOKES: dict[str, list[str]] = { |
| "retail": [ |
| "Why did the inventory go to therapy? It had too many unresolved stockouts.", |
| "What do you call a shelf that never runs out? A shelf-fulfilling prophecy.", |
| "Why don't retail systems ever get lonely? Because they're always in stock-ing.", |
| "I tried to write a joke about planograms… but I couldn't find the right placement.", |
| "Why did the out-of-stock alert break up with the replenishment order? It said: 'You always come too late.'", |
| ], |
| "pharma": [ |
| "Why did the adverse event report go to the doctor? It had too many side effects.", |
| "What do you call a drug that's always on time? Punctu-pill.", |
| "Why are pharmacovigilance teams great at parties? They always follow up.", |
| "I asked the formulary for a joke. It said: 'Prior authorization required.'", |
| "Why did the clinical trial fail to finish the joke? Insufficient sample size.", |
| ], |
| } |
|
|
| _FALLBACK_JOKES = [ |
| "Why did the AI refuse to tell a joke? Insufficient context. Please provide grounding documents.", |
| "I'm not great at jokes — my training data was mostly compliance documents.", |
| "Why did the chatbot cross the road? To reduce hallucination on the other side.", |
| ] |
|
|
| _RESET_PATTERNS = re.compile( |
| r"\b(forget|ignore|pretend|blank slate|no context|disregard|system prompt|jailbreak)\b", |
| re.IGNORECASE, |
| ) |
|
|
|
|
| def _is_joke_query(query: str) -> bool: |
| return bool(_JOKE_PATTERN.search(query)) |
|
|
|
|
| def _joke_response(domain: str) -> str: |
| jokes = _DOMAIN_JOKES.get(domain, _FALLBACK_JOKES) |
| return random.choice(jokes) |
|
|
|
|
| def _is_reset_attempt(query: str) -> bool: |
| return bool(_RESET_PATTERNS.search(query)) |
|
|
|
|
| |
| |
| |
|
|
| HF_GENERATION_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
|
|
|
| def _generate( |
| query: str, |
| context: str, |
| client: str, |
| domain: str, |
| hf_client: InferenceClient, |
| ) -> str: |
| terms = client_terms(client) |
| term_list = "\n".join(f"- {v}" for v in terms.values()) if terms else "(none)" |
| system = SYSTEM_PROMPT.format( |
| client_display=DISPLAY_NAMES.get(client, client), |
| domain=domain, |
| term_list=term_list, |
| ) |
| response = hf_client.chat.completions.create( |
| model=HF_GENERATION_MODEL, |
| messages=[ |
| {"role": "system", "content": system}, |
| {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}, |
| ], |
| max_tokens=512, |
| ) |
| return str(response.choices[0].message.content or "").strip() |
|
|
|
|
| |
| |
| |
|
|
| def run( |
| query: str, |
| client: str, |
| hf_client: InferenceClient, |
| top_k: int = TOP_K, |
| ) -> PipelineResult: |
| """Retrieve relevant KB docs, generate a grounded answer, and grade it.""" |
| domain = domain_for(client) |
|
|
| if _is_joke_query(query) or _is_reset_attempt(query): |
| answer = _joke_response(domain) |
| report = grade(query=query, response=answer, context=answer, client=client) |
| return PipelineResult( |
| query=query, client=client, answer=answer, |
| retrieved_docs=[], grade_report=report, context_used="", |
| ) |
|
|
| embedder = get_embedder() |
| index = _build_index(domain, embedder) |
|
|
| t0 = time.perf_counter() |
| q_vec = embedder.encode([query]) |
| scores = cosine_similarity(q_vec, index.embeddings)[0] |
| top_indices = np.argsort(scores)[::-1][:top_k] |
| retrieved = [ |
| RetrievedDoc( |
| id=index.docs[i]["id"], |
| title=index.docs[i]["title"], |
| content=index.docs[i]["content"], |
| score=float(scores[i]), |
| ) |
| for i in top_indices |
| if scores[i] > MIN_RETRIEVAL_SCORE |
| ] |
| t_retrieve = (time.perf_counter() - t0) * 1000 |
|
|
| terms_doc = client_terms_doc(client) |
| pinned = RetrievedDoc( |
| id=terms_doc["id"], |
| title=terms_doc["title"], |
| content=terms_doc["content"], |
| score=1.0, |
| ) |
| |
| generation_context = _build_context([pinned, *retrieved]) |
| grading_context = _build_context(retrieved) |
|
|
| t1 = time.perf_counter() |
| answer = _generate(query, generation_context, client, domain, hf_client) |
| t_generate = (time.perf_counter() - t1) * 1000 |
|
|
| t2 = time.perf_counter() |
| report = grade( |
| query=query, |
| response=answer, |
| context=grading_context, |
| client=client, |
| ) |
| t_grade = (time.perf_counter() - t2) * 1000 |
|
|
| import telemetry |
| telemetry.record( |
| client=client, |
| domain=domain, |
| query_len=len(query), |
| latency_ms={"retrieve": t_retrieve, "generate": t_generate, "grade": t_grade}, |
| report=report, |
| docs_retrieved=len(retrieved), |
| min_retrieval_score=float(min(s.score for s in retrieved)) if retrieved else 0.0, |
| ) |
|
|
| return PipelineResult( |
| query=query, |
| client=client, |
| answer=answer, |
| retrieved_docs=retrieved, |
| grade_report=report, |
| context_used=grading_context, |
| ) |
|
|