Spaces:
Sleeping
Sleeping
salmasoma commited on
Commit ·
feb0b0a
1
Parent(s): a11e8f7
Use local foundation MedGemma generation when remote API fails
Browse files
src/demo_backend/foundation_embeddings.py
CHANGED
|
@@ -86,6 +86,7 @@ def extract_foundation_embeddings(
|
|
| 86 |
status: Dict[str, str] = {}
|
| 87 |
siglib_embedding: Optional[torch.Tensor] = None
|
| 88 |
gemma_embedding: Optional[torch.Tensor] = None
|
|
|
|
| 89 |
use_cache = _cache_foundation_models()
|
| 90 |
|
| 91 |
# Extract MedGemma first (typically larger memory footprint), then release.
|
|
@@ -118,6 +119,21 @@ def extract_foundation_embeddings(
|
|
| 118 |
gemma_embedding = clinical_encoder.extract_embeddings([narrative], device=device).float()
|
| 119 |
model_type = getattr(clinical_encoder, "model_type", "unknown")
|
| 120 |
status["medgemma"] = f"{model_type}:{medgemma_model_name}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
if require_true_hf_models and model_type != "medgemma":
|
| 122 |
raise RuntimeError(
|
| 123 |
f"Expected MedGemma but got fallback model_type='{model_type}' "
|
|
@@ -165,5 +181,6 @@ def extract_foundation_embeddings(
|
|
| 165 |
return {
|
| 166 |
"siglib_embedding": siglib_embedding,
|
| 167 |
"gemma_embedding": gemma_embedding,
|
|
|
|
| 168 |
"status": status,
|
| 169 |
}
|
|
|
|
| 86 |
status: Dict[str, str] = {}
|
| 87 |
siglib_embedding: Optional[torch.Tensor] = None
|
| 88 |
gemma_embedding: Optional[torch.Tensor] = None
|
| 89 |
+
medgemma_local_output: Optional[str] = None
|
| 90 |
use_cache = _cache_foundation_models()
|
| 91 |
|
| 92 |
# Extract MedGemma first (typically larger memory footprint), then release.
|
|
|
|
| 119 |
gemma_embedding = clinical_encoder.extract_embeddings([narrative], device=device).float()
|
| 120 |
model_type = getattr(clinical_encoder, "model_type", "unknown")
|
| 121 |
status["medgemma"] = f"{model_type}:{medgemma_model_name}"
|
| 122 |
+
|
| 123 |
+
if model_type == "medgemma" and _is_true(os.getenv("HF_LOCAL_MEDGEMMA_REPORT"), default=True):
|
| 124 |
+
try:
|
| 125 |
+
local_prompt = (
|
| 126 |
+
"Given this patient summary and class probabilities, write a concise clinical report "
|
| 127 |
+
"with key evidence and one-line impression.\n\n"
|
| 128 |
+
f"Patient summary:\n{narrative}"
|
| 129 |
+
)
|
| 130 |
+
medgemma_local_output = clinical_encoder.generate_local_report(
|
| 131 |
+
prompt=local_prompt,
|
| 132 |
+
device=device,
|
| 133 |
+
max_new_tokens=160,
|
| 134 |
+
)
|
| 135 |
+
except Exception as local_exc:
|
| 136 |
+
status["medgemma_local_generation"] = f"error:{type(local_exc).__name__}: {_short_error(local_exc)}"
|
| 137 |
if require_true_hf_models and model_type != "medgemma":
|
| 138 |
raise RuntimeError(
|
| 139 |
f"Expected MedGemma but got fallback model_type='{model_type}' "
|
|
|
|
| 181 |
return {
|
| 182 |
"siglib_embedding": siglib_embedding,
|
| 183 |
"gemma_embedding": gemma_embedding,
|
| 184 |
+
"medgemma_local_output": medgemma_local_output,
|
| 185 |
"status": status,
|
| 186 |
}
|
src/demo_backend/neurofusion/medgemma_encoder.py
CHANGED
|
@@ -293,6 +293,42 @@ class MedGemmaEncoder(nn.Module):
|
|
| 293 |
self.eval()
|
| 294 |
return self.encode_text(narratives, device).float()
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
class StructuredClinicalEncoder(nn.Module):
|
| 298 |
"""MLP encoder for structured clinical features (demographics + health history).
|
|
|
|
| 293 |
self.eval()
|
| 294 |
return self.encode_text(narratives, device).float()
|
| 295 |
|
| 296 |
+
@torch.no_grad()
|
| 297 |
+
def generate_local_report(
|
| 298 |
+
self,
|
| 299 |
+
prompt: str,
|
| 300 |
+
device: torch.device,
|
| 301 |
+
max_new_tokens: int = 160,
|
| 302 |
+
) -> str:
|
| 303 |
+
"""Generate text locally with MedGemma when remote inference is unavailable."""
|
| 304 |
+
if self.model_type != "medgemma" or self.tokenizer is None or self.lm_backbone is None:
|
| 305 |
+
return ""
|
| 306 |
+
|
| 307 |
+
if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
|
| 308 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 309 |
+
|
| 310 |
+
inputs = self.tokenizer(
|
| 311 |
+
prompt,
|
| 312 |
+
truncation=True,
|
| 313 |
+
max_length=self.max_length,
|
| 314 |
+
return_tensors="pt",
|
| 315 |
+
).to(device)
|
| 316 |
+
|
| 317 |
+
generated = self.lm_backbone.generate(
|
| 318 |
+
**inputs,
|
| 319 |
+
max_new_tokens=max_new_tokens,
|
| 320 |
+
do_sample=False,
|
| 321 |
+
temperature=0.2,
|
| 322 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 323 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 324 |
+
)
|
| 325 |
+
text = self.tokenizer.decode(generated[0], skip_special_tokens=True).strip()
|
| 326 |
+
|
| 327 |
+
if text.startswith(prompt):
|
| 328 |
+
text = text[len(prompt) :].strip()
|
| 329 |
+
|
| 330 |
+
return text
|
| 331 |
+
|
| 332 |
|
| 333 |
class StructuredClinicalEncoder(nn.Module):
|
| 334 |
"""MLP encoder for structured clinical features (demographics + health history).
|
src/demo_backend/pipeline.py
CHANGED
|
@@ -67,6 +67,7 @@ def run_full_inference(
|
|
| 67 |
foundation = {
|
| 68 |
"siglib_embedding": None,
|
| 69 |
"gemma_embedding": None,
|
|
|
|
| 70 |
"status": {"medsiglip": "disabled", "medgemma": "disabled"},
|
| 71 |
}
|
| 72 |
if use_hf_foundation_embeddings:
|
|
@@ -101,6 +102,7 @@ def run_full_inference(
|
|
| 101 |
prediction=prediction,
|
| 102 |
enable_remote_llm=enable_remote_medgemma_report,
|
| 103 |
foundation_status=foundation["status"],
|
|
|
|
| 104 |
)
|
| 105 |
|
| 106 |
final_payload = {
|
|
@@ -113,6 +115,7 @@ def run_full_inference(
|
|
| 113 |
"avra_scores": avra_scores,
|
| 114 |
"clinical_narrative": narrative,
|
| 115 |
"foundation_embeddings": foundation["status"],
|
|
|
|
| 116 |
"medgemma_report": report,
|
| 117 |
"prediction": prediction,
|
| 118 |
}
|
|
|
|
| 67 |
foundation = {
|
| 68 |
"siglib_embedding": None,
|
| 69 |
"gemma_embedding": None,
|
| 70 |
+
"medgemma_local_output": None,
|
| 71 |
"status": {"medsiglip": "disabled", "medgemma": "disabled"},
|
| 72 |
}
|
| 73 |
if use_hf_foundation_embeddings:
|
|
|
|
| 102 |
prediction=prediction,
|
| 103 |
enable_remote_llm=enable_remote_medgemma_report,
|
| 104 |
foundation_status=foundation["status"],
|
| 105 |
+
local_medgemma_output=foundation.get("medgemma_local_output"),
|
| 106 |
)
|
| 107 |
|
| 108 |
final_payload = {
|
|
|
|
| 115 |
"avra_scores": avra_scores,
|
| 116 |
"clinical_narrative": narrative,
|
| 117 |
"foundation_embeddings": foundation["status"],
|
| 118 |
+
"medgemma_local_output": foundation.get("medgemma_local_output"),
|
| 119 |
"medgemma_report": report,
|
| 120 |
"prediction": prediction,
|
| 121 |
}
|
src/demo_backend/reporting.py
CHANGED
|
@@ -117,6 +117,7 @@ def generate_medgemma_report(
|
|
| 117 |
prediction: Dict,
|
| 118 |
enable_remote_llm: bool = True,
|
| 119 |
foundation_status: Optional[Mapping[str, str]] = None,
|
|
|
|
| 120 |
) -> Dict[str, str]:
|
| 121 |
"""Generate clinical report.
|
| 122 |
|
|
@@ -138,6 +139,16 @@ def generate_medgemma_report(
|
|
| 138 |
"medgemma_available": "false",
|
| 139 |
}
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 142 |
configured = os.getenv("MEDGEMMA_MODEL_ID", "").strip()
|
| 143 |
model_candidates = _build_model_candidates(configured)
|
|
|
|
| 117 |
prediction: Dict,
|
| 118 |
enable_remote_llm: bool = True,
|
| 119 |
foundation_status: Optional[Mapping[str, str]] = None,
|
| 120 |
+
local_medgemma_output: Optional[str] = None,
|
| 121 |
) -> Dict[str, str]:
|
| 122 |
"""Generate clinical report.
|
| 123 |
|
|
|
|
| 139 |
"medgemma_available": "false",
|
| 140 |
}
|
| 141 |
|
| 142 |
+
# If local MedGemma generation is available from the foundation encoder, use it.
|
| 143 |
+
if local_medgemma_output and local_medgemma_output.strip():
|
| 144 |
+
med_out = local_medgemma_output.strip()
|
| 145 |
+
return {
|
| 146 |
+
"report": _compose_report(base_narrative, med_out, prediction),
|
| 147 |
+
"source": "local_foundation_medgemma",
|
| 148 |
+
"medgemma_output": med_out,
|
| 149 |
+
"medgemma_available": "true",
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 153 |
configured = os.getenv("MEDGEMMA_MODEL_ID", "").strip()
|
| 154 |
model_candidates = _build_model_candidates(configured)
|