File size: 9,655 Bytes
91b591f
85020ae
91b591f
ad9e267
 
 
 
 
91b591f
 
 
 
 
 
 
 
 
ad9e267
91b591f
 
 
 
 
ad9e267
91b591f
 
 
ad9e267
91b591f
 
 
 
 
 
ad9e267
91b591f
 
 
ad9e267
 
91b591f
ad9e267
91b591f
 
 
 
 
ad9e267
91b591f
ad9e267
91b591f
 
 
 
 
 
 
 
 
 
ad9e267
91b591f
 
 
 
ad9e267
91b591f
 
 
ad9e267
91b591f
 
ad9e267
91b591f
 
 
 
 
 
 
 
 
 
 
ad9e267
91b591f
 
 
ad9e267
 
91b591f
 
 
 
 
 
ad9e267
91b591f
 
 
 
 
 
 
 
 
 
2cec50c
 
91b591f
 
2cec50c
 
 
 
 
 
 
 
91b591f
2cec50c
 
 
 
 
 
 
 
 
 
 
 
91b591f
 
 
 
 
 
 
 
 
 
 
ad9e267
91b591f
 
 
ad9e267
91b591f
 
 
 
 
 
ad9e267
91b591f
 
 
 
 
 
 
 
 
 
 
 
ad9e267
91b591f
ad9e267
 
 
 
 
91b591f
ad9e267
 
91b591f
 
ad9e267
91b591f
ad9e267
91b591f
ad9e267
 
91b591f
 
ad9e267
91b591f
ad9e267
91b591f
ad9e267
91b591f
 
 
ad9e267
91b591f
 
ad9e267
91b591f
ad9e267
91b591f
ad9e267
91b591f
ad9e267
91b591f
ad9e267
91b591f
 
ad9e267
 
91b591f
ad9e267
91b591f
ad9e267
91b591f
ad9e267
91b591f
ad9e267
91b591f
 
 
ad9e267
 
91b591f
ad9e267
91b591f
 
 
ad9e267
91b591f
 
 
 
 
 
 
ad9e267
 
91b591f
 
 
 
 
ad9e267
 
 
91b591f
 
 
ad9e267
91b591f
ad9e267
91b591f
 
 
 
 
 
ad9e267
91b591f
 
 
 
ad9e267
91b591f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""
RAG module for AMR-Guard.

