#!/usr/bin/env python3 from __future__ import annotations import json import os import shutil import uuid from datetime import datetime from pathlib import Path from typing import Dict from urllib.parse import urlparse from urllib.request import urlopen import pandas as pd import streamlit as st from demo_backend.assets import ( PIPELINE_FIGURE_PATH, ensure_assets_from_hub, ) from demo_backend.constants import AVRA_COLS, CLINICAL_FEATURE_COLS from demo_backend.paths import ( DEFAULT_OUTPUT_DIR, EXAMPLE_EHR, PROJECT_ROOT, resolve_default_avra_script, resolve_default_checkpoint, resolve_default_stats, ) from demo_backend.pipeline import run_full_inference EXAMPLE_NIFTI = PROJECT_ROOT / "src" / "examples" / "example_case.nii.gz" st.set_page_config(page_title="HyperClinical Demo", layout="wide") EHR_UI_COLUMNS = [ "sex", "age", "edu", "race", "hispanic", "marriage", "declong", "decage", "smoke", "height", "weight", "bmi", "Heart attack/cardiac arrest", "Atrial fibrillation", "Angioplasty/endarterectomy/stent", "Cardiac bypass procedure", "Pacemaker", "Congestive heart failure", "Cardiovascular disease", "Cerebrovascular disease", "Parkinson", "Stroke/transient ischemic attack", ] UI_TO_INTERNAL = { "sex": "sex", "age": "age", "edu": "edu", "race": "race", "hispanic": "hispanic", "marriage": "marriage", "declong": "declong", "decage": "decage", "smoke": "smoke", "height": "height", "weight": "weight", "bmi": "bmi", "Heart attack/cardiac arrest": "health_history1", "Atrial fibrillation": "health_history2", "Angioplasty/endarterectomy/stent": "health_history3", "Cardiac bypass procedure": "health_history4", "Pacemaker": "health_history5", "Congestive heart failure": "health_history6", "Cardiovascular disease": "health_history7", "Cerebrovascular disease": "health_history10", "Parkinson": "health_history11", "Stroke/transient ischemic attack": "health_history12", } INTERNAL_TO_UI = {v: k for k, v in UI_TO_INTERNAL.items()} def _load_default_ehr() -> pd.DataFrame: if EXAMPLE_EHR.exists(): df = pd.read_csv(EXAMPLE_EHR).head(1) else: df = pd.DataFrame([{col: 0 for col in CLINICAL_FEATURE_COLS}]) one = {} for internal in CLINICAL_FEATURE_COLS: ui = INTERNAL_TO_UI[internal] if internal in df.columns: one[ui] = df.iloc[0][internal] else: one[ui] = 0 return pd.DataFrame([one])[EHR_UI_COLUMNS] def _resolve_user_path(raw_value: str, fallback: Path, allow_missing: bool = False) -> Path: path = Path(raw_value).expanduser() if not path.is_absolute(): path = (PROJECT_ROOT / path).resolve() if path.exists() or allow_missing: return path return fallback def _format_confidence(probabilities: Dict[str, float]) -> tuple[str, float]: top_label = max(probabilities, key=probabilities.get) return top_label, probabilities[top_label] def _download_nifti_from_url(url: str, target_dir: Path) -> Path: parsed = urlparse(url) suffix = ".nii.gz" if parsed.path.lower().endswith(".nii.gz") else ".nii" out_path = target_dir / f"uploaded_case_from_url{suffix}" with urlopen(url, timeout=60) as resp: with out_path.open("wb") as dst: shutil.copyfileobj(resp, dst) return out_path def _ehr_ui_to_internal_row(ui_row: Dict[str, float]) -> Dict[str, float]: return {UI_TO_INTERNAL[key]: ui_row.get(key, 0) for key in EHR_UI_COLUMNS} st.title("HyperClinical: Multimodal Dementia Subclassification") st.markdown( "HyperClinical combines AVRA atrophy scoring, structured clinical context, and multimodal fusion to predict " "fine-grained dementia subclasses. Fill the EHR table, provide MRI input, and inspect every stage output." ) if PIPELINE_FIGURE_PATH.exists(): st.image(str(PIPELINE_FIGURE_PATH), caption="HyperClinical Pipeline Overview", use_container_width=True) else: st.caption("Pipeline figure not found locally yet. It will be auto-downloaded from the configured HF assets repo.") st.caption("Runtime settings are fixed for demo use. Model assets are managed server-side.") output_root_input = "./outputs" assets_repo_id = os.getenv("HF_ASSETS_REPO_ID", "SalmaHassan/HyperClinical-assets").strip() assets_revision = (os.getenv("HF_ASSETS_REVISION", "main") or "main").strip() or "main" auto_download_assets = True device = "auto" reuse_registration = True use_fsl = False allow_foundation_backbones = True enable_remote_report = True use_hf_foundation_embeddings = True require_true_hf_embeddings = ( (os.getenv("HF_REQUIRE_FOUNDATION_MODELS", "0") or "0").strip().lower() in {"1", "true", "yes", "on"} ) st.subheader("1) MRI Input") input_mode = st.radio( "Select input mode", ["Upload file", "Public URL", "Built-in example"], horizontal=True, ) uploaded_file = None nifti_url = "" use_example_nifti = False if input_mode == "Upload file": uploaded_file = st.file_uploader("Upload a T1 MRI NIfTI (.nii or .nii.gz)", type=["nii", "gz"]) elif input_mode == "Public URL": nifti_url = st.text_input("Public URL to .nii or .nii.gz") else: use_example_nifti = True if EXAMPLE_NIFTI.exists(): st.caption(f"Using built-in sample: `{EXAMPLE_NIFTI.name}`") else: st.warning("Built-in sample MRI is not available in this deployment.") st.subheader("2) Enter EHR Features") if "ehr_df" not in st.session_state: st.session_state.ehr_df = _load_default_ehr() ehr_df = st.data_editor( st.session_state.ehr_df, num_rows="fixed", use_container_width=True, key="ehr_editor", ) if st.button("Run Full Inference", type="primary", use_container_width=True): if len(ehr_df) == 0: st.error("EHR table is empty. Add one row.") st.stop() if input_mode == "Upload file" and uploaded_file is None: st.error("Please upload a NIfTI file first.") st.stop() if input_mode == "Public URL" and not nifti_url.strip(): st.error("Please provide a public NIfTI URL.") st.stop() if input_mode == "Built-in example" and not EXAMPLE_NIFTI.exists(): st.error("Built-in sample MRI is missing. Please use upload or URL mode.") st.stop() if auto_download_assets: with st.spinner("Ensuring required assets are available..."): try: ensure_assets_from_hub(repo_id=assets_repo_id, revision=assets_revision) except Exception as exc: st.error(f"Could not download required assets: {exc}") st.caption("Check HF assets repo configuration and `HF_TOKEN` (if private).") st.stop() checkpoint_path = resolve_default_checkpoint() stats_path = resolve_default_stats() avra_script_path = resolve_default_avra_script() output_root = _resolve_user_path(output_root_input, DEFAULT_OUTPUT_DIR, allow_missing=True) if not checkpoint_path.exists(): st.error(f"Checkpoint not found: {checkpoint_path}") st.stop() if not stats_path.exists(): st.error(f"Stats file not found: {stats_path}") st.stop() if not avra_script_path.exists(): st.error(f"AVRA script not found: {avra_script_path}") st.stop() run_id = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6] run_dir = output_root / run_id run_dir.mkdir(parents=True, exist_ok=True) if input_mode == "Upload file": file_name = uploaded_file.name if not (file_name.endswith(".nii") or file_name.endswith(".nii.gz")): st.error("Uploaded file must end with .nii or .nii.gz") st.stop() safe_name = "uploaded_case.nii.gz" if file_name.endswith(".nii.gz") else "uploaded_case.nii" nifti_path = run_dir / safe_name with nifti_path.open("wb") as f: f.write(uploaded_file.getbuffer()) elif input_mode == "Public URL": with st.spinner("Downloading NIfTI from URL..."): try: nifti_path = _download_nifti_from_url(nifti_url.strip(), run_dir) except Exception as exc: st.error(f"Failed to download NIfTI from URL: {exc}") st.stop() else: nifti_path = run_dir / EXAMPLE_NIFTI.name shutil.copy2(EXAMPLE_NIFTI, nifti_path) ehr_row_internal = _ehr_ui_to_internal_row(ehr_df.iloc[0].to_dict()) with st.spinner("Running AVRA + HyperClinical inference..."): try: result = run_full_inference( nifti_path=nifti_path, ehr_row=ehr_row_internal, output_dir=run_dir, checkpoint_path=checkpoint_path, stats_path=stats_path, avra_script_path=avra_script_path, device=device, allow_foundation_backbones=allow_foundation_backbones, use_hf_foundation_embeddings=use_hf_foundation_embeddings, require_true_hf_embeddings=require_true_hf_embeddings, reuse_registration=reuse_registration, use_fsl=use_fsl, enable_remote_medgemma_report=enable_remote_report, ) except Exception as exc: st.exception(exc) st.stop() st.success("Inference complete") tab1, tab2, tab3, tab4 = st.tabs([ "Stage 1: Atrophy (AVRA)", "Stage 2: Clinical Report", "Stage 3: Diagnosis", "Raw JSON", ]) with tab1: st.markdown("### AVRA Atrophy Scores") avra_scores = result["avra_scores"] cols = st.columns(4) for i, key in enumerate(AVRA_COLS): cols[i].metric(label=key, value=f"{float(avra_scores[key]):.3f}") st.dataframe(pd.DataFrame([avra_scores]), use_container_width=True) coronal_path = Path(result["avra_coronal_image"]) if coronal_path.exists(): st.image(str(coronal_path), caption="AVRA coronal slice output") with tab2: st.markdown("### MedGemma-style Clinical Narrative") st.text_area("Clinical narrative", value=result["clinical_narrative"], height=220) st.markdown("### Foundation Embedding Sources") foundation_status = result.get("foundation_embeddings", {}) st.json(foundation_status) med_status = str(foundation_status.get("medgemma", "")).lower() if "fallback_mlp_only" in med_status or med_status.startswith("error:"): st.warning( "MedGemma did not load from HF for this run. " "Common causes: outdated `transformers` (need >=4.57.1), missing HF token for gated models, " "or insufficient runtime memory." ) st.markdown("### Clinical Report") report = result["medgemma_report"] st.caption(f"Source: {report.get('source', 'unknown')}") st.text_area("Report", value=report.get("report", ""), height=280) with tab3: st.markdown("### Final Diagnosis") pred = result["prediction"] probs = pred["class_probabilities"] _, top_prob = _format_confidence(probs) col_a, col_b = st.columns(2) col_a.metric("Predicted Class", pred["predicted_class_name"]) col_b.metric("Confidence", f"{top_prob * 100:.2f}%") probs_df = pd.DataFrame( {"Class": list(probs.keys()), "Probability": list(probs.values())} ).sort_values("Probability", ascending=False) st.bar_chart(probs_df.set_index("Class")) with tab4: st.json(result) json_bytes = json.dumps(result, indent=2).encode("utf-8") st.download_button( label="Download result JSON", data=json_bytes, file_name=f"hyperclinical_result_{run_id}.json", mime="application/json", )