Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import shutil | |
| import uuid | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict | |
| from urllib.parse import urlparse | |
| from urllib.request import urlopen | |
| import pandas as pd | |
| import streamlit as st | |
| from demo_backend.assets import ( | |
| PIPELINE_FIGURE_PATH, | |
| ensure_assets_from_hub, | |
| ) | |
| from demo_backend.constants import AVRA_COLS, CLINICAL_FEATURE_COLS | |
| from demo_backend.paths import ( | |
| DEFAULT_OUTPUT_DIR, | |
| EXAMPLE_EHR, | |
| PROJECT_ROOT, | |
| resolve_default_avra_script, | |
| resolve_default_checkpoint, | |
| resolve_default_stats, | |
| ) | |
| from demo_backend.pipeline import run_full_inference | |
| EXAMPLE_NIFTI = PROJECT_ROOT / "src" / "examples" / "example_case.nii.gz" | |
| st.set_page_config(page_title="HyperClinical Demo", layout="wide") | |
| EHR_UI_COLUMNS = [ | |
| "sex", | |
| "age", | |
| "edu", | |
| "race", | |
| "hispanic", | |
| "marriage", | |
| "declong", | |
| "decage", | |
| "smoke", | |
| "height", | |
| "weight", | |
| "bmi", | |
| "Heart attack/cardiac arrest", | |
| "Atrial fibrillation", | |
| "Angioplasty/endarterectomy/stent", | |
| "Cardiac bypass procedure", | |
| "Pacemaker", | |
| "Congestive heart failure", | |
| "Cardiovascular disease", | |
| "Cerebrovascular disease", | |
| "Parkinson", | |
| "Stroke/transient ischemic attack", | |
| ] | |
| UI_TO_INTERNAL = { | |
| "sex": "sex", | |
| "age": "age", | |
| "edu": "edu", | |
| "race": "race", | |
| "hispanic": "hispanic", | |
| "marriage": "marriage", | |
| "declong": "declong", | |
| "decage": "decage", | |
| "smoke": "smoke", | |
| "height": "height", | |
| "weight": "weight", | |
| "bmi": "bmi", | |
| "Heart attack/cardiac arrest": "health_history1", | |
| "Atrial fibrillation": "health_history2", | |
| "Angioplasty/endarterectomy/stent": "health_history3", | |
| "Cardiac bypass procedure": "health_history4", | |
| "Pacemaker": "health_history5", | |
| "Congestive heart failure": "health_history6", | |
| "Cardiovascular disease": "health_history7", | |
| "Cerebrovascular disease": "health_history10", | |
| "Parkinson": "health_history11", | |
| "Stroke/transient ischemic attack": "health_history12", | |
| } | |
| INTERNAL_TO_UI = {v: k for k, v in UI_TO_INTERNAL.items()} | |
| def _load_default_ehr() -> pd.DataFrame: | |
| if EXAMPLE_EHR.exists(): | |
| df = pd.read_csv(EXAMPLE_EHR).head(1) | |
| else: | |
| df = pd.DataFrame([{col: 0 for col in CLINICAL_FEATURE_COLS}]) | |
| one = {} | |
| for internal in CLINICAL_FEATURE_COLS: | |
| ui = INTERNAL_TO_UI[internal] | |
| if internal in df.columns: | |
| one[ui] = df.iloc[0][internal] | |
| else: | |
| one[ui] = 0 | |
| return pd.DataFrame([one])[EHR_UI_COLUMNS] | |
| def _resolve_user_path(raw_value: str, fallback: Path, allow_missing: bool = False) -> Path: | |
| path = Path(raw_value).expanduser() | |
| if not path.is_absolute(): | |
| path = (PROJECT_ROOT / path).resolve() | |
| if path.exists() or allow_missing: | |
| return path | |
| return fallback | |
| def _format_confidence(probabilities: Dict[str, float]) -> tuple[str, float]: | |
| top_label = max(probabilities, key=probabilities.get) | |
| return top_label, probabilities[top_label] | |
| def _download_nifti_from_url(url: str, target_dir: Path) -> Path: | |
| parsed = urlparse(url) | |
| suffix = ".nii.gz" if parsed.path.lower().endswith(".nii.gz") else ".nii" | |
| out_path = target_dir / f"uploaded_case_from_url{suffix}" | |
| with urlopen(url, timeout=60) as resp: | |
| with out_path.open("wb") as dst: | |
| shutil.copyfileobj(resp, dst) | |
| return out_path | |
| def _ehr_ui_to_internal_row(ui_row: Dict[str, float]) -> Dict[str, float]: | |
| return {UI_TO_INTERNAL[key]: ui_row.get(key, 0) for key in EHR_UI_COLUMNS} | |
| st.title("HyperClinical: Multimodal Dementia Subclassification") | |
| st.markdown( | |
| "HyperClinical combines AVRA atrophy scoring, structured clinical context, and multimodal fusion to predict " | |
| "fine-grained dementia subclasses. Fill the EHR table, provide MRI input, and inspect every stage output." | |
| ) | |
| if PIPELINE_FIGURE_PATH.exists(): | |
| st.image(str(PIPELINE_FIGURE_PATH), caption="HyperClinical Pipeline Overview", use_container_width=True) | |
| else: | |
| st.caption("Pipeline figure not found locally yet. It will be auto-downloaded from the configured HF assets repo.") | |
| st.caption("Runtime settings are fixed for demo use. Model assets are managed server-side.") | |
| output_root_input = "./outputs" | |
| assets_repo_id = os.getenv("HF_ASSETS_REPO_ID", "SalmaHassan/HyperClinical-assets").strip() | |
| assets_revision = (os.getenv("HF_ASSETS_REVISION", "main") or "main").strip() or "main" | |
| auto_download_assets = True | |
| device = "auto" | |
| reuse_registration = True | |
| use_fsl = False | |
| allow_foundation_backbones = True | |
| enable_remote_report = True | |
| use_hf_foundation_embeddings = True | |
| require_true_hf_embeddings = ( | |
| (os.getenv("HF_REQUIRE_FOUNDATION_MODELS", "0") or "0").strip().lower() in {"1", "true", "yes", "on"} | |
| ) | |
| st.subheader("1) MRI Input") | |
| input_mode = st.radio( | |
| "Select input mode", | |
| ["Upload file", "Public URL", "Built-in example"], | |
| horizontal=True, | |
| ) | |
| uploaded_file = None | |
| nifti_url = "" | |
| use_example_nifti = False | |
| if input_mode == "Upload file": | |
| uploaded_file = st.file_uploader("Upload a T1 MRI NIfTI (.nii or .nii.gz)", type=["nii", "gz"]) | |
| elif input_mode == "Public URL": | |
| nifti_url = st.text_input("Public URL to .nii or .nii.gz") | |
| else: | |
| use_example_nifti = True | |
| if EXAMPLE_NIFTI.exists(): | |
| st.caption(f"Using built-in sample: `{EXAMPLE_NIFTI.name}`") | |
| else: | |
| st.warning("Built-in sample MRI is not available in this deployment.") | |
| st.subheader("2) Enter EHR Features") | |
| if "ehr_df" not in st.session_state: | |
| st.session_state.ehr_df = _load_default_ehr() | |
| ehr_df = st.data_editor( | |
| st.session_state.ehr_df, | |
| num_rows="fixed", | |
| use_container_width=True, | |
| key="ehr_editor", | |
| ) | |
| if st.button("Run Full Inference", type="primary", use_container_width=True): | |
| if len(ehr_df) == 0: | |
| st.error("EHR table is empty. Add one row.") | |
| st.stop() | |
| if input_mode == "Upload file" and uploaded_file is None: | |
| st.error("Please upload a NIfTI file first.") | |
| st.stop() | |
| if input_mode == "Public URL" and not nifti_url.strip(): | |
| st.error("Please provide a public NIfTI URL.") | |
| st.stop() | |
| if input_mode == "Built-in example" and not EXAMPLE_NIFTI.exists(): | |
| st.error("Built-in sample MRI is missing. Please use upload or URL mode.") | |
| st.stop() | |
| if auto_download_assets: | |
| with st.spinner("Ensuring required assets are available..."): | |
| try: | |
| ensure_assets_from_hub(repo_id=assets_repo_id, revision=assets_revision) | |
| except Exception as exc: | |
| st.error(f"Could not download required assets: {exc}") | |
| st.caption("Check HF assets repo configuration and `HF_TOKEN` (if private).") | |
| st.stop() | |
| checkpoint_path = resolve_default_checkpoint() | |
| stats_path = resolve_default_stats() | |
| avra_script_path = resolve_default_avra_script() | |
| output_root = _resolve_user_path(output_root_input, DEFAULT_OUTPUT_DIR, allow_missing=True) | |
| if not checkpoint_path.exists(): | |
| st.error(f"Checkpoint not found: {checkpoint_path}") | |
| st.stop() | |
| if not stats_path.exists(): | |
| st.error(f"Stats file not found: {stats_path}") | |
| st.stop() | |
| if not avra_script_path.exists(): | |
| st.error(f"AVRA script not found: {avra_script_path}") | |
| st.stop() | |
| run_id = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6] | |
| run_dir = output_root / run_id | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| if input_mode == "Upload file": | |
| file_name = uploaded_file.name | |
| if not (file_name.endswith(".nii") or file_name.endswith(".nii.gz")): | |
| st.error("Uploaded file must end with .nii or .nii.gz") | |
| st.stop() | |
| safe_name = "uploaded_case.nii.gz" if file_name.endswith(".nii.gz") else "uploaded_case.nii" | |
| nifti_path = run_dir / safe_name | |
| with nifti_path.open("wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| elif input_mode == "Public URL": | |
| with st.spinner("Downloading NIfTI from URL..."): | |
| try: | |
| nifti_path = _download_nifti_from_url(nifti_url.strip(), run_dir) | |
| except Exception as exc: | |
| st.error(f"Failed to download NIfTI from URL: {exc}") | |
| st.stop() | |
| else: | |
| nifti_path = run_dir / EXAMPLE_NIFTI.name | |
| shutil.copy2(EXAMPLE_NIFTI, nifti_path) | |
| ehr_row_internal = _ehr_ui_to_internal_row(ehr_df.iloc[0].to_dict()) | |
| with st.spinner("Running AVRA + HyperClinical inference..."): | |
| try: | |
| result = run_full_inference( | |
| nifti_path=nifti_path, | |
| ehr_row=ehr_row_internal, | |
| output_dir=run_dir, | |
| checkpoint_path=checkpoint_path, | |
| stats_path=stats_path, | |
| avra_script_path=avra_script_path, | |
| device=device, | |
| allow_foundation_backbones=allow_foundation_backbones, | |
| use_hf_foundation_embeddings=use_hf_foundation_embeddings, | |
| require_true_hf_embeddings=require_true_hf_embeddings, | |
| reuse_registration=reuse_registration, | |
| use_fsl=use_fsl, | |
| enable_remote_medgemma_report=enable_remote_report, | |
| ) | |
| except Exception as exc: | |
| st.exception(exc) | |
| st.stop() | |
| st.success("Inference complete") | |
| tab1, tab2, tab3, tab4 = st.tabs([ | |
| "Stage 1: Atrophy (AVRA)", | |
| "Stage 2: Clinical Report", | |
| "Stage 3: Diagnosis", | |
| "Raw JSON", | |
| ]) | |
| with tab1: | |
| st.markdown("### AVRA Atrophy Scores") | |
| avra_scores = result["avra_scores"] | |
| cols = st.columns(4) | |
| for i, key in enumerate(AVRA_COLS): | |
| cols[i].metric(label=key, value=f"{float(avra_scores[key]):.3f}") | |
| st.dataframe(pd.DataFrame([avra_scores]), use_container_width=True) | |
| coronal_path = Path(result["avra_coronal_image"]) | |
| if coronal_path.exists(): | |
| st.image(str(coronal_path), caption="AVRA coronal slice output") | |
| with tab2: | |
| st.markdown("### MedGemma-style Clinical Narrative") | |
| st.text_area("Clinical narrative", value=result["clinical_narrative"], height=220) | |
| st.markdown("### Foundation Embedding Sources") | |
| foundation_status = result.get("foundation_embeddings", {}) | |
| st.json(foundation_status) | |
| med_status = str(foundation_status.get("medgemma", "")).lower() | |
| if "fallback_mlp_only" in med_status or med_status.startswith("error:"): | |
| st.warning( | |
| "MedGemma did not load from HF for this run. " | |
| "Common causes: outdated `transformers` (need >=4.57.1), missing HF token for gated models, " | |
| "or insufficient runtime memory." | |
| ) | |
| st.markdown("### Clinical Report") | |
| report = result["medgemma_report"] | |
| st.caption(f"Source: {report.get('source', 'unknown')}") | |
| st.text_area("Report", value=report.get("report", ""), height=280) | |
| with tab3: | |
| st.markdown("### Final Diagnosis") | |
| pred = result["prediction"] | |
| probs = pred["class_probabilities"] | |
| _, top_prob = _format_confidence(probs) | |
| col_a, col_b = st.columns(2) | |
| col_a.metric("Predicted Class", pred["predicted_class_name"]) | |
| col_b.metric("Confidence", f"{top_prob * 100:.2f}%") | |
| probs_df = pd.DataFrame( | |
| {"Class": list(probs.keys()), "Probability": list(probs.values())} | |
| ).sort_values("Probability", ascending=False) | |
| st.bar_chart(probs_df.set_index("Class")) | |
| with tab4: | |
| st.json(result) | |
| json_bytes = json.dumps(result, indent=2).encode("utf-8") | |
| st.download_button( | |
| label="Download result JSON", | |
| data=json_bytes, | |
| file_name=f"hyperclinical_result_{run_id}.json", | |
| mime="application/json", | |
| ) | |