Mediscribe / rag /retriever.py
Fred-Rcky's picture
all done
c32bf13
Raw
History Blame Contribute Delete
5.2 kB
from __future__ import annotations
import json
import os
from pathlib import Path
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
DATA_DIR = Path(__file__).parent / "data"
DB_DIR = Path(__file__).parent.parent / "chroma_db"
EMBED_MODEL = "all-MiniLM-L6-v2"
_client: chromadb.PersistentClient | None = None
_icd_col = None
_drug_col = None
def _get_client():
global _client
if _client is None:
_client = chromadb.PersistentClient(path=str(DB_DIR))
return _client
def _embedding_fn():
return SentenceTransformerEmbeddingFunction(model_name=EMBED_MODEL)
def build_knowledge_base(force: bool = False):
"""Embed ICD-10 codes and medicines into ChromaDB. Runs once; skipped if DB exists."""
client = _get_client()
ef = _embedding_fn()
existing = [c.name for c in client.list_collections()]
# ── ICD-10 ──────────────────────────────────────────────────────────────
if "icd10" not in existing or force:
if "icd10" in existing:
client.delete_collection("icd10")
col = client.create_collection("icd10", embedding_function=ef)
with open(DATA_DIR / "icd10_common.json") as f:
records = json.load(f)
col.add(
ids=[r["code"] for r in records],
documents=[f"{r['description']} {r['keywords']}" for r in records],
metadatas=[{"code": r["code"], "description": r["description"]} for r in records],
)
print(f"[RAG] Indexed {len(records)} ICD-10 codes")
# ── Medicines ────────────────────────────────────────────────────────────
if "medicines" not in existing or force:
if "medicines" in existing:
client.delete_collection("medicines")
col = client.create_collection("medicines", embedding_function=ef)
with open(DATA_DIR / "essential_medicines.json") as f:
records = json.load(f)
col.add(
ids=[str(i) for i in range(len(records))],
documents=[
f"{r['name']} {r['class']} {r['indications']}"
for r in records
],
metadatas=records,
)
print(f"[RAG] Indexed {len(records)} essential medicines")
def _icd_collection():
global _icd_col
if _icd_col is None:
_icd_col = _get_client().get_collection("icd10", embedding_function=_embedding_fn())
return _icd_col
def _drug_collection():
global _drug_col
if _drug_col is None:
_drug_col = _get_client().get_collection("medicines", embedding_function=_embedding_fn())
return _drug_col
def retrieve_icd_codes(query: str, n: int = 5) -> list[dict]:
"""Return top-n ICD-10 codes matching the clinical query."""
if not query.strip():
return []
results = _icd_collection().query(query_texts=[query], n_results=n)
codes = []
for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
codes.append({
"code": meta["code"],
"description": meta["description"],
"score": round(1 - dist, 3),
})
return codes
def retrieve_drug_info(drug_names: list[str], n: int = 3) -> list[dict]:
"""Return drug info for each named medication. Falls back to closest match."""
if not drug_names:
return []
query = ", ".join(drug_names)
results = _drug_collection().query(query_texts=[query], n_results=n)
drugs = []
for meta in results["metadatas"][0]:
drugs.append({
"name": meta["name"],
"class": meta["class"],
"adult_dose": meta["adult_dose"],
"indications": meta["indications"],
"contraindications": meta["contraindications"],
"notes": meta.get("notes", ""),
})
return drugs
def format_icd_context(codes: list[dict]) -> str:
"""Format ICD codes as text context for injection into prompts."""
if not codes:
return ""
lines = ["Relevant ICD-10 codes to consider:"]
for c in codes:
lines.append(f" {c['code']} β€” {c['description']}")
return "\n".join(lines)
def format_drug_context(drugs: list[dict]) -> str:
"""Format drug info as text context for injection into prompts."""
if not drugs:
return ""
lines = ["Relevant medication reference:"]
for d in drugs:
lines.append(
f" {d['name']} ({d['class']}): {d['adult_dose']}. "
f"Indications: {d['indications']}."
)
return "\n".join(lines)
def ensure_kb():
"""Called at app startup β€” builds KB only if it doesn't exist yet."""
client = _get_client()
existing = [c.name for c in client.list_collections()]
if "icd10" not in existing or "medicines" not in existing:
print("[RAG] Building knowledge base for the first time...")
build_knowledge_base()
else:
print("[RAG] Knowledge base ready.")