SamarpeetGarad commited on
Commit
a0ba97a
·
verified ·
1 Parent(s): 6992528

Upload agents/finding_interpreter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. agents/finding_interpreter.py +68 -60
agents/finding_interpreter.py CHANGED
@@ -9,13 +9,12 @@ from PIL import Image
9
 
10
  from .base_agent import BaseAgent, AgentResult
11
 
12
- # Try to import torch and transformers
13
  try:
14
- import torch
15
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
16
- TORCH_AVAILABLE = True
17
  except ImportError:
18
- TORCH_AVAILABLE = False
19
 
20
 
21
  class FindingInterpreterAgent(BaseAgent):
@@ -23,7 +22,7 @@ class FindingInterpreterAgent(BaseAgent):
23
  Agent 2: MedGemma Finding Interpreter
24
 
25
  Takes CXR analysis results and generates clinical interpretations
26
- using Google's MedGemma model.
27
  """
28
 
29
  def __init__(self, demo_mode: bool = False):
@@ -32,42 +31,20 @@ class FindingInterpreterAgent(BaseAgent):
32
  model_name="google/medgemma-4b-it"
33
  )
34
  self.demo_mode = demo_mode
35
- self.tokenizer = None
36
 
37
  def load_model(self) -> bool:
38
- """Load MedGemma model."""
39
- if self.demo_mode:
40
- self.is_loaded = True
41
- return True
42
-
43
- if not TORCH_AVAILABLE:
44
- print("Warning: PyTorch not available. Running in demo mode.")
45
- self.demo_mode = True
46
  self.is_loaded = True
47
  return True
48
 
49
  try:
50
- self.tokenizer = AutoTokenizer.from_pretrained(
51
- self.model_name,
52
- trust_remote_code=True
53
- )
54
-
55
- # Load with appropriate settings for memory efficiency
56
- self.model = AutoModelForCausalLM.from_pretrained(
57
- self.model_name,
58
- trust_remote_code=True,
59
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
60
- device_map="auto" if torch.cuda.is_available() else None,
61
- low_cpu_mem_usage=True
62
- )
63
-
64
- self.model.eval()
65
- self.is_loaded = True
66
  return True
67
-
68
  except Exception as e:
69
- print(f"Failed to load MedGemma model: {e}")
70
- print("Falling back to demo mode.")
71
  self.demo_mode = True
72
  self.is_loaded = True
73
  return True
@@ -98,11 +75,11 @@ class FindingInterpreterAgent(BaseAgent):
98
  findings = input_data.get("findings", [])
99
  region_analysis = input_data.get("region_analysis", {})
100
 
101
- # Process based on mode
102
- if self.demo_mode:
103
- interpretation = self._simulate_interpretation(findings, region_analysis, context)
104
- else:
105
  interpretation = self._run_model_inference(findings, region_analysis, context)
 
 
106
 
107
  processing_time = (time.time() - start_time) * 1000
108
 
@@ -119,37 +96,68 @@ class FindingInterpreterAgent(BaseAgent):
119
  region_analysis: Dict,
120
  context: Optional[Dict]
121
  ) -> Dict:
122
- """Run actual MedGemma inference."""
123
  try:
124
- # Prepare prompt
125
- prompt = self._build_prompt(findings, region_analysis, context)
126
 
127
- # Tokenize
128
- inputs = self.tokenizer(prompt, return_tensors="pt")
129
- if torch.cuda.is_available():
130
- inputs = {k: v.cuda() for k, v in inputs.items()}
131
-
132
- # Generate
133
- with torch.no_grad():
134
- outputs = self.model.generate(
135
- **inputs,
136
- max_new_tokens=512,
137
- temperature=0.3,
138
- top_p=0.9,
139
- do_sample=True,
140
- pad_token_id=self.tokenizer.eos_token_id
141
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- # Decode
144
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
145
 
146
- # Parse response
147
- return self._parse_model_response(response, findings)
 
 
 
 
 
 
 
 
 
148
 
149
  except Exception as e:
150
  print(f"MedGemma inference error: {e}")
151
  return self._simulate_interpretation(findings, region_analysis, context)
152
 
 
 
 
 
 
 
 
 
153
  def _simulate_interpretation(
154
  self,
155
  findings: List[Dict],
 
9
 
10
  from .base_agent import BaseAgent, AgentResult
11
 
12
+ # Import the unified MedGemma engine
13
  try:
14
+ from .medgemma_engine import get_engine, MedGemmaEngine
15
+ ENGINE_AVAILABLE = True
 
16
  except ImportError:
17
+ ENGINE_AVAILABLE = False
18
 
19
 
20
  class FindingInterpreterAgent(BaseAgent):
 
22
  Agent 2: MedGemma Finding Interpreter
23
 
24
  Takes CXR analysis results and generates clinical interpretations
25
+ using Google's MedGemma model via the unified engine.
26
  """
