HyperClinical / src /streamlit_app.py
salmasoma
Fix Gemma3 hidden-size handling and add built-in example NIfTI
a19ac32
#!/usr/bin/env python3
from __future__ import annotations
import json
import os
import shutil
import uuid
from datetime import datetime
from pathlib import Path
from typing import Dict
from urllib.parse import urlparse
from urllib.request import urlopen
import pandas as pd
import streamlit as st
from demo_backend.assets import (
PIPELINE_FIGURE_PATH,
ensure_assets_from_hub,
)
from demo_backend.constants import AVRA_COLS, CLINICAL_FEATURE_COLS
from demo_backend.paths import (
DEFAULT_OUTPUT_DIR,
EXAMPLE_EHR,
PROJECT_ROOT,
resolve_default_avra_script,
resolve_default_checkpoint,
resolve_default_stats,
)
from demo_backend.pipeline import run_full_inference
EXAMPLE_NIFTI = PROJECT_ROOT / "src" / "examples" / "example_case.nii.gz"
st.set_page_config(page_title="HyperClinical Demo", layout="wide")
EHR_UI_COLUMNS = [
"sex",
"age",
"edu",
"race",
"hispanic",
"marriage",
"declong",
"decage",
"smoke",
"height",
"weight",
"bmi",
"Heart attack/cardiac arrest",
"Atrial fibrillation",
"Angioplasty/endarterectomy/stent",
"Cardiac bypass procedure",
"Pacemaker",
"Congestive heart failure",
"Cardiovascular disease",
"Cerebrovascular disease",
"Parkinson",
"Stroke/transient ischemic attack",
]
UI_TO_INTERNAL = {
"sex": "sex",
"age": "age",
"edu": "edu",
"race": "race",
"hispanic": "hispanic",
"marriage": "marriage",
"declong": "declong",
"decage": "decage",
"smoke": "smoke",
"height": "height",
"weight": "weight",
"bmi": "bmi",
"Heart attack/cardiac arrest": "health_history1",
"Atrial fibrillation": "health_history2",
"Angioplasty/endarterectomy/stent": "health_history3",
"Cardiac bypass procedure": "health_history4",
"Pacemaker": "health_history5",
"Congestive heart failure": "health_history6",
"Cardiovascular disease": "health_history7",
"Cerebrovascular disease": "health_history10",
"Parkinson": "health_history11",
"Stroke/transient ischemic attack": "health_history12",
}
INTERNAL_TO_UI = {v: k for k, v in UI_TO_INTERNAL.items()}
def _load_default_ehr() -> pd.DataFrame:
if EXAMPLE_EHR.exists():
df = pd.read_csv(EXAMPLE_EHR).head(1)
else:
df = pd.DataFrame([{col: 0 for col in CLINICAL_FEATURE_COLS}])
one = {}
for internal in CLINICAL_FEATURE_COLS:
ui = INTERNAL_TO_UI[internal]
if internal in df.columns:
one[ui] = df.iloc[0][internal]
else:
one[ui] = 0
return pd.DataFrame([one])[EHR_UI_COLUMNS]
def _resolve_user_path(raw_value: str, fallback: Path, allow_missing: bool = False) -> Path:
path = Path(raw_value).expanduser()
if not path.is_absolute():
path = (PROJECT_ROOT / path).resolve()
if path.exists() or allow_missing:
return path
return fallback
def _format_confidence(probabilities: Dict[str, float]) -> tuple[str, float]:
top_label = max(probabilities, key=probabilities.get)
return top_label, probabilities[top_label]
def _download_nifti_from_url(url: str, target_dir: Path) -> Path:
parsed = urlparse(url)
suffix = ".nii.gz" if parsed.path.lower().endswith(".nii.gz") else ".nii"
out_path = target_dir / f"uploaded_case_from_url{suffix}"
with urlopen(url, timeout=60) as resp:
with out_path.open("wb") as dst:
shutil.copyfileobj(resp, dst)
return out_path
def _ehr_ui_to_internal_row(ui_row: Dict[str, float]) -> Dict[str, float]:
return {UI_TO_INTERNAL[key]: ui_row.get(key, 0) for key in EHR_UI_COLUMNS}
st.title("HyperClinical: Multimodal Dementia Subclassification")
st.markdown(
"HyperClinical combines AVRA atrophy scoring, structured clinical context, and multimodal fusion to predict "
"fine-grained dementia subclasses. Fill the EHR table, provide MRI input, and inspect every stage output."
)
if PIPELINE_FIGURE_PATH.exists():
st.image(str(PIPELINE_FIGURE_PATH), caption="HyperClinical Pipeline Overview", use_container_width=True)
else:
st.caption("Pipeline figure not found locally yet. It will be auto-downloaded from the configured HF assets repo.")
st.caption("Runtime settings are fixed for demo use. Model assets are managed server-side.")
output_root_input = "./outputs"
assets_repo_id = os.getenv("HF_ASSETS_REPO_ID", "SalmaHassan/HyperClinical-assets").strip()
assets_revision = (os.getenv("HF_ASSETS_REVISION", "main") or "main").strip() or "main"
auto_download_assets = True
device = "auto"
reuse_registration = True
use_fsl = False
allow_foundation_backbones = True
enable_remote_report = True
use_hf_foundation_embeddings = True
require_true_hf_embeddings = (
(os.getenv("HF_REQUIRE_FOUNDATION_MODELS", "0") or "0").strip().lower() in {"1", "true", "yes", "on"}
)
st.subheader("1) MRI Input")
input_mode = st.radio(
"Select input mode",
["Upload file", "Public URL", "Built-in example"],
horizontal=True,
)
uploaded_file = None
nifti_url = ""
use_example_nifti = False
if input_mode == "Upload file":
uploaded_file = st.file_uploader("Upload a T1 MRI NIfTI (.nii or .nii.gz)", type=["nii", "gz"])
elif input_mode == "Public URL":
nifti_url = st.text_input("Public URL to .nii or .nii.gz")
else:
use_example_nifti = True
if EXAMPLE_NIFTI.exists():
st.caption(f"Using built-in sample: `{EXAMPLE_NIFTI.name}`")
else:
st.warning("Built-in sample MRI is not available in this deployment.")
st.subheader("2) Enter EHR Features")
if "ehr_df" not in st.session_state:
st.session_state.ehr_df = _load_default_ehr()
ehr_df = st.data_editor(
st.session_state.ehr_df,
num_rows="fixed",
use_container_width=True,
key="ehr_editor",
)
if st.button("Run Full Inference", type="primary", use_container_width=True):
if len(ehr_df) == 0:
st.error("EHR table is empty. Add one row.")
st.stop()
if input_mode == "Upload file" and uploaded_file is None:
st.error("Please upload a NIfTI file first.")
st.stop()
if input_mode == "Public URL" and not nifti_url.strip():
st.error("Please provide a public NIfTI URL.")
st.stop()
if input_mode == "Built-in example" and not EXAMPLE_NIFTI.exists():
st.error("Built-in sample MRI is missing. Please use upload or URL mode.")
st.stop()
if auto_download_assets:
with st.spinner("Ensuring required assets are available..."):
try:
ensure_assets_from_hub(repo_id=assets_repo_id, revision=assets_revision)
except Exception as exc:
st.error(f"Could not download required assets: {exc}")
st.caption("Check HF assets repo configuration and `HF_TOKEN` (if private).")
st.stop()
checkpoint_path = resolve_default_checkpoint()
stats_path = resolve_default_stats()
avra_script_path = resolve_default_avra_script()
output_root = _resolve_user_path(output_root_input, DEFAULT_OUTPUT_DIR, allow_missing=True)
if not checkpoint_path.exists():
st.error(f"Checkpoint not found: {checkpoint_path}")
st.stop()
if not stats_path.exists():
st.error(f"Stats file not found: {stats_path}")
st.stop()
if not avra_script_path.exists():
st.error(f"AVRA script not found: {avra_script_path}")
st.stop()
run_id = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:6]
run_dir = output_root / run_id
run_dir.mkdir(parents=True, exist_ok=True)
if input_mode == "Upload file":
file_name = uploaded_file.name
if not (file_name.endswith(".nii") or file_name.endswith(".nii.gz")):
st.error("Uploaded file must end with .nii or .nii.gz")
st.stop()
safe_name = "uploaded_case.nii.gz" if file_name.endswith(".nii.gz") else "uploaded_case.nii"
nifti_path = run_dir / safe_name
with nifti_path.open("wb") as f:
f.write(uploaded_file.getbuffer())
elif input_mode == "Public URL":
with st.spinner("Downloading NIfTI from URL..."):
try:
nifti_path = _download_nifti_from_url(nifti_url.strip(), run_dir)
except Exception as exc:
st.error(f"Failed to download NIfTI from URL: {exc}")
st.stop()
else:
nifti_path = run_dir / EXAMPLE_NIFTI.name
shutil.copy2(EXAMPLE_NIFTI, nifti_path)
ehr_row_internal = _ehr_ui_to_internal_row(ehr_df.iloc[0].to_dict())
with st.spinner("Running AVRA + HyperClinical inference..."):
try:
result = run_full_inference(
nifti_path=nifti_path,
ehr_row=ehr_row_internal,
output_dir=run_dir,
checkpoint_path=checkpoint_path,
stats_path=stats_path,
avra_script_path=avra_script_path,
device=device,
allow_foundation_backbones=allow_foundation_backbones,
use_hf_foundation_embeddings=use_hf_foundation_embeddings,
require_true_hf_embeddings=require_true_hf_embeddings,
reuse_registration=reuse_registration,
use_fsl=use_fsl,
enable_remote_medgemma_report=enable_remote_report,
)
except Exception as exc:
st.exception(exc)
st.stop()
st.success("Inference complete")
tab1, tab2, tab3, tab4 = st.tabs([
"Stage 1: Atrophy (AVRA)",
"Stage 2: Clinical Report",
"Stage 3: Diagnosis",
"Raw JSON",
])
with tab1:
st.markdown("### AVRA Atrophy Scores")
avra_scores = result["avra_scores"]
cols = st.columns(4)
for i, key in enumerate(AVRA_COLS):
cols[i].metric(label=key, value=f"{float(avra_scores[key]):.3f}")
st.dataframe(pd.DataFrame([avra_scores]), use_container_width=True)
coronal_path = Path(result["avra_coronal_image"])
if coronal_path.exists():
st.image(str(coronal_path), caption="AVRA coronal slice output")
with tab2:
st.markdown("### MedGemma-style Clinical Narrative")
st.text_area("Clinical narrative", value=result["clinical_narrative"], height=220)
st.markdown("### Foundation Embedding Sources")
foundation_status = result.get("foundation_embeddings", {})
st.json(foundation_status)
med_status = str(foundation_status.get("medgemma", "")).lower()
if "fallback_mlp_only" in med_status or med_status.startswith("error:"):
st.warning(
"MedGemma did not load from HF for this run. "
"Common causes: outdated `transformers` (need >=4.57.1), missing HF token for gated models, "
"or insufficient runtime memory."
)
st.markdown("### Clinical Report")
report = result["medgemma_report"]
st.caption(f"Source: {report.get('source', 'unknown')}")
st.text_area("Report", value=report.get("report", ""), height=280)
with tab3:
st.markdown("### Final Diagnosis")
pred = result["prediction"]
probs = pred["class_probabilities"]
_, top_prob = _format_confidence(probs)
col_a, col_b = st.columns(2)
col_a.metric("Predicted Class", pred["predicted_class_name"])
col_b.metric("Confidence", f"{top_prob * 100:.2f}%")
probs_df = pd.DataFrame(
{"Class": list(probs.keys()), "Probability": list(probs.values())}
).sort_values("Probability", ascending=False)
st.bar_chart(probs_df.set_index("Class"))
with tab4:
st.json(result)
json_bytes = json.dumps(result, indent=2).encode("utf-8")
st.download_button(
label="Download result JSON",
data=json_bytes,
file_name=f"hyperclinical_result_{run_id}.json",
mime="application/json",
)