Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, Mapping | |
| from .avra import nifti_uid, run_avra_inference | |
| from .feature_engineering import ( | |
| build_clinical_narrative, | |
| build_model_tensors, | |
| build_raw_feature_dict, | |
| load_preprocessing_stats, | |
| normalize_feature_dict, | |
| ) | |
| from .foundation_embeddings import extract_foundation_embeddings | |
| from .modeling import load_model_from_checkpoint, predict_single, resolve_device | |
| from .reporting import generate_medgemma_report | |
| def run_full_inference( | |
| nifti_path: Path, | |
| ehr_row: Mapping[str, float], | |
| output_dir: Path, | |
| checkpoint_path: Path, | |
| stats_path: Path, | |
| avra_script_path: Path, | |
| device: str = "auto", | |
| allow_foundation_backbones: bool = False, | |
| use_hf_foundation_embeddings: bool = False, | |
| require_true_hf_embeddings: bool = False, | |
| reuse_registration: bool = False, | |
| use_fsl: bool = False, | |
| enable_remote_medgemma_report: bool = True, | |
| ) -> Dict: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| avra_json, avra_scores, coronal_path = run_avra_inference( | |
| nifti_path=nifti_path, | |
| output_dir=output_dir, | |
| avra_script=avra_script_path, | |
| reuse_registration=reuse_registration, | |
| use_fsl=use_fsl, | |
| ) | |
| raw_features = build_raw_feature_dict(ehr_row, avra_scores) | |
| stats = load_preprocessing_stats(stats_path) | |
| normalized_features = normalize_feature_dict(raw_features, stats) | |
| narrative = build_clinical_narrative(raw_features) | |
| torch_device = resolve_device(device) | |
| model, _, model_cfg = load_model_from_checkpoint( | |
| checkpoint_path=checkpoint_path, | |
| device=torch_device, | |
| force_lightweight_backbones=not allow_foundation_backbones, | |
| ) | |
| mri, avra, clinical = build_model_tensors( | |
| nifti_path=nifti_path, | |
| normalized_features=normalized_features, | |
| image_size=int(model_cfg.image_size), | |
| device=torch_device, | |
| ) | |
| foundation = { | |
| "siglib_embedding": None, | |
| "gemma_embedding": None, | |
| "medgemma_local_output": None, | |
| "status": {"medsiglip": "disabled", "medgemma": "disabled"}, | |
| } | |
| if use_hf_foundation_embeddings: | |
| medsiglip_model_id = os.getenv("HF_MEDSIGLIP_MODEL_ID", model_cfg.medsiglib_model_name) | |
| medgemma_model_id = os.getenv("HF_MEDGEMMA_MODEL_ID", model_cfg.medgemma_model_name) | |
| allow_bioclinical_fallback = ( | |
| (os.getenv("HF_ALLOW_BIOCLINICAL_FALLBACK", "0") or "0").strip().lower() | |
| in {"1", "true", "yes", "on"} | |
| ) | |
| foundation = extract_foundation_embeddings( | |
| mri_slices=mri, | |
| narrative=narrative, | |
| device=torch_device, | |
| medsiglip_model_name=medsiglip_model_id, | |
| medgemma_model_name=medgemma_model_id, | |
| require_true_hf_models=require_true_hf_embeddings, | |
| allow_bioclinical_fallback=allow_bioclinical_fallback, | |
| ) | |
| prediction = predict_single( | |
| model=model, | |
| mri=mri, | |
| avra=avra, | |
| clinical=clinical, | |
| narrative=narrative, | |
| siglib_embedding=foundation["siglib_embedding"], | |
| gemma_embedding=foundation["gemma_embedding"], | |
| ) | |
| report = generate_medgemma_report( | |
| base_narrative=narrative, | |
| prediction=prediction, | |
| enable_remote_llm=enable_remote_medgemma_report, | |
| foundation_status=foundation["status"], | |
| local_medgemma_output=foundation.get("medgemma_local_output"), | |
| ) | |
| final_payload = { | |
| "input_nifti": str(nifti_path), | |
| "checkpoint": str(checkpoint_path), | |
| "stats": str(stats_path), | |
| "avra_script": str(avra_script_path), | |
| "avra_json": str(avra_json), | |
| "avra_coronal_image": str(coronal_path), | |
| "avra_scores": avra_scores, | |
| "clinical_narrative": narrative, | |
| "foundation_embeddings": foundation["status"], | |
| "medgemma_local_output": foundation.get("medgemma_local_output"), | |
| "medgemma_report": report, | |
| "prediction": prediction, | |
| } | |
| out_path = output_dir / f"{nifti_uid(nifti_path)}_streamlit_demo_output.json" | |
| with out_path.open("w", encoding="utf-8") as f: | |
| json.dump(final_payload, f, indent=2) | |
| final_payload["output_json"] = str(out_path) | |
| return final_payload | |