| | """ |
| | RAG module for AMR-Guard. |
| | |
| | Retrieves context from four ChromaDB collections: |
| | - idsa_treatment_guidelines: IDSA 2024 AMR guidance |
| | - mic_reference_docs: EUCAST v16.0 breakpoint tables |
| | - drug_safety: Drug interactions and contraindications |
| | - pathogen_resistance: ATLAS regional susceptibility data |
| | """ |
| |
|
| | import logging |
| | from typing import Any, Dict, List, Optional |
| |
|
| | from .config import get_settings |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | _chroma_client = None |
| | _embedding_function = None |
| |
|
| |
|
| | def get_chroma_client(): |
| | """Return the ChromaDB persistent client, creating it on first call.""" |
| | global _chroma_client |
| | if _chroma_client is None: |
| | import chromadb |
| | chroma_path = get_settings().chroma_db_dir |
| | chroma_path.mkdir(parents=True, exist_ok=True) |
| | _chroma_client = chromadb.PersistentClient(path=str(chroma_path)) |
| | return _chroma_client |
| |
|
| |
|
| | def get_embedding_function(): |
| | """Return the SentenceTransformer embedding function, creating it on first call.""" |
| | global _embedding_function |
| | if _embedding_function is None: |
| | from chromadb.utils import embedding_functions |
| | |
| | model_short_name = get_settings().embedding_model_name.split("/")[-1] |
| | _embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( |
| | model_name=model_short_name |
| | ) |
| | return _embedding_function |
| |
|
| |
|
| | def get_collection(name: str): |
| | """Return a ChromaDB collection by name, or None if it does not exist.""" |
| | try: |
| | return get_chroma_client().get_collection(name=name, embedding_function=get_embedding_function()) |
| | except Exception: |
| | logger.warning(f"Collection '{name}' not found") |
| | return None |
| |
|
| |
|
| | def search_antibiotic_guidelines( |
| | query: str, |
| | n_results: int = 5, |
| | pathogen_filter: Optional[str] = None, |
| | ) -> List[Dict[str, Any]]: |
| | """Search the IDSA treatment guidelines collection.""" |
| | collection = get_collection("idsa_treatment_guidelines") |
| | if collection is None: |
| | return [] |
| | try: |
| | where = {"pathogen_type": pathogen_filter} if pathogen_filter else None |
| | results = collection.query( |
| | query_texts=[query], |
| | n_results=n_results, |
| | where=where, |
| | include=["documents", "metadatas", "distances"], |
| | ) |
| | return _format_results(results) |
| | except Exception as e: |
| | logger.error(f"Error querying guidelines: {e}") |
| | return [] |
| |
|
| |
|
| | def search_mic_breakpoints( |
| | query: str, |
| | n_results: int = 5, |
| | organism: Optional[str] = None, |
| | antibiotic: Optional[str] = None, |
| | ) -> List[Dict[str, Any]]: |
| | """Search the EUCAST MIC breakpoint reference collection.""" |
| | collection = get_collection("mic_reference_docs") |
| | if collection is None: |
| | return [] |
| | |
| | enhanced_query = " ".join(filter(None, [organism, antibiotic, query])) |
| | try: |
| | results = collection.query( |
| | query_texts=[enhanced_query], |
| | n_results=n_results, |
| | include=["documents", "metadatas", "distances"], |
| | ) |
| | return _format_results(results) |
| | except Exception as e: |
| | logger.error(f"Error querying breakpoints: {e}") |
| | return [] |
| |
|
| |
|
| | def search_drug_safety( |
| | query: str, |
| | n_results: int = 5, |
| | drug_name: Optional[str] = None, |
| | ) -> List[Dict[str, Any]]: |
| | """Search drug interactions from SQLite (drug_interactions table).""" |
| | if not drug_name: |
| | return [] |
| | try: |
| | from .db.database import execute_query |
| |
|
| | rows = execute_query( |
| | """SELECT drug_1, drug_2, interaction_description, severity |
| | FROM drug_interactions |
| | WHERE LOWER(drug_1) LIKE ? OR LOWER(drug_2) LIKE ? |
| | LIMIT ?""", |
| | (f"%{drug_name.lower()}%", f"%{drug_name.lower()}%", n_results), |
| | ) |
| | return [ |
| | { |
| | "content": ( |
| | f"{r['drug_1']} + {r['drug_2']}: {r['interaction_description']}" |
| | ), |
| | "metadata": {"severity": r.get("severity", "unknown")}, |
| | "distance": None, |
| | "source": "drug_interactions (SQLite)", |
| | "relevance_score": 1.0, |
| | } |
| | for r in rows |
| | ] |
| | except Exception as e: |
| | logger.error(f"Error querying drug safety: {e}") |
| | return [] |
| |
|
| |
|
| | def search_resistance_patterns( |
| | query: str, |
| | n_results: int = 5, |
| | organism: Optional[str] = None, |
| | region: Optional[str] = None, |
| | ) -> List[Dict[str, Any]]: |
| | """Search the ATLAS pathogen resistance collection.""" |
| | collection = get_collection("pathogen_resistance") |
| | if collection is None: |
| | return [] |
| | enhanced_query = " ".join(filter(None, [region, organism, query])) |
| | try: |
| | results = collection.query( |
| | query_texts=[enhanced_query], |
| | n_results=n_results, |
| | include=["documents", "metadatas", "distances"], |
| | ) |
| | return _format_results(results) |
| | except Exception as e: |
| | logger.error(f"Error querying resistance patterns: {e}") |
| | return [] |
| |
|
| |
|
| | def get_context_for_agent( |
| | agent_name: str, |
| | query: str, |
| | patient_context: Optional[Dict[str, Any]] = None, |
| | n_results: int = 3, |
| | ) -> str: |
| | """ |
| | Return a formatted context string for a specific agent. |
| | |
| | Each agent draws from the collections most relevant to its task: |
| | - intake_historian: IDSA guidelines |
| | - vision_specialist: MIC breakpoints |
| | - trend_analyst: MIC breakpoints + resistance patterns |
| | - clinical_pharmacologist: guidelines + drug safety |
| | """ |
| | ctx = patient_context or {} |
| | parts = [] |
| |
|
| | if agent_name == "intake_historian": |
| | guidelines = search_antibiotic_guidelines(query, n_results=n_results, pathogen_filter=ctx.get("pathogen_type")) |
| | if guidelines: |
| | parts.append("RELEVANT TREATMENT GUIDELINES:") |
| | for g in guidelines: |
| | parts.append(f"- {g['content'][:500]}...") |
| | parts.append(f" [Source: {g.get('source', 'IDSA Guidelines')}]") |
| |
|
| | elif agent_name == "vision_specialist": |
| | breakpoints = search_mic_breakpoints(query, n_results=n_results, organism=ctx.get("organism"), antibiotic=ctx.get("antibiotic")) |
| | if breakpoints: |
| | parts.append("RELEVANT BREAKPOINT INFORMATION:") |
| | for b in breakpoints: |
| | parts.append(f"- {b['content'][:400]}...") |
| |
|
| | elif agent_name == "trend_analyst": |
| | breakpoints = search_mic_breakpoints( |
| | f"breakpoint {ctx.get('organism', '')} {ctx.get('antibiotic', '')}", |
| | n_results=n_results, |
| | ) |
| | resistance = search_resistance_patterns(query, n_results=n_results, organism=ctx.get("organism"), region=ctx.get("region")) |
| | if breakpoints: |
| | parts.append("EUCAST BREAKPOINT DATA:") |
| | for b in breakpoints: |
| | parts.append(f"- {b['content'][:400]}...") |
| | if resistance: |
| | parts.append("\nRESISTANCE PATTERN DATA:") |
| | for r in resistance: |
| | parts.append(f"- {r['content'][:400]}...") |
| |
|
| | elif agent_name == "clinical_pharmacologist": |
| | guidelines = search_antibiotic_guidelines(query, n_results=n_results) |
| | safety = search_drug_safety(query, n_results=n_results, drug_name=ctx.get("proposed_antibiotic")) |
| | if guidelines: |
| | parts.append("TREATMENT GUIDELINES:") |
| | for g in guidelines: |
| | parts.append(f"- {g['content'][:400]}...") |
| | if safety: |
| | parts.append("\nDRUG SAFETY INFORMATION:") |
| | for s in safety: |
| | parts.append(f"- {s['content'][:400]}...") |
| |
|
| | else: |
| | guidelines = search_antibiotic_guidelines(query, n_results=n_results) |
| | for g in guidelines: |
| | parts.append(f"- {g['content'][:500]}...") |
| |
|
| | return "\n".join(parts) if parts else "No relevant context found in knowledge base." |
| |
|
| |
|
| | def _format_results(results: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """Flatten ChromaDB query results into a list of dicts.""" |
| | if not results or not results.get("documents"): |
| | return [] |
| |
|
| | documents = results["documents"][0] if results["documents"] else [] |
| | metadatas = results.get("metadatas", [[]])[0] |
| | distances = results.get("distances", [[]])[0] |
| |
|
| | return [ |
| | { |
| | "content": doc, |
| | "metadata": metadatas[i] if i < len(metadatas) else {}, |
| | "distance": distances[i] if i < len(distances) else None, |
| | "source": metadatas[i].get("source", "Unknown") if i < len(metadatas) else "Unknown", |
| | "relevance_score": 1 - (distances[i] if i < len(distances) else 0), |
| | } |
| | for i, doc in enumerate(documents) |
| | ] |
| |
|
| |
|
| | def list_available_collections() -> List[str]: |
| | """Return names of all ChromaDB collections that exist.""" |
| | try: |
| | return [c.name for c in get_chroma_client().list_collections()] |
| | except Exception as e: |
| | logger.error(f"Error listing collections: {e}") |
| | return [] |
| |
|
| |
|
| | def get_collection_info(name: str) -> Optional[Dict[str, Any]]: |
| | """Return count and metadata for a collection, or None if it does not exist.""" |
| | collection = get_collection(name) |
| | if collection is None: |
| | return None |
| | try: |
| | return {"name": collection.name, "count": collection.count(), "metadata": collection.metadata} |
| | except Exception as e: |
| | logger.error(f"Error getting collection info: {e}") |
| | return None |
| |
|