SamarpeetGarad commited on
Commit
3a161e0
·
verified ·
1 Parent(s): a0ba97a

Upload agents/report_generator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. agents/report_generator.py +19 -51
agents/report_generator.py CHANGED
@@ -9,20 +9,20 @@ from datetime import datetime
9
 
10
  from .base_agent import BaseAgent, AgentResult
11
 
12
- # Try to import torch and transformers
13
  try:
14
- import torch
15
- from transformers import AutoModelForCausalLM, AutoTokenizer
16
- TORCH_AVAILABLE = True
17
  except ImportError:
18
- TORCH_AVAILABLE = False
19
 
20
 
21
  class ReportGeneratorAgent(BaseAgent):
22
  """
23
  Agent 3: MedGemma Report Generator
24
 
25
- Generates structured radiology reports from interpreted findings.
 
26
  """
27
 
28
  def __init__(self, demo_mode: bool = False):
@@ -31,39 +31,20 @@ class ReportGeneratorAgent(BaseAgent):
31
  model_name="google/medgemma-4b-it"
32
  )
33
  self.demo_mode = demo_mode
34
- self.tokenizer = None
35
 
36
  def load_model(self) -> bool:
37
- """Load MedGemma model for report generation."""
38
- if self.demo_mode:
39
- self.is_loaded = True
40
- return True
41
-
42
- if not TORCH_AVAILABLE:
43
- self.demo_mode = True
44
  self.is_loaded = True
45
  return True
46
 
47
  try:
48
- self.tokenizer = AutoTokenizer.from_pretrained(
49
- self.model_name,
50
- trust_remote_code=True
51
- )
52
-
53
- self.model = AutoModelForCausalLM.from_pretrained(
54
- self.model_name,
55
- trust_remote_code=True,
56
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
57
- device_map="auto" if torch.cuda.is_available() else None,
58
- low_cpu_mem_usage=True
59
- )
60
-
61
- self.model.eval()
62
- self.is_loaded = True
63
  return True
64
-
65
  except Exception as e:
66
- print(f"Failed to load model for report generation: {e}")
67
  self.demo_mode = True
68
  self.is_loaded = True
69
  return True
@@ -95,13 +76,13 @@ class ReportGeneratorAgent(BaseAgent):
95
  clinical_summary = input_data.get("clinical_summary", "")
96
  key_concerns = input_data.get("key_concerns", [])
97
 
98
- # Process based on mode
99
- if self.demo_mode:
100
- report = self._simulate_report_generation(
101
  interpreted_findings, clinical_summary, key_concerns, context
102
  )
103
  else:
104
- report = self._run_model_inference(
105
  interpreted_findings, clinical_summary, key_concerns, context
106
  )
107
 
@@ -121,27 +102,14 @@ class ReportGeneratorAgent(BaseAgent):
121
  key_concerns: List[str],
122
  context: Optional[Dict]
123
  ) -> Dict:
124
- """Generate report using MedGemma."""
125
  try:
126
  prompt = self._build_report_prompt(
127
  interpreted_findings, clinical_summary, key_concerns, context
128
  )
129
 
130
- inputs = self.tokenizer(prompt, return_tensors="pt")
131
- if torch.cuda.is_available():
132
- inputs = {k: v.cuda() for k, v in inputs.items()}
133
-
134
- with torch.no_grad():
135
- outputs = self.model.generate(
136
- **inputs,
137
- max_new_tokens=1024,
138
- temperature=0.3,
139
- top_p=0.9,
140
- do_sample=True,
141
- pad_token_id=self.tokenizer.eos_token_id
142
- )
143
-
144
- report_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
145
 
146
  return self._structure_report(report_text, interpreted_findings, context)
147
 
 
9
 
10
  from .base_agent import BaseAgent, AgentResult
11
 
12
+ # Import the unified MedGemma engine
13
  try:
14
+ from .medgemma_engine import get_engine, MedGemmaEngine
15
+ ENGINE_AVAILABLE = True
 
16
  except ImportError:
17
+ ENGINE_AVAILABLE = False
18
 
19
 
20
  class ReportGeneratorAgent(BaseAgent):
21
  """
22
  Agent 3: MedGemma Report Generator
23
 
24
+ Generates structured radiology reports from interpreted findings
25
+ using the unified MedGemma engine.
26
  """
27
 
28
  def __init__(self, demo_mode: bool = False):
 
31
  model_name="google/medgemma-4b-it"
32
  )
33
  self.demo_mode = demo_mode
34
+ self.engine = None
35
 
36
  def load_model(self) -> bool:
37
+ """Load MedGemma model via unified engine."""
38
+ if self.demo_mode or not ENGINE_AVAILABLE:
 
 
 
 
 
39
  self.is_loaded = True
40
  return True
41
 
42
  try:
43
+ self.engine = get_engine(force_demo=self.demo_mode)
44
+ self.is_loaded = self.engine.is_loaded
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return True
 
46
  except Exception as e:
47
+ print(f"Failed to load MedGemma engine: {e}")
48
  self.demo_mode = True
49
  self.is_loaded = True
50
  return True
 
76
  clinical_summary = input_data.get("clinical_summary", "")
77
  key_concerns = input_data.get("key_concerns", [])
78
 
79
+ # Process - always try to use real model if available
80
+ if self.engine and self.engine.is_loaded and self.engine.backend != "demo":
81
+ report = self._run_model_inference(
82
  interpreted_findings, clinical_summary, key_concerns, context
83
  )
84
  else:
85
+ report = self._simulate_report_generation(
86
  interpreted_findings, clinical_summary, key_concerns, context
87
  )
88
 
 
102
  key_concerns: List[str],
103
  context: Optional[Dict]
104
  ) -> Dict:
105
+ """Generate report using MedGemma via unified engine."""
106
  try:
107
  prompt = self._build_report_prompt(
108
  interpreted_findings, clinical_summary, key_concerns, context
109
  )
110
 
111
+ # Use the unified engine to generate report
112
+ report_text = self.engine.generate(prompt, max_tokens=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  return self._structure_report(report_text, interpreted_findings, context)
115