27
 
28
  def __init__(self, demo_mode: bool = False):
 
31
  model_name="google/medgemma-4b-it"
32
  )
33
  self.demo_mode = demo_mode
34
+ self.engine = None
35
 
36
  def load_model(self) -> bool:
37
+ """Load MedGemma model via unified engine."""
38
+ if self.demo_mode or not ENGINE_AVAILABLE:
 
 
 
 
 
 
39
  self.is_loaded = True
40
  return True
41
 
42
  try:
43
+ self.engine = get_engine(force_demo=self.demo_mode)
44
+ self.is_loaded = self.engine.is_loaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return True
 
46
  except Exception as e:
47
+ print(f"Failed to load MedGemma engine: {e}")
 
48
  self.demo_mode = True
49
  self.is_loaded = True
50
  return True
 
75
  findings = input_data.get("findings", [])
76
  region_analysis = input_data.get("region_analysis", {})
77
 
78
+ # Process - always try to use real model if available
79
+ if self.engine and self.engine.is_loaded and self.engine.backend != "demo":
 
 
80
  interpretation = self._run_model_inference(findings, region_analysis, context)
81
+ else:
82
+ interpretation = self._simulate_interpretation(findings, region_analysis, context)
83
 
84
  processing_time = (time.time() - start_time) * 1000
85
 
 
96
  region_analysis: Dict,
97
  context: Optional[Dict]
98
  ) -> Dict:
99
+ """Run actual MedGemma inference using the unified engine."""
100
  try:
101
+ clinical_context = context.get("clinical_history", "Not provided") if context else "Not provided"
 
102
 
103
+ # Generate interpretations for each finding using real MedGemma
104
+ interpreted_findings = []
105
+ for finding in findings:
106
+ prompt = f"""As a radiologist, interpret this chest X-ray finding:
107
+
108
+ Finding: {finding.get('type', 'Unknown')}
109
+ Region: {finding.get('region', 'Unknown')}
110
+ Severity: {finding.get('severity', 'Unknown')}
111
+ Description: {finding.get('description', 'No description')}
112
+ Clinical History: {clinical_context}
113
+
114
+ Provide:
115
+ 1. Clinical significance (1-2 sentences)
116
+ 2. Top 3 differential diagnoses
117
+ 3. Recommended follow-up
118
+
119
+ Be concise and clinically relevant."""
120
+
121
+ response = self.engine.generate(prompt, max_tokens=200)
122
+
123
+ interpreted = {
124
+ "original": finding,
125
+ "clinical_significance": self._extract_significance(response, finding),
126
+ "differential_diagnoses": self._get_differentials(finding),
127
+ "recommended_followup": self._get_followup(finding),
128
+ "medgemma_interpretation": response,
129
+ "correlation_notes": f"MedGemma analysis: {response[:100]}..."
130
+ }
131
+ interpreted_findings.append(interpreted)
132
 
133
+ # Generate clinical summary
134
+ clinical_summary = self._generate_clinical_summary(interpreted_findings, clinical_context)
135
+ key_concerns = self._identify_key_concerns(interpreted_findings)
136
 
137
+ return {
138
+ "interpreted_findings": interpreted_findings,
139
+ "clinical_summary": clinical_summary,
140
+ "key_concerns": key_concerns,
141
+ "abnormal_regions": [
142
+ region for region, data in region_analysis.items()
143
+ if data.get("status") == "abnormal"
144
+ ],
145
+ "confidence_level": "high",
146
+ "model_used": f"MedGemma ({self.engine.backend})"
147
+ }
148
 
149
  except Exception as e:
150
  print(f"MedGemma inference error: {e}")
151
  return self._simulate_interpretation(findings, region_analysis, context)
152
 
153
+ def _extract_significance(self, response: str, finding: Dict) -> str:
154
+ """Extract clinical significance from MedGemma response."""
155
+ # Take first meaningful sentence
156
+ sentences = response.split('.')
157
+ if sentences and len(sentences[0]) > 10:
158
+ return sentences[0].strip() + "."
159
+ return self._get_significance(finding)
160
+
161
  def _simulate_interpretation(
162
  self,
163
  findings: List[Dict],