Spaces:
Sleeping
Sleeping
salmasoma commited on
Commit ·
a19ac32
1
Parent(s): 044452f
Fix Gemma3 hidden-size handling and add built-in example NIfTI
Browse files- README.md +2 -0
- src/demo_backend/neurofusion/medgemma_encoder.py +19 -1
- src/examples/example_case.nii.gz +3 -0
- src/streamlit_app.py +21 -3
README.md
CHANGED
|
@@ -25,6 +25,7 @@ UI updates:
|
|
| 25 |
- EHR table uses the requested clinical field names (including cardiovascular comorbidity columns).
|
| 26 |
- Embedded pipeline figure shown at top of interface.
|
| 27 |
- MRI input supports both file upload and public URL.
|
|
|
|
| 28 |
- If Hugging Face upload fails with `AxiosError 403`, use URL mode.
|
| 29 |
|
| 30 |
## Solving the Large Checkpoint Push Issue
|
|
@@ -90,6 +91,7 @@ bash scripts/publish_assets_from_local.sh SalmaHassan/HyperClinical-assets ../av
|
|
| 90 |
## App Entry
|
| 91 |
|
| 92 |
- Streamlit app: `src/streamlit_app.py`
|
|
|
|
| 93 |
|
| 94 |
## Local Run
|
| 95 |
|
|
|
|
| 25 |
- EHR table uses the requested clinical field names (including cardiovascular comorbidity columns).
|
| 26 |
- Embedded pipeline figure shown at top of interface.
|
| 27 |
- MRI input supports both file upload and public URL.
|
| 28 |
+
- MRI input supports file upload, public URL, and a built-in example NIfTI.
|
| 29 |
- If Hugging Face upload fails with `AxiosError 403`, use URL mode.
|
| 30 |
|
| 31 |
## Solving the Large Checkpoint Push Issue
|
|
|
|
| 91 |
## App Entry
|
| 92 |
|
| 93 |
- Streamlit app: `src/streamlit_app.py`
|
| 94 |
+
- Built-in sample MRI: `src/examples/example_case.nii.gz`
|
| 95 |
|
| 96 |
## Local Run
|
| 97 |
|
src/demo_backend/neurofusion/medgemma_encoder.py
CHANGED
|
@@ -41,6 +41,24 @@ def _normalize_loader_result(result):
|
|
| 41 |
raise ValueError(f"Unexpected loader tuple length: {len(result)}")
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def _try_load_medgemma(model_name: str, cache_dir: Optional[str] = None):
|
| 45 |
"""Load MedGemma via ``AutoModelForCausalLM`` (generative Gemma model).
|
| 46 |
|
|
@@ -70,7 +88,7 @@ def _try_load_medgemma(model_name: str, cache_dir: Optional[str] = None):
|
|
| 70 |
model_name,
|
| 71 |
**load_kwargs,
|
| 72 |
)
|
| 73 |
-
hidden_size = model
|
| 74 |
logger.info(f"Loaded MedGemma: {model_name} (hidden_size={hidden_size})")
|
| 75 |
return model, tokenizer, hidden_size, None
|
| 76 |
except Exception as e:
|
|
|
|
| 41 |
raise ValueError(f"Unexpected loader tuple length: {len(result)}")
|
| 42 |
|
| 43 |
|
| 44 |
+
def _resolve_hidden_size(model) -> int:
|
| 45 |
+
"""Infer LM embedding width across Gemma config variants."""
|
| 46 |
+
cfg = getattr(model, "config", None)
|
| 47 |
+
hidden_size = getattr(cfg, "hidden_size", None)
|
| 48 |
+
|
| 49 |
+
if hidden_size is None and cfg is not None and hasattr(cfg, "text_config"):
|
| 50 |
+
hidden_size = getattr(cfg.text_config, "hidden_size", None)
|
| 51 |
+
|
| 52 |
+
if hidden_size is None:
|
| 53 |
+
emb = model.get_input_embeddings()
|
| 54 |
+
hidden_size = getattr(emb, "embedding_dim", None)
|
| 55 |
+
|
| 56 |
+
if hidden_size is None:
|
| 57 |
+
raise AttributeError("Could not infer hidden_size from model config or input embeddings")
|
| 58 |
+
|
| 59 |
+
return int(hidden_size)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
def _try_load_medgemma(model_name: str, cache_dir: Optional[str] = None):
|
| 63 |
"""Load MedGemma via ``AutoModelForCausalLM`` (generative Gemma model).
|
| 64 |
|
|
|
|
| 88 |
model_name,
|
| 89 |
**load_kwargs,
|
| 90 |
)
|
| 91 |
+
hidden_size = _resolve_hidden_size(model)
|
| 92 |
logger.info(f"Loaded MedGemma: {model_name} (hidden_size={hidden_size})")
|
| 93 |
return model, tokenizer, hidden_size, None
|
| 94 |
except Exception as e:
|
src/examples/example_case.nii.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b6c4f69e0bcce412b8aecbb47ade5b6e222723bdf5a98d90cef38f36b5341eb
|
| 3 |
+
size 5301248
|
src/streamlit_app.py
CHANGED
|
@@ -29,6 +29,7 @@ from demo_backend.paths import (
|
|
| 29 |
)
|
| 30 |
from demo_backend.pipeline import run_full_inference
|
| 31 |
|
|
|
|
| 32 |
|
| 33 |
st.set_page_config(page_title="HyperClinical Demo", layout="wide")
|
| 34 |
|
|
@@ -164,13 +165,24 @@ require_true_hf_embeddings = (
|
|
| 164 |
|
| 165 |
st.subheader("1) MRI Input")
|
| 166 |
|
| 167 |
-
input_mode = st.radio(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
uploaded_file = None
|
| 169 |
nifti_url = ""
|
|
|
|
| 170 |
if input_mode == "Upload file":
|
| 171 |
uploaded_file = st.file_uploader("Upload a T1 MRI NIfTI (.nii or .nii.gz)", type=["nii", "gz"])
|
| 172 |
-
|
| 173 |
nifti_url = st.text_input("Public URL to .nii or .nii.gz")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
st.subheader("2) Enter EHR Features")
|
| 176 |
if "ehr_df" not in st.session_state:
|
|
@@ -195,6 +207,9 @@ if st.button("Run Full Inference", type="primary", use_container_width=True):
|
|
| 195 |
if input_mode == "Public URL" and not nifti_url.strip():
|
| 196 |
st.error("Please provide a public NIfTI URL.")
|
| 197 |
st.stop()
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
if auto_download_assets:
|
| 200 |
with st.spinner("Ensuring required assets are available..."):
|
|
@@ -234,13 +249,16 @@ if st.button("Run Full Inference", type="primary", use_container_width=True):
|
|
| 234 |
nifti_path = run_dir / safe_name
|
| 235 |
with nifti_path.open("wb") as f:
|
| 236 |
f.write(uploaded_file.getbuffer())
|
| 237 |
-
|
| 238 |
with st.spinner("Downloading NIfTI from URL..."):
|
| 239 |
try:
|
| 240 |
nifti_path = _download_nifti_from_url(nifti_url.strip(), run_dir)
|
| 241 |
except Exception as exc:
|
| 242 |
st.error(f"Failed to download NIfTI from URL: {exc}")
|
| 243 |
st.stop()
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
ehr_row_internal = _ehr_ui_to_internal_row(ehr_df.iloc[0].to_dict())
|
| 246 |
|
|
|
|
| 29 |
)
|
| 30 |
from demo_backend.pipeline import run_full_inference
|
| 31 |
|
| 32 |
+
EXAMPLE_NIFTI = PROJECT_ROOT / "src" / "examples" / "example_case.nii.gz"
|
| 33 |
|
| 34 |
st.set_page_config(page_title="HyperClinical Demo", layout="wide")
|
| 35 |
|
|
|
|
| 165 |
|
| 166 |
st.subheader("1) MRI Input")
|
| 167 |
|
| 168 |
+
input_mode = st.radio(
|
| 169 |
+
"Select input mode",
|
| 170 |
+
["Upload file", "Public URL", "Built-in example"],
|
| 171 |
+
horizontal=True,
|
| 172 |
+
)
|
| 173 |
uploaded_file = None
|
| 174 |
nifti_url = ""
|
| 175 |
+
use_example_nifti = False
|
| 176 |
if input_mode == "Upload file":
|
| 177 |
uploaded_file = st.file_uploader("Upload a T1 MRI NIfTI (.nii or .nii.gz)", type=["nii", "gz"])
|
| 178 |
+
elif input_mode == "Public URL":
|
| 179 |
nifti_url = st.text_input("Public URL to .nii or .nii.gz")
|
| 180 |
+
else:
|
| 181 |
+
use_example_nifti = True
|
| 182 |
+
if EXAMPLE_NIFTI.exists():
|
| 183 |
+
st.caption(f"Using built-in sample: `{EXAMPLE_NIFTI.name}`")
|
| 184 |
+
else:
|
| 185 |
+
st.warning("Built-in sample MRI is not available in this deployment.")
|
| 186 |
|
| 187 |
st.subheader("2) Enter EHR Features")
|
| 188 |
if "ehr_df" not in st.session_state:
|
|
|
|
| 207 |
if input_mode == "Public URL" and not nifti_url.strip():
|
| 208 |
st.error("Please provide a public NIfTI URL.")
|
| 209 |
st.stop()
|
| 210 |
+
if input_mode == "Built-in example" and not EXAMPLE_NIFTI.exists():
|
| 211 |
+
st.error("Built-in sample MRI is missing. Please use upload or URL mode.")
|
| 212 |
+
st.stop()
|
| 213 |
|
| 214 |
if auto_download_assets:
|
| 215 |
with st.spinner("Ensuring required assets are available..."):
|
|
|
|
| 249 |
nifti_path = run_dir / safe_name
|
| 250 |
with nifti_path.open("wb") as f:
|
| 251 |
f.write(uploaded_file.getbuffer())
|
| 252 |
+
elif input_mode == "Public URL":
|
| 253 |
with st.spinner("Downloading NIfTI from URL..."):
|
| 254 |
try:
|
| 255 |
nifti_path = _download_nifti_from_url(nifti_url.strip(), run_dir)
|
| 256 |
except Exception as exc:
|
| 257 |
st.error(f"Failed to download NIfTI from URL: {exc}")
|
| 258 |
st.stop()
|
| 259 |
+
else:
|
| 260 |
+
nifti_path = run_dir / EXAMPLE_NIFTI.name
|
| 261 |
+
shutil.copy2(EXAMPLE_NIFTI, nifti_path)
|
| 262 |
|
| 263 |
ehr_row_internal = _ehr_ui_to_internal_row(ehr_df.iloc[0].to_dict())
|
| 264 |
|