HyperClinical / src /demo_backend /pipeline.py
salmasoma
Use local foundation MedGemma generation when remote API fails
feb0b0a
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Dict, Mapping
from .avra import nifti_uid, run_avra_inference
from .feature_engineering import (
build_clinical_narrative,
build_model_tensors,
build_raw_feature_dict,
load_preprocessing_stats,
normalize_feature_dict,
)
from .foundation_embeddings import extract_foundation_embeddings
from .modeling import load_model_from_checkpoint, predict_single, resolve_device
from .reporting import generate_medgemma_report
def run_full_inference(
nifti_path: Path,
ehr_row: Mapping[str, float],
output_dir: Path,
checkpoint_path: Path,
stats_path: Path,
avra_script_path: Path,
device: str = "auto",
allow_foundation_backbones: bool = False,
use_hf_foundation_embeddings: bool = False,
require_true_hf_embeddings: bool = False,
reuse_registration: bool = False,
use_fsl: bool = False,
enable_remote_medgemma_report: bool = True,
) -> Dict:
output_dir.mkdir(parents=True, exist_ok=True)
avra_json, avra_scores, coronal_path = run_avra_inference(
nifti_path=nifti_path,
output_dir=output_dir,
avra_script=avra_script_path,
reuse_registration=reuse_registration,
use_fsl=use_fsl,
)
raw_features = build_raw_feature_dict(ehr_row, avra_scores)
stats = load_preprocessing_stats(stats_path)
normalized_features = normalize_feature_dict(raw_features, stats)
narrative = build_clinical_narrative(raw_features)
torch_device = resolve_device(device)
model, _, model_cfg = load_model_from_checkpoint(
checkpoint_path=checkpoint_path,
device=torch_device,
force_lightweight_backbones=not allow_foundation_backbones,
)
mri, avra, clinical = build_model_tensors(
nifti_path=nifti_path,
normalized_features=normalized_features,
image_size=int(model_cfg.image_size),
device=torch_device,
)
foundation = {
"siglib_embedding": None,
"gemma_embedding": None,
"medgemma_local_output": None,
"status": {"medsiglip": "disabled", "medgemma": "disabled"},
}
if use_hf_foundation_embeddings:
medsiglip_model_id = os.getenv("HF_MEDSIGLIP_MODEL_ID", model_cfg.medsiglib_model_name)
medgemma_model_id = os.getenv("HF_MEDGEMMA_MODEL_ID", model_cfg.medgemma_model_name)
allow_bioclinical_fallback = (
(os.getenv("HF_ALLOW_BIOCLINICAL_FALLBACK", "0") or "0").strip().lower()
in {"1", "true", "yes", "on"}
)
foundation = extract_foundation_embeddings(
mri_slices=mri,
narrative=narrative,
device=torch_device,
medsiglip_model_name=medsiglip_model_id,
medgemma_model_name=medgemma_model_id,
require_true_hf_models=require_true_hf_embeddings,
allow_bioclinical_fallback=allow_bioclinical_fallback,
)
prediction = predict_single(
model=model,
mri=mri,
avra=avra,
clinical=clinical,
narrative=narrative,
siglib_embedding=foundation["siglib_embedding"],
gemma_embedding=foundation["gemma_embedding"],
)
report = generate_medgemma_report(
base_narrative=narrative,
prediction=prediction,
enable_remote_llm=enable_remote_medgemma_report,
foundation_status=foundation["status"],
local_medgemma_output=foundation.get("medgemma_local_output"),
)
final_payload = {
"input_nifti": str(nifti_path),
"checkpoint": str(checkpoint_path),
"stats": str(stats_path),
"avra_script": str(avra_script_path),
"avra_json": str(avra_json),
"avra_coronal_image": str(coronal_path),
"avra_scores": avra_scores,
"clinical_narrative": narrative,
"foundation_embeddings": foundation["status"],
"medgemma_local_output": foundation.get("medgemma_local_output"),
"medgemma_report": report,
"prediction": prediction,
}
out_path = output_dir / f"{nifti_uid(nifti_path)}_streamlit_demo_output.json"
with out_path.open("w", encoding="utf-8") as f:
json.dump(final_payload, f, indent=2)
final_payload["output_json"] = str(out_path)
return final_payload