mbochniak01
Replace ad-hoc refusal regexes with NOT IN DOCUMENTS sentinel
0ad5e39
"""
RAG pipeline: retrieve → generate → grade.
Retrieval: in-memory semantic search (sentence-transformers, encoded at first use per domain).
Generation: Llama-3 via HF Inference API with retrieved context injected as grounding.
Grading: L1 metrics via grader.py.
Knowledge base formats supported:
- features.yaml (default) — {documents: [{id, title, content, tags}, ...]}
- features.json — [{id, title, content, tags}, ...]
- features.csv — columns: id, title, content, tags (tags as comma-sep string)
"""
import csv
import json
import logging
import random
import re
import time
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import yaml
from config import DISPLAY_NAMES, domain_for, features_path
from grader import GradeReport, get_embedder, grade
from huggingface_hub import InferenceClient
from rosetta import client_terms, client_terms_doc
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
log = logging.getLogger(__name__)
TOP_K = 3
MIN_RETRIEVAL_SCORE = 0.1
SYSTEM_PROMPT = """\
You are a helpful assistant for {client_display} ({domain} domain).
Answer the user's question using only the information in the provided context.
Be concise.
If the context does not contain enough information to answer, respond with exactly:
NOT IN DOCUMENTS: [one sentence explaining what information is missing]
Do not speculate, infer, or use knowledge outside the provided context.
You MUST use the following terminology. These are the only acceptable terms — do not substitute synonyms:
{term_list}"""
# ---------------------------------------------------------------------------
# Knowledge base loader — supports .yaml, .json, .csv
# ---------------------------------------------------------------------------
def _load_docs(domain: str) -> list[dict[str, Any]]:
"""
Load and merge knowledge base documents for a domain from all present files.
All formats loaded if they exist (merged in this order):
1. features.yaml — curated process/workflow docs
2. features.json — flat list of document dicts
3. features.csv — Kaggle-style drug/product dataset
All formats normalised to:
[{"id": str, "title": str, "content": str, "tags": list[str]}, ...]
"""
base = features_path(domain).parent # knowledge/<domain>/
docs: list[dict[str, Any]] = []
yaml_path = base / "features.yaml"
if yaml_path.exists():
data = yaml.safe_load(yaml_path.read_text())
docs.extend(data["documents"])
json_path = base / "features.json"
if json_path.exists():
docs.extend(_normalise_docs(json.loads(json_path.read_text())))
csv_path = base / "features.csv"
if csv_path.exists():
docs.extend(_load_csv_docs(csv_path))
if not docs:
raise FileNotFoundError(
f"No knowledge base found for domain '{domain}'. "
f"Expected one of: {yaml_path}, {json_path}, {csv_path}"
)
return docs
def _normalise_docs(docs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Ensure tags is always a list, not a string."""
for doc in docs:
if isinstance(doc.get("tags"), str):
doc["tags"] = [t.strip() for t in doc["tags"].split(",") if t.strip()]
elif "tags" not in doc:
doc["tags"] = []
return docs
def _slugify(s: str) -> str:
return re.sub(r"[^a-z0-9]+", "-", s.lower()).strip("-")
_CSV_DRUG_COLS = ["drug_name", "name", "Drug Name", "drugName"]
_CSV_USE_COLS = ["medical_condition", "condition", "use", "uses", "indication"]
_CSV_EFFECT_COLS = ["side_effects", "Side_Effects", "sideEffects", "adverse_effects"]
def _find_col(headers: Sequence[str], candidates: list[str]) -> str | None:
h_lower = {h.lower().strip(): h for h in headers}
for c in candidates:
if c in headers:
return c
if c.lower().strip() in h_lower:
return h_lower[c.lower().strip()]
return None
def _truncate_list(text: str, max_items: int = 8) -> str:
sep = "," if text.count(",") >= text.count(";") else ";"
items = [i.strip() for i in text.split(sep) if i.strip()]
if len(items) > max_items:
return ", ".join(items[:max_items]) + ", and others"
return ", ".join(items)
def _load_csv_docs(path: Path) -> list[dict[str, Any]]:
"""
Load a drug/side-effects CSV into the standard document format.
Handles Kaggle-style CSVs with drug name, condition, and side effects columns.
Falls back to generic id/title/content/tags columns if detected.
"""
docs = []
with path.open(newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
headers = reader.fieldnames or []
# Check if it's already in the native format
if all(c in headers for c in ("id", "title", "content")):
log.info("CSV detected as native format (id/title/content)")
for row in reader:
tags_raw = row.get("tags", "")
tags = [t.strip() for t in tags_raw.split(",") if t.strip()]
docs.append({
"id": row["id"],
"title": row["title"],
"content": row["content"],
"tags": tags,
})
return docs
# Otherwise treat as Kaggle-style drug dataset
drug_col = _find_col(headers, _CSV_DRUG_COLS)
use_col = _find_col(headers, _CSV_USE_COLS)
effect_col = _find_col(headers, _CSV_EFFECT_COLS)
if not drug_col:
raise ValueError(
f"CSV at {path} has no recognisable drug name column. "
f"Available columns: {headers}"
)
log.info(
"CSV drug dataset — drug: %r, use: %r, effects: %r",
drug_col, use_col, effect_col,
)
seen: set[str] = set()
for i, row in enumerate(reader):
drug = row.get(drug_col, "").strip().title()
if not drug or drug.lower() in seen:
continue
seen.add(drug.lower())
condition = row.get(use_col, "").strip() if use_col else ""
effects = row.get(effect_col, "").strip() if effect_col else ""
condition_str = condition or "the indicated condition"
effects_str = _truncate_list(effects) if effects else "not listed"
content = (
f"{drug} is indicated for the treatment or management of {condition_str}. "
f"Known adverse events associated with {drug} include: {effects_str}. "
f"Prescribers should monitor patients for these adverse events and report "
f"serious unexpected occurrences to the regulatory authority within 15 days."
)
tags = list(filter(None, [
_slugify(drug),
_slugify(condition) if condition else None,
"adverse-event",
"drug-profile",
]))
docs.append({
"id": f"pharma_{i + 1:03d}",
"title": f"{drug} — Drug Profile",
"content": content,
"tags": tags,
})
log.info("Loaded %d documents from CSV %s", len(docs), path)
return docs
# ---------------------------------------------------------------------------
# Index cache
# ---------------------------------------------------------------------------
@dataclass(slots=True)
class RetrievedDoc:
id: str
title: str
content: str
score: float
@dataclass(slots=True)
class PipelineResult:
query: str
client: str
answer: str
retrieved_docs: list[RetrievedDoc]
grade_report: GradeReport
context_used: str
@property
def response_payload(self) -> dict[str, Any]:
return {
"query": self.query,
"client": self.client,
"client_display": DISPLAY_NAMES.get(self.client, self.client),
"answer": self.answer,
"flagged": not self.grade_report.overall,
"sources": [
{"id": d.id, "title": d.title, "score": round(d.score, 3)}
for d in self.retrieved_docs
],
"evaluation": self.grade_report.summary,
}
@dataclass(slots=True)
class KBIndex:
docs: list[dict[str, Any]]
embeddings: np.ndarray
_index_cache: dict[str, KBIndex] = {}
def clear_index_cache() -> list[str]:
"""Evict all cached KB indexes. Returns list of evicted domain names."""
evicted = list(_index_cache.keys())
_index_cache.clear()
return evicted
def _build_index(domain: str, embedder: SentenceTransformer) -> KBIndex:
if domain not in _index_cache:
docs = _load_docs(domain)
texts = [f"{d['title']}. {d['content']}" for d in docs]
embeddings = embedder.encode(texts, show_progress_bar=False)
_index_cache[domain] = KBIndex(docs=docs, embeddings=np.array(embeddings))
log.info("Built KB index for domain=%s (%d docs)", domain, len(docs))
return _index_cache[domain]
def _build_context(docs: list[RetrievedDoc]) -> str:
return "\n\n".join(f"[{d.title}]\n{d.content.strip()}" for d in docs)
# ---------------------------------------------------------------------------
# Joke short-circuit
# ---------------------------------------------------------------------------
_JOKE_PATTERN = re.compile(
r"\b(joke|jokes|funny|laugh|humor|humour|dad joke|daddy joke|pun|make me (laugh|smile))\b",
re.IGNORECASE,
)
_DOMAIN_JOKES: dict[str, list[str]] = {
"retail": [
"Why did the inventory go to therapy? It had too many unresolved stockouts.",
"What do you call a shelf that never runs out? A shelf-fulfilling prophecy.",
"Why don't retail systems ever get lonely? Because they're always in stock-ing.",
"I tried to write a joke about planograms… but I couldn't find the right placement.",
"Why did the out-of-stock alert break up with the replenishment order? It said: 'You always come too late.'",
],
"pharma": [
"Why did the adverse event report go to the doctor? It had too many side effects.",
"What do you call a drug that's always on time? Punctu-pill.",
"Why are pharmacovigilance teams great at parties? They always follow up.",
"I asked the formulary for a joke. It said: 'Prior authorization required.'",
"Why did the clinical trial fail to finish the joke? Insufficient sample size.",
],
}
_FALLBACK_JOKES = [
"Why did the AI refuse to tell a joke? Insufficient context. Please provide grounding documents.",
"I'm not great at jokes — my training data was mostly compliance documents.",
"Why did the chatbot cross the road? To reduce hallucination on the other side.",
]
_RESET_PATTERNS = re.compile(
r"\b(forget|ignore|pretend|blank slate|no context|disregard|system prompt|jailbreak)\b",
re.IGNORECASE,
)
def _is_joke_query(query: str) -> bool:
return bool(_JOKE_PATTERN.search(query))
def _joke_response(domain: str) -> str:
jokes = _DOMAIN_JOKES.get(domain, _FALLBACK_JOKES)
return random.choice(jokes)
def _is_reset_attempt(query: str) -> bool:
return bool(_RESET_PATTERNS.search(query))
# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------
HF_GENERATION_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
def _generate(
query: str,
context: str,
client: str,
domain: str,
hf_client: InferenceClient,
) -> str:
terms = client_terms(client)
term_list = "\n".join(f"- {v}" for v in terms.values()) if terms else "(none)"
system = SYSTEM_PROMPT.format(
client_display=DISPLAY_NAMES.get(client, client),
domain=domain,
term_list=term_list,
)
response = hf_client.chat.completions.create(
model=HF_GENERATION_MODEL,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"},
],
max_tokens=512,
)
return str(response.choices[0].message.content or "").strip()
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
def run(
query: str,
client: str,
hf_client: InferenceClient,
top_k: int = TOP_K,
) -> PipelineResult:
"""Retrieve relevant KB docs, generate a grounded answer, and grade it."""
domain = domain_for(client)
if _is_joke_query(query) or _is_reset_attempt(query):
answer = _joke_response(domain)
report = grade(query=query, response=answer, context=answer, client=client)
return PipelineResult(
query=query, client=client, answer=answer,
retrieved_docs=[], grade_report=report, context_used="",
)
embedder = get_embedder()
index = _build_index(domain, embedder)
t0 = time.perf_counter()
q_vec = embedder.encode([query])
scores = cosine_similarity(q_vec, index.embeddings)[0]
top_indices = np.argsort(scores)[::-1][:top_k]
retrieved = [
RetrievedDoc(
id=index.docs[i]["id"],
title=index.docs[i]["title"],
content=index.docs[i]["content"],
score=float(scores[i]),
)
for i in top_indices
if scores[i] > MIN_RETRIEVAL_SCORE
]
t_retrieve = (time.perf_counter() - t0) * 1000
terms_doc = client_terms_doc(client)
pinned = RetrievedDoc(
id=terms_doc["id"],
title=terms_doc["title"],
content=terms_doc["content"],
score=1.0,
)
# Generation context includes glossary; grading context uses KB docs only
generation_context = _build_context([pinned, *retrieved])
grading_context = _build_context(retrieved)
t1 = time.perf_counter()
answer = _generate(query, generation_context, client, domain, hf_client)
t_generate = (time.perf_counter() - t1) * 1000
t2 = time.perf_counter()
report = grade(
query=query,
response=answer,
context=grading_context,
client=client,
)
t_grade = (time.perf_counter() - t2) * 1000
import telemetry
telemetry.record(
client=client,
domain=domain,
query_len=len(query),
latency_ms={"retrieve": t_retrieve, "generate": t_generate, "grade": t_grade},
report=report,
docs_retrieved=len(retrieved),
min_retrieval_score=float(min(s.score for s in retrieved)) if retrieved else 0.0,
)
return PipelineResult(
query=query,
client=client,
answer=answer,
retrieved_docs=retrieved,
grade_report=report,
context_used=grading_context,
)