salmasoma commited on
Commit
feb0b0a
·
1 Parent(s): a11e8f7

Use local foundation MedGemma generation when remote API fails

Browse files
src/demo_backend/foundation_embeddings.py CHANGED
@@ -86,6 +86,7 @@ def extract_foundation_embeddings(
86
  status: Dict[str, str] = {}
87
  siglib_embedding: Optional[torch.Tensor] = None
88
  gemma_embedding: Optional[torch.Tensor] = None
 
89
  use_cache = _cache_foundation_models()
90
 
91
  # Extract MedGemma first (typically larger memory footprint), then release.
@@ -118,6 +119,21 @@ def extract_foundation_embeddings(
118
  gemma_embedding = clinical_encoder.extract_embeddings([narrative], device=device).float()
119
  model_type = getattr(clinical_encoder, "model_type", "unknown")
120
  status["medgemma"] = f"{model_type}:{medgemma_model_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if require_true_hf_models and model_type != "medgemma":
122
  raise RuntimeError(
123
  f"Expected MedGemma but got fallback model_type='{model_type}' "
@@ -165,5 +181,6 @@ def extract_foundation_embeddings(
165
  return {
166
  "siglib_embedding": siglib_embedding,
167
  "gemma_embedding": gemma_embedding,
 
168
  "status": status,
169
  }
 
86
  status: Dict[str, str] = {}
87
  siglib_embedding: Optional[torch.Tensor] = None
88
  gemma_embedding: Optional[torch.Tensor] = None
89
+ medgemma_local_output: Optional[str] = None
90
  use_cache = _cache_foundation_models()
91
 
92
  # Extract MedGemma first (typically larger memory footprint), then release.
 
119
  gemma_embedding = clinical_encoder.extract_embeddings([narrative], device=device).float()
120
  model_type = getattr(clinical_encoder, "model_type", "unknown")
121
  status["medgemma"] = f"{model_type}:{medgemma_model_name}"
122
+
123
+ if model_type == "medgemma" and _is_true(os.getenv("HF_LOCAL_MEDGEMMA_REPORT"), default=True):
124
+ try:
125
+ local_prompt = (
126
+ "Given this patient summary and class probabilities, write a concise clinical report "
127
+ "with key evidence and one-line impression.\n\n"
128
+ f"Patient summary:\n{narrative}"
129
+ )
130
+ medgemma_local_output = clinical_encoder.generate_local_report(
131
+ prompt=local_prompt,
132
+ device=device,
133
+ max_new_tokens=160,
134
+ )
135
+ except Exception as local_exc:
136
+ status["medgemma_local_generation"] = f"error:{type(local_exc).__name__}: {_short_error(local_exc)}"
137
  if require_true_hf_models and model_type != "medgemma":
138
  raise RuntimeError(
139
  f"Expected MedGemma but got fallback model_type='{model_type}' "
 
181
  return {
182
  "siglib_embedding": siglib_embedding,
183
  "gemma_embedding": gemma_embedding,
184
+ "medgemma_local_output": medgemma_local_output,
185
  "status": status,
186
  }
src/demo_backend/neurofusion/medgemma_encoder.py CHANGED
@@ -293,6 +293,42 @@ class MedGemmaEncoder(nn.Module):
293
  self.eval()
294
  return self.encode_text(narratives, device).float()
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  class StructuredClinicalEncoder(nn.Module):
298
  """MLP encoder for structured clinical features (demographics + health history).
 
293
  self.eval()
294
  return self.encode_text(narratives, device).float()
295
 
296
+ @torch.no_grad()
297
+ def generate_local_report(
298
+ self,
299
+ prompt: str,
300
+ device: torch.device,
301
+ max_new_tokens: int = 160,
302
+ ) -> str:
303
+ """Generate text locally with MedGemma when remote inference is unavailable."""
304
+ if self.model_type != "medgemma" or self.tokenizer is None or self.lm_backbone is None:
305
+ return ""
306
+
307
+ if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
308
+ self.tokenizer.pad_token = self.tokenizer.eos_token
309
+
310
+ inputs = self.tokenizer(
311
+ prompt,
312
+ truncation=True,
313
+ max_length=self.max_length,
314
+ return_tensors="pt",
315
+ ).to(device)
316
+
317
+ generated = self.lm_backbone.generate(
318
+ **inputs,
319
+ max_new_tokens=max_new_tokens,
320
+ do_sample=False,
321
+ temperature=0.2,
322
+ pad_token_id=self.tokenizer.pad_token_id,
323
+ eos_token_id=self.tokenizer.eos_token_id,
324
+ )
325
+ text = self.tokenizer.decode(generated[0], skip_special_tokens=True).strip()
326
+
327
+ if text.startswith(prompt):
328
+ text = text[len(prompt) :].strip()
329
+
330
+ return text
331
+
332
 
333
  class StructuredClinicalEncoder(nn.Module):
334
  """MLP encoder for structured clinical features (demographics + health history).
src/demo_backend/pipeline.py CHANGED
@@ -67,6 +67,7 @@ def run_full_inference(
67
  foundation = {
68
  "siglib_embedding": None,
69
  "gemma_embedding": None,
 
70
  "status": {"medsiglip": "disabled", "medgemma": "disabled"},
71
  }
72
  if use_hf_foundation_embeddings:
@@ -101,6 +102,7 @@ def run_full_inference(
101
  prediction=prediction,
102
  enable_remote_llm=enable_remote_medgemma_report,
103
  foundation_status=foundation["status"],
 
104
  )
105
 
106
  final_payload = {
@@ -113,6 +115,7 @@ def run_full_inference(
113
  "avra_scores": avra_scores,
114
  "clinical_narrative": narrative,
115
  "foundation_embeddings": foundation["status"],
 
116
  "medgemma_report": report,
117
  "prediction": prediction,
118
  }
 
67
  foundation = {
68
  "siglib_embedding": None,
69
  "gemma_embedding": None,
70
+ "medgemma_local_output": None,
71
  "status": {"medsiglip": "disabled", "medgemma": "disabled"},
72
  }
73
  if use_hf_foundation_embeddings:
 
102
  prediction=prediction,
103
  enable_remote_llm=enable_remote_medgemma_report,
104
  foundation_status=foundation["status"],
105
+ local_medgemma_output=foundation.get("medgemma_local_output"),
106
  )
107
 
108
  final_payload = {
 
115
  "avra_scores": avra_scores,
116
  "clinical_narrative": narrative,
117
  "foundation_embeddings": foundation["status"],
118
+ "medgemma_local_output": foundation.get("medgemma_local_output"),
119
  "medgemma_report": report,
120
  "prediction": prediction,
121
  }
src/demo_backend/reporting.py CHANGED
@@ -117,6 +117,7 @@ def generate_medgemma_report(
117
  prediction: Dict,
118
  enable_remote_llm: bool = True,
119
  foundation_status: Optional[Mapping[str, str]] = None,
 
120
  ) -> Dict[str, str]:
121
  """Generate clinical report.
122
 
@@ -138,6 +139,16 @@ def generate_medgemma_report(
138
  "medgemma_available": "false",
139
  }
140
 
 
 
 
 
 
 
 
 
 
 
141
  token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
142
  configured = os.getenv("MEDGEMMA_MODEL_ID", "").strip()
143
  model_candidates = _build_model_candidates(configured)
 
117
  prediction: Dict,
118
  enable_remote_llm: bool = True,
119
  foundation_status: Optional[Mapping[str, str]] = None,
120
+ local_medgemma_output: Optional[str] = None,
121
  ) -> Dict[str, str]:
122
  """Generate clinical report.
123
 
 
139
  "medgemma_available": "false",
140
  }
141
 
142
+ # If local MedGemma generation is available from the foundation encoder, use it.
143
+ if local_medgemma_output and local_medgemma_output.strip():
144
+ med_out = local_medgemma_output.strip()
145
+ return {
146
+ "report": _compose_report(base_narrative, med_out, prediction),
147
+ "source": "local_foundation_medgemma",
148
+ "medgemma_output": med_out,
149
+ "medgemma_available": "true",
150
+ }
151
+
152
  token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
153
  configured = os.getenv("MEDGEMMA_MODEL_ID", "").strip()
154
  model_candidates = _build_model_candidates(configured)