SamarpeetGarad commited on
Commit
ff529e6
·
verified ·
1 Parent(s): 3a161e0

Upload agents/priority_router.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. agents/priority_router.py +18 -51
agents/priority_router.py CHANGED
@@ -8,13 +8,12 @@ from typing import Any, Dict, Optional, List
8
 
9
  from .base_agent import BaseAgent, AgentResult
10
 
11
- # Try to import torch and transformers
12
  try:
13
- import torch
14
- from transformers import AutoModelForCausalLM, AutoTokenizer
15
- TORCH_AVAILABLE = True
16
  except ImportError:
17
- TORCH_AVAILABLE = False
18
 
19
 
20
  class PriorityRouterAgent(BaseAgent):
@@ -22,7 +21,7 @@ class PriorityRouterAgent(BaseAgent):
22
  Agent 4: MedGemma Priority Router
23
 
24
  Assesses case urgency and determines appropriate routing
25
- based on radiology report and findings.
26
  """
27
 
28
  # Priority level definitions
@@ -68,39 +67,20 @@ class PriorityRouterAgent(BaseAgent):
68
  model_name="google/medgemma-4b-it"
69
  )
70
  self.demo_mode = demo_mode
71
- self.tokenizer = None
72
 
73
  def load_model(self) -> bool:
74
- """Load MedGemma model for priority assessment."""
75
- if self.demo_mode:
76
- self.is_loaded = True
77
- return True
78
-
79
- if not TORCH_AVAILABLE:
80
- self.demo_mode = True
81
  self.is_loaded = True
82
  return True
83
 
84
  try:
85
- self.tokenizer = AutoTokenizer.from_pretrained(
86
- self.model_name,
87
- trust_remote_code=True
88
- )
89
-
90
- self.model = AutoModelForCausalLM.from_pretrained(
91
- self.model_name,
92
- trust_remote_code=True,
93
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
94
- device_map="auto" if torch.cuda.is_available() else None,
95
- low_cpu_mem_usage=True
96
- )
97
-
98
- self.model.eval()
99
- self.is_loaded = True
100
  return True
101
-
102
  except Exception as e:
103
- print(f"Failed to load model for priority routing: {e}")
104
  self.demo_mode = True
105
  self.is_loaded = True
106
  return True
@@ -135,13 +115,13 @@ class PriorityRouterAgent(BaseAgent):
135
  # Get original findings if passed through context
136
  original_findings = context.get("original_findings", []) if context else []
137
 
138
- # Process based on mode
139
- if self.demo_mode:
140
- routing = self._simulate_priority_assessment(
141
  report_sections, full_report, findings_count, original_findings, context
142
  )
143
  else:
144
- routing = self._run_model_inference(
145
  report_sections, full_report, findings_count, original_findings, context
146
  )
147
 
@@ -162,25 +142,12 @@ class PriorityRouterAgent(BaseAgent):
162
  original_findings: List[Dict],
163
  context: Optional[Dict]
164
  ) -> Dict:
165
- """Use MedGemma to assess priority."""
166
  try:
167
  prompt = self._build_priority_prompt(full_report, original_findings)
168
 
169
- inputs = self.tokenizer(prompt, return_tensors="pt")
170
- if torch.cuda.is_available():
171
- inputs = {k: v.cuda() for k, v in inputs.items()}
172
-
173
- with torch.no_grad():
174
- outputs = self.model.generate(
175
- **inputs,
176
- max_new_tokens=512,
177
- temperature=0.2, # Lower temperature for more deterministic output
178
- top_p=0.9,
179
- do_sample=True,
180
- pad_token_id=self.tokenizer.eos_token_id
181
- )
182
-
183
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
184
 
185
  return self._parse_priority_response(response, original_findings)
186
 
 
8
 
9
  from .base_agent import BaseAgent, AgentResult
10
 
11
+ # Import the unified MedGemma engine
12
  try:
13
+ from .medgemma_engine import get_engine, MedGemmaEngine
14
+ ENGINE_AVAILABLE = True
 
15
  except ImportError:
16
+ ENGINE_AVAILABLE = False
17
 
18
 
19
  class PriorityRouterAgent(BaseAgent):
 
21
  Agent 4: MedGemma Priority Router
22
 
23
  Assesses case urgency and determines appropriate routing
24
+ based on radiology report and findings using MedGemma.
25
  """
26
 
27
  # Priority level definitions
 
67
  model_name="google/medgemma-4b-it"
68
  )
69
  self.demo_mode = demo_mode
70
+ self.engine = None
71
 
72
  def load_model(self) -> bool:
73
+ """Load MedGemma model via unified engine."""
74
+ if self.demo_mode or not ENGINE_AVAILABLE:
 
 
 
 
 
75
  self.is_loaded = True
76
  return True
77
 
78
  try:
79
+ self.engine = get_engine(force_demo=self.demo_mode)
80
+ self.is_loaded = self.engine.is_loaded
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  return True
 
82
  except Exception as e:
83
+ print(f"Failed to load MedGemma engine: {e}")
84
  self.demo_mode = True
85
  self.is_loaded = True
86
  return True
 
115
  # Get original findings if passed through context
116
  original_findings = context.get("original_findings", []) if context else []
117
 
118
+ # Process - always try to use real model if available
119
+ if self.engine and self.engine.is_loaded and self.engine.backend != "demo":
120
+ routing = self._run_model_inference(
121
  report_sections, full_report, findings_count, original_findings, context
122
  )
123
  else:
124
+ routing = self._simulate_priority_assessment(
125
  report_sections, full_report, findings_count, original_findings, context
126
  )
127
 
 
142
  original_findings: List[Dict],
143
  context: Optional[Dict]
144
  ) -> Dict:
145
+ """Use MedGemma to assess priority via unified engine."""
146
  try:
147
  prompt = self._build_priority_prompt(full_report, original_findings)
148
 
149
+ # Use the unified engine to assess priority
150
+ response = self.engine.generate(prompt, max_tokens=256)
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  return self._parse_priority_response(response, original_findings)
153