Yogeshwarirj commited on
Commit
ef4b6ba
Β·
verified Β·
1 Parent(s): d22ea8d

Create medpanel.py

Browse files
Files changed (1) hide show
  1. medpanel.py +377 -0
medpanel.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # medpanel.py
2
+ # Core logic for the MedPanel multi-agent diagnostic system.
3
+ # This file contains all 4 agents + orchestrator + RAG pipeline.
4
+ # Imported by app.py which runs the Gradio interface on HuggingFace Spaces.
5
+
6
+ import os
7
+ import json
8
+ import re
9
+ import torch
10
+ import numpy as np
11
+ import faiss
12
+
13
+ from transformers import AutoProcessor, AutoModelForImageTextToText
14
+ from sentence_transformers import SentenceTransformer
15
+ from Bio import Entrez
16
+ from PIL import Image
17
+
18
+
19
+ # ── Model Configuration ──────────────────────────────────────────────
20
+ # We load these once at startup so they're ready for every request
21
+ MODEL_ID = "google/medgemma-4b-it"
22
+
23
+ # NCBI requires an email for PubMed access β€” just for identification purposes
24
+ Entrez.email = "medpanel@example.com"
25
+
26
+
27
+ # ── Load Models ──────────────────────────────────────────────────────
28
+
29
+ def load_models():
30
+ """
31
+ Loads MedGemma and the PubMed embedding model into memory.
32
+ Called once when the app starts up on HuggingFace Spaces.
33
+ Returns processor, model, and embed_model.
34
+ """
35
+
36
+ print("Loading MedGemma model...")
37
+
38
+ # Load the processor β€” handles both text tokenization and image preprocessing
39
+ processor = AutoProcessor.from_pretrained(
40
+ MODEL_ID,
41
+ token=os.environ.get("HF_TOKEN")
42
+
43
+ )
44
+
45
+ # Load MedGemma in bfloat16 to fit within GPU memory limits
46
+ model = AutoModelForImageTextToText.from_pretrained(
47
+ MODEL_ID,
48
+ torch_dtype=torch.bfloat16,
49
+ device_map="auto",
50
+ token=os.environ.get("HF_TOKEN"),
51
+ low_cpu_mem_usage=True,
52
+ attn_implementation="eager"
53
+ )
54
+ model.eval()
55
+ print("βœ… MedGemma loaded!")
56
+
57
+ # Load the PubMed-specific embedding model for semantic search
58
+ print("Loading PubMed embedding model...")
59
+ embed_model = SentenceTransformer("pritamdeka/S-PubMedBert-MS-MARCO")
60
+ print("βœ… Embedding model loaded!")
61
+
62
+ return processor, model, embed_model
63
+
64
+
65
+ # Initialize all models at module load time
66
+ processor, model, embed_model = load_models()
67
+
68
+
69
+ # ── Base Caller ──────────────────────────────────────────────────────
70
+
71
+ def call_medgemma(prompt, image=None, max_tokens=400):
72
+ """
73
+ Sends a prompt (and optional image) to MedGemma and returns the response.
74
+ This is the single point of contact with the model for all agents.
75
+ """
76
+
77
+ # Build message in MedGemma's expected chat format
78
+ messages = [
79
+ {
80
+ "role": "user",
81
+ "content": [
82
+ {"type": "text", "text": prompt},
83
+ *([{"type": "image", "image": image}] if image else [])
84
+ ]
85
+ }
86
+ ]
87
+
88
+ # Tokenize and move to the same device as the model
89
+ inputs = processor.apply_chat_template(
90
+ messages,
91
+ add_generation_prompt=True,
92
+ tokenize=True,
93
+ return_dict=True,
94
+ return_tensors="pt"
95
+ ).to(model.device)
96
+
97
+ # Generate response β€” no_grad saves memory, do_sample=False is deterministic
98
+ with torch.no_grad():
99
+ output_tokens = model.generate(
100
+ **inputs,
101
+ max_new_tokens=max_tokens,
102
+ do_sample=False
103
+ )
104
+
105
+ # Decode and strip the echoed prompt β€” we only want the model's reply
106
+ full_response = processor.decode(output_tokens[0], skip_special_tokens=True)
107
+ return full_response.split("model\n")[-1].strip()
108
+
109
+
110
+ def safe_json(text):
111
+ """
112
+ Safely extracts a JSON object from the model's response.
113
+ Handles markdown code fences, extra text, and malformed JSON.
114
+ Always returns a dict β€” never crashes.
115
+ """
116
+
117
+ # Strip markdown fences like ```json ... ``` if present
118
+ for fence_start, fence_end in [("```json", "```"), ("```", "```")]:
119
+ if fence_start in text:
120
+ text = text.split(fence_start)[1].split(fence_end)[0].strip()
121
+ break
122
+
123
+ # Try standard JSON parsing first
124
+ try:
125
+ return json.loads(text)
126
+ except json.JSONDecodeError:
127
+ pass
128
+
129
+ # Fall back to regex β€” find any { ... } block in the response
130
+ json_match = re.search(r'\{.*\}', text, re.DOTALL)
131
+ try:
132
+ return json.loads(json_match.group()) if json_match else {"raw_response": text}
133
+ except json.JSONDecodeError:
134
+ return {"raw_response": text}
135
+
136
+
137
+ # ── PubMed RAG ───────────────────────────────────────────────────────
138
+
139
+ def fetch_and_retrieve(query, top_k=3):
140
+ """
141
+ Searches PubMed for relevant abstracts using the given query.
142
+ Uses FAISS + PubMedBERT embeddings to find the most semantically
143
+ similar abstracts rather than just keyword matching.
144
+ Returns a list of abstract strings.
145
+ """
146
+
147
+ try:
148
+ # Search PubMed for matching paper IDs
149
+ handle = Entrez.esearch(db="pubmed", term=query, retmax=8)
150
+ ids = Entrez.read(handle)["IdList"]
151
+
152
+ if not ids:
153
+ return []
154
+
155
+ # Fetch the actual abstract text for those papers
156
+ handle = Entrez.efetch(
157
+ db="pubmed",
158
+ id=ids,
159
+ rettype="abstract",
160
+ retmode="text"
161
+ )
162
+
163
+ # Split the bulk text into individual abstracts, filter out short chunks
164
+ raw_text = handle.read()
165
+ abstracts = [
166
+ chunk.strip()
167
+ for chunk in raw_text.split("\n\n")
168
+ if len(chunk.strip()) > 100
169
+ ]
170
+
171
+ if not abstracts:
172
+ return []
173
+
174
+ # Build FAISS index from abstract embeddings
175
+ embeddings = embed_model.encode(abstracts)
176
+ index = faiss.IndexFlatL2(embeddings.shape[1])
177
+ index.add(np.array(embeddings))
178
+
179
+ # Find the top_k most relevant abstracts for our query
180
+ query_embedding = embed_model.encode([query])
181
+ _, best_indices = index.search(
182
+ np.array(query_embedding),
183
+ min(top_k, len(abstracts))
184
+ )
185
+
186
+ return [abstracts[i] for i in best_indices[0]]
187
+
188
+ except Exception as e:
189
+ # If PubMed is unavailable, return empty rather than crashing
190
+ print(f"PubMed fetch failed for '{query}': {e}")
191
+ return []
192
+
193
+
194
+ # ── Agent 1: Radiologist ─────────────────────────────────────────────
195
+
196
+ def radiologist_agent(image, notes):
197
+ """
198
+ Analyzes the medical image and returns structured radiology findings.
199
+ If no image is provided, returns a safe empty result.
200
+ """
201
+
202
+ if not image:
203
+ return {
204
+ "suspected_conditions": [],
205
+ "note": "No image provided β€” skipping radiology analysis"
206
+ }
207
+
208
+ # Convert to RGB if the image is grayscale β€” MedGemma requires RGB
209
+ if image.mode != "RGB":
210
+ image = image.convert("RGB")
211
+
212
+ prompt = f"""You are an experienced radiologist reviewing a medical image.
213
+ Patient clinical notes: {notes}
214
+ Carefully analyze the image and return your findings as a JSON object with:
215
+ - image_findings: list of observed features (e.g. "upper lobe opacity")
216
+ - suspected_conditions: list of possible diagnoses based on what you see
217
+ - abnormalities_detected: true or false
218
+ - severity: one of "mild", "moderate", "severe", or "normal"
219
+ - confidence: your confidence level from 0.0 to 1.0
220
+ Return only the JSON object, no extra explanation."""
221
+
222
+ return safe_json(call_medgemma(prompt, image))
223
+
224
+
225
+ # ── Agent 2: Internist ───────────────────────────────────────────────
226
+
227
+ def internist_agent(notes):
228
+ """
229
+ Analyzes clinical notes as an internal medicine physician.
230
+ Returns differential diagnoses, risk factors, and urgency level.
231
+ Works from text only β€” no image.
232
+ """
233
+
234
+ prompt = f"""You are an experienced internal medicine physician.
235
+ Patient clinical notes: {notes}
236
+ Based on the symptoms and clinical details, return your assessment as a JSON object with:
237
+ - differential_diagnoses: list of 3 most likely diagnoses, ordered by likelihood
238
+ - risk_factors: list of relevant patient risk factors
239
+ - urgency: one of "routine", "urgent", or "emergent"
240
+ - confidence: your overall confidence from 0.0 to 1.0
241
+ Return only the JSON object, no extra explanation."""
242
+
243
+ return safe_json(call_medgemma(prompt))
244
+
245
+
246
+ # ── Agent 3: Evidence Reviewer ───────────────────────────────────────
247
+
248
+ def evidence_agent(r1, r2):
249
+ """
250
+ Fetches supporting medical literature from PubMed based on what
251
+ the Radiologist and Internist suspected.
252
+ Returns up to 4 relevant abstracts.
253
+ """
254
+
255
+ # Combine top conditions from both agents into search queries
256
+ queries = (
257
+ r1.get("suspected_conditions", [])[:2] +
258
+ r2.get("differential_diagnoses", [])[:2]
259
+ )
260
+
261
+ # Search PubMed for each condition and collect abstracts
262
+ evidence = []
263
+ for query in queries:
264
+ results = fetch_and_retrieve(str(query), top_k=2)
265
+ evidence.extend(results)
266
+
267
+ # Cap at 4 to avoid overflowing the model's context window
268
+ return evidence[:4]
269
+
270
+
271
+ # ── Agent 4: Devil's Advocate ────────────────────────────────────────
272
+
273
+ def devils_advocate_agent(image, notes, r1, r2, evidence):
274
+ """
275
+ Adversarial agent that challenges the other agents' conclusions.
276
+ Specifically looks for dangerous diagnoses that were missed.
277
+ This is the agent that catches TB when base MedGemma misses it.
278
+ """
279
+
280
+ # Short evidence snippet so we don't overflow the prompt
281
+ evidence_snippet = "\n".join(evidence[:2]) if evidence else "None available"
282
+
283
+ prompt = f"""You are a critical care specialist and patient safety advocate.
284
+ Your job is NOT to agree β€” your job is to find what everyone else missed.
285
+ Patient clinical notes: {notes}
286
+ The radiologist suspected: {r1.get('suspected_conditions', [])}
287
+ The internist concluded: {r2.get('differential_diagnoses', [])}
288
+ Relevant medical literature:
289
+ {evidence_snippet[:500]}
290
+ Challenge these conclusions. Look for dangerous diagnoses that were missed,
291
+ rare but life-threatening alternatives, and overlooked red flags.
292
+ Return a JSON object with:
293
+ - missed_diagnoses: list of diagnoses the other agents may have overlooked
294
+ - dangerous_alternatives: list of serious conditions that must be ruled out
295
+ - challenge_statement: one sentence explaining your biggest concern
296
+ - requires_human_review: true or false
297
+ Return only the JSON object, no extra explanation."""
298
+
299
+ # Pass image if available so the devil's advocate can see it too
300
+ if image and image.mode != "RGB":
301
+ image = image.convert("RGB")
302
+
303
+ return safe_json(call_medgemma(prompt, image))
304
+
305
+
306
+ # ── Orchestrator ─────────────────────────────────────────────────────
307
+
308
+ def orchestrator_agent(notes, r1, r2, evidence, devil):
309
+ """
310
+ Synthesizes all four agents' outputs into a single final report.
311
+ Decides on the primary diagnosis, confidence, escalation, and next steps.
312
+ """
313
+
314
+ prompt = f"""You are the lead physician synthesizing a multi-specialist panel review.
315
+ RADIOLOGIST findings:
316
+ {json.dumps(r1, indent=2)}
317
+ INTERNIST findings:
318
+ {json.dumps(r2, indent=2)}
319
+ DEVIL'S ADVOCATE concerns:
320
+ {json.dumps(devil, indent=2)}
321
+ Supporting evidence: {len(evidence)} PubMed abstracts retrieved.
322
+ Synthesize everything into a final clinical report as a JSON object with:
323
+ - primary_diagnosis: the single most likely diagnosis
324
+ - differential_diagnoses: list of other possibilities
325
+ - panel_agreement_score: 0-100, how much the specialists agreed
326
+ - red_flags: list of warning signs needing immediate attention
327
+ - recommended_next_steps: list of tests or actions to take
328
+ - escalate_to_human: true if a real doctor needs to review this urgently
329
+ - escalation_reason: why escalation is needed (or "Not required")
330
+ - patient_summary: 2-sentence plain English summary for the patient
331
+ Return only the JSON object, no extra explanation."""
332
+
333
+ return safe_json(call_medgemma(prompt))
334
+
335
+
336
+ # ── Master Pipeline ──────────────────────────────────────────────────
337
+
338
+ def run_medpanel(image, notes):
339
+ """
340
+ Runs the full MedPanel multi-agent pipeline.
341
+ Accepts a PIL image (or None) and a string of clinical notes.
342
+ Returns a dict with panel_trace (each agent's output) and final_report.
343
+ """
344
+
345
+ trace = []
346
+
347
+ # Step 1: Radiologist β€” analyze the image
348
+ print("🩻 Running Radiologist agent...")
349
+ r1 = radiologist_agent(image, notes)
350
+ trace.append({"agent": "Radiologist", "output": r1})
351
+
352
+ # Step 2: Internist β€” analyze the clinical notes
353
+ print("🩺 Running Internist agent...")
354
+ r2 = internist_agent(notes)
355
+ trace.append({"agent": "Internist", "output": r2})
356
+
357
+ # Step 3: Evidence Reviewer β€” fetch PubMed literature
358
+ print("πŸ“š Fetching PubMed evidence...")
359
+ evidence = evidence_agent(r1, r2)
360
+ trace.append({"agent": "Evidence Reviewer", "abstracts_retrieved": len(evidence)})
361
+
362
+ # Step 4: Devil's Advocate β€” challenge the findings
363
+ print("😈 Running Devil's Advocate agent...")
364
+ devil = devils_advocate_agent(image, notes, r1, r2, evidence)
365
+ trace.append({"agent": "Devil's Advocate", "output": devil})
366
+
367
+ # Step 5: Orchestrator β€” synthesize the final report
368
+ print("πŸ₯ Synthesizing final report...")
369
+ final_report = orchestrator_agent(notes, r1, r2, evidence, devil)
370
+ trace.append({"agent": "Orchestrator", "output": final_report})
371
+
372
+ print("βœ… MedPanel analysis complete!")
373
+
374
+ return {
375
+ "panel_trace": trace,
376
+ "final_report": final_report
377
+ }