Retrieves context from four ChromaDB collections:
- idsa_treatment_guidelines: IDSA 2024 AMR guidance
- mic_reference_docs: EUCAST v16.0 breakpoint tables
- drug_safety: Drug interactions and contraindications
- pathogen_resistance: ATLAS regional susceptibility data
"""

import logging
from typing import Any, Dict, List, Optional

from .config import get_settings

logger = logging.getLogger(__name__)

# Module-level singletons; initialized lazily to avoid import-time side effects
_chroma_client = None
_embedding_function = None


def get_chroma_client():
    """Return the ChromaDB persistent client, creating it on first call."""
    global _chroma_client
    if _chroma_client is None:
        import chromadb
        chroma_path = get_settings().chroma_db_dir
        chroma_path.mkdir(parents=True, exist_ok=True)
        _chroma_client = chromadb.PersistentClient(path=str(chroma_path))
    return _chroma_client


def get_embedding_function():
    """Return the SentenceTransformer embedding function, creating it on first call."""
    global _embedding_function
    if _embedding_function is None:
        from chromadb.utils import embedding_functions
        # Use only the model short name (not the full HuggingFace path)
        model_short_name = get_settings().embedding_model_name.split("/")[-1]
        _embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
            model_name=model_short_name
        )
    return _embedding_function


def get_collection(name: str):
    """Return a ChromaDB collection by name, or None if it does not exist."""
    try:
        return get_chroma_client().get_collection(name=name, embedding_function=get_embedding_function())
    except Exception:
        logger.warning(f"Collection '{name}' not found")
        return None


def search_antibiotic_guidelines(
    query: str,
    n_results: int = 5,
    pathogen_filter: Optional[str] = None,
) -> List[Dict[str, Any]]:
    """Search the IDSA treatment guidelines collection."""
    collection = get_collection("idsa_treatment_guidelines")
    if collection is None:
        return []
    try:
        where = {"pathogen_type": pathogen_filter} if pathogen_filter else None
        results = collection.query(
            query_texts=[query],
            n_results=n_results,
            where=where,
            include=["documents", "metadatas", "distances"],
        )
        return _format_results(results)
    except Exception as e:
        logger.error(f"Error querying guidelines: {e}")
        return []


def search_mic_breakpoints(
    query: str,
    n_results: int = 5,
    organism: Optional[str] = None,
    antibiotic: Optional[str] = None,
) -> List[Dict[str, Any]]:
    """Search the EUCAST MIC breakpoint reference collection."""
    collection = get_collection("mic_reference_docs")
    if collection is None:
        return []
    # Prepend organism/antibiotic to query to narrow semantic search
    enhanced_query = " ".join(filter(None, [organism, antibiotic, query]))
    try:
        results = collection.query(
            query_texts=[enhanced_query],
            n_results=n_results,
            include=["documents", "metadatas", "distances"],
        )
        return _format_results(results)
    except Exception as e:
        logger.error(f"Error querying breakpoints: {e}")
        return []


def search_drug_safety(
    query: str,
    n_results: int = 5,
    drug_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
    """Search drug interactions from SQLite (drug_interactions table)."""
    if not drug_name:
        return []
    try:
        from .db.database import execute_query

        rows = execute_query(
            """SELECT drug_1, drug_2, interaction_description, severity
               FROM drug_interactions
               WHERE LOWER(drug_1) LIKE ? OR LOWER(drug_2) LIKE ?
               LIMIT ?""",
            (f"%{drug_name.lower()}%", f"%{drug_name.lower()}%", n_results),
        )
        return [
            {
                "content": (
                    f"{r['drug_1']} + {r['drug_2']}: {r['interaction_description']}"
                ),
                "metadata": {"severity": r.get("severity", "unknown")},
                "distance": None,
                "source": "drug_interactions (SQLite)",
                "relevance_score": 1.0,
            }
            for r in rows
        ]
    except Exception as e:
        logger.error(f"Error querying drug safety: {e}")
        return []


def search_resistance_patterns(
    query: str,
    n_results: int = 5,
    organism: Optional[str] = None,
    region: Optional[str] = None,
) -> List[Dict[str, Any]]:
    """Search the ATLAS pathogen resistance collection."""
    collection = get_collection("pathogen_resistance")
    if collection is None:
        return []
    enhanced_query = " ".join(filter(None, [region, organism, query]))
    try:
        results = collection.query(
            query_texts=[enhanced_query],
            n_results=n_results,
            include=["documents", "metadatas", "distances"],
        )
        return _format_results(results)
    except Exception as e:
        logger.error(f"Error querying resistance patterns: {e}")
        return []


def get_context_for_agent(
    agent_name: str,
    query: str,
    patient_context: Optional[Dict[str, Any]] = None,
    n_results: int = 3,
) -> str:
    """
    Return a formatted context string for a specific agent.

    Each agent draws from the collections most relevant to its task:
    - intake_historian: IDSA guidelines
    - vision_specialist: MIC breakpoints
    - trend_analyst: MIC breakpoints + resistance patterns
    - clinical_pharmacologist: guidelines + drug safety
    """
    ctx = patient_context or {}
    parts = []

    if agent_name == "intake_historian":
        guidelines = search_antibiotic_guidelines(query, n_results=n_results, pathogen_filter=ctx.get("pathogen_type"))
        if guidelines:
            parts.append("RELEVANT TREATMENT GUIDELINES:")
            for g in guidelines:
                parts.append(f"- {g['content'][:500]}...")
                parts.append(f"  [Source: {g.get('source', 'IDSA Guidelines')}]")

    elif agent_name == "vision_specialist":
        breakpoints = search_mic_breakpoints(query, n_results=n_results, organism=ctx.get("organism"), antibiotic=ctx.get("antibiotic"))
        if breakpoints:
            parts.append("RELEVANT BREAKPOINT INFORMATION:")
            for b in breakpoints:
                parts.append(f"- {b['content'][:400]}...")

    elif agent_name == "trend_analyst":
        breakpoints = search_mic_breakpoints(
            f"breakpoint {ctx.get('organism', '')} {ctx.get('antibiotic', '')}",
            n_results=n_results,
        )
        resistance = search_resistance_patterns(query, n_results=n_results, organism=ctx.get("organism"), region=ctx.get("region"))
        if breakpoints:
            parts.append("EUCAST BREAKPOINT DATA:")
            for b in breakpoints:
                parts.append(f"- {b['content'][:400]}...")
        if resistance:
            parts.append("\nRESISTANCE PATTERN DATA:")
            for r in resistance:
                parts.append(f"- {r['content'][:400]}...")

    elif agent_name == "clinical_pharmacologist":
        guidelines = search_antibiotic_guidelines(query, n_results=n_results)
        safety = search_drug_safety(query, n_results=n_results, drug_name=ctx.get("proposed_antibiotic"))
        if guidelines:
            parts.append("TREATMENT GUIDELINES:")
            for g in guidelines:
                parts.append(f"- {g['content'][:400]}...")
        if safety:
            parts.append("\nDRUG SAFETY INFORMATION:")
            for s in safety:
                parts.append(f"- {s['content'][:400]}...")

    else:
        guidelines = search_antibiotic_guidelines(query, n_results=n_results)
        for g in guidelines:
            parts.append(f"- {g['content'][:500]}...")

    return "\n".join(parts) if parts else "No relevant context found in knowledge base."


def _format_results(results: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Flatten ChromaDB query results into a list of dicts."""
    if not results or not results.get("documents"):
        return []

    documents = results["documents"][0] if results["documents"] else []
    metadatas = results.get("metadatas", [[]])[0]
    distances = results.get("distances", [[]])[0]

    return [
        {
            "content": doc,
            "metadata": metadatas[i] if i < len(metadatas) else {},
            "distance": distances[i] if i < len(distances) else None,
            "source": metadatas[i].get("source", "Unknown") if i < len(metadatas) else "Unknown",
            "relevance_score": 1 - (distances[i] if i < len(distances) else 0),
        }
        for i, doc in enumerate(documents)
    ]


def list_available_collections() -> List[str]:
    """Return names of all ChromaDB collections that exist."""
    try:
        return [c.name for c in get_chroma_client().list_collections()]
    except Exception as e:
        logger.error(f"Error listing collections: {e}")
        return []


def get_collection_info(name: str) -> Optional[Dict[str, Any]]:
    """Return count and metadata for a collection, or None if it does not exist."""
    collection = get_collection(name)
    if collection is None:
        return None
    try:
        return {"name": collection.name, "count": collection.count(), "metadata": collection.metadata}
    except Exception as e:
        logger.error(f"Error getting collection info: {e}")
        return None