salmasoma commited on
Commit
a19ac32
·
1 Parent(s): 044452f

Fix Gemma3 hidden-size handling and add built-in example NIfTI

Browse files
README.md CHANGED
@@ -25,6 +25,7 @@ UI updates:
25
  - EHR table uses the requested clinical field names (including cardiovascular comorbidity columns).
26
  - Embedded pipeline figure shown at top of interface.
27
  - MRI input supports both file upload and public URL.
 
28
  - If Hugging Face upload fails with `AxiosError 403`, use URL mode.
29
 
30
  ## Solving the Large Checkpoint Push Issue
@@ -90,6 +91,7 @@ bash scripts/publish_assets_from_local.sh SalmaHassan/HyperClinical-assets ../av
90
  ## App Entry
91
 
92
  - Streamlit app: `src/streamlit_app.py`
 
93
 
94
  ## Local Run
95
 
 
25
  - EHR table uses the requested clinical field names (including cardiovascular comorbidity columns).
26
  - Embedded pipeline figure shown at top of interface.
27
  - MRI input supports both file upload and public URL.
28
+ - MRI input supports file upload, public URL, and a built-in example NIfTI.
29
  - If Hugging Face upload fails with `AxiosError 403`, use URL mode.
30
 
31
  ## Solving the Large Checkpoint Push Issue
 
91
  ## App Entry
92
 
93
  - Streamlit app: `src/streamlit_app.py`
94
+ - Built-in sample MRI: `src/examples/example_case.nii.gz`
95
 
96
  ## Local Run
97
 
src/demo_backend/neurofusion/medgemma_encoder.py CHANGED
@@ -41,6 +41,24 @@ def _normalize_loader_result(result):
41
  raise ValueError(f"Unexpected loader tuple length: {len(result)}")
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def _try_load_medgemma(model_name: str, cache_dir: Optional[str] = None):
45
  """Load MedGemma via ``AutoModelForCausalLM`` (generative Gemma model).
46
 
@@ -70,7 +88,7 @@ def _try_load_medgemma(model_name: str, cache_dir: Optional[str] = None):
70
  model_name,
71
  **load_kwargs,
72
  )
73
- hidden_size = model.config.hidden_size
74
  logger.info(f"Loaded MedGemma: {model_name} (hidden_size={hidden_size})")
75
  return model, tokenizer, hidden_size, None
76
  except Exception as e:
 
41
  raise ValueError(f"Unexpected loader tuple length: {len(result)}")
42
 
43
 
44
+ def _resolve_hidden_size(model) -> int:
45
+ """Infer LM embedding width across Gemma config variants."""
46
+ cfg = getattr(model, "config", None)
47
+ hidden_size = getattr(cfg, "hidden_size", None)
48
+
49
+ if hidden_size is None and cfg is not None and hasattr(cfg, "text_config"):
50
+ hidden_size = getattr(cfg.text_config, "hidden_size", None)
51
+
52
+ if hidden_size is None:
53
+ emb = model.get_input_embeddings()
54
+ hidden_size = getattr(emb, "embedding_dim", None)
55
+
56
+ if hidden_size is None:
57
+ raise AttributeError("Could not infer hidden_size from model config or input embeddings")
58
+
59
+ return int(hidden_size)
60
+
61
+
62
  def _try_load_medgemma(model_name: str, cache_dir: Optional[str] = None):
63
  """Load MedGemma via ``AutoModelForCausalLM`` (generative Gemma model).
64
 
 
88
  model_name,
89
  **load_kwargs,
90
  )
91
+ hidden_size = _resolve_hidden_size(model)
92
  logger.info(f"Loaded MedGemma: {model_name} (hidden_size={hidden_size})")
93
  return model, tokenizer, hidden_size, None
94
  except Exception as e:
src/examples/example_case.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b6c4f69e0bcce412b8aecbb47ade5b6e222723bdf5a98d90cef38f36b5341eb
3
+ size 5301248
src/streamlit_app.py CHANGED
@@ -29,6 +29,7 @@ from demo_backend.paths import (
29
  )
30
  from demo_backend.pipeline import run_full_inference
31
 
 
32
 
33
  st.set_page_config(page_title="HyperClinical Demo", layout="wide")
34
 
@@ -164,13 +165,24 @@ require_true_hf_embeddings = (
164
 
165
  st.subheader("1) MRI Input")
166
 
167
- input_mode = st.radio("Select input mode", ["Upload file", "Public URL"], horizontal=True)
 
 
 
 
168
  uploaded_file = None
169
  nifti_url = ""
 
170
  if input_mode == "Upload file":
171
  uploaded_file = st.file_uploader("Upload a T1 MRI NIfTI (.nii or .nii.gz)", type=["nii", "gz"])
172
- else:
173
  nifti_url = st.text_input("Public URL to .nii or .nii.gz")
 
 
 
 
 
 
174
 
175
  st.subheader("2) Enter EHR Features")
176
  if "ehr_df" not in st.session_state:
@@ -195,6 +207,9 @@ if st.button("Run Full Inference", type="primary", use_container_width=True):
195
  if input_mode == "Public URL" and not nifti_url.strip():
196
  st.error("Please provide a public NIfTI URL.")
197
  st.stop()
 
 
 
198
 
199
  if auto_download_assets:
200
  with st.spinner("Ensuring required assets are available..."):
@@ -234,13 +249,16 @@ if st.button("Run Full Inference", type="primary", use_container_width=True):
234
  nifti_path = run_dir / safe_name
235
  with nifti_path.open("wb") as f:
236
  f.write(uploaded_file.getbuffer())
237
- else:
238
  with st.spinner("Downloading NIfTI from URL..."):
239
  try:
240
  nifti_path = _download_nifti_from_url(nifti_url.strip(), run_dir)
241
  except Exception as exc:
242
  st.error(f"Failed to download NIfTI from URL: {exc}")
243
  st.stop()
 
 
 
244
 
245
  ehr_row_internal = _ehr_ui_to_internal_row(ehr_df.iloc[0].to_dict())
246
 
 
29
  )
30
  from demo_backend.pipeline import run_full_inference
31
 
32
+ EXAMPLE_NIFTI = PROJECT_ROOT / "src" / "examples" / "example_case.nii.gz"
33
 
34
  st.set_page_config(page_title="HyperClinical Demo", layout="wide")
35
 
 
165
 
166
  st.subheader("1) MRI Input")
167
 
168
+ input_mode = st.radio(
169
+ "Select input mode",
170
+ ["Upload file", "Public URL", "Built-in example"],
171
+ horizontal=True,
172
+ )
173
  uploaded_file = None
174
  nifti_url = ""
175
+ use_example_nifti = False
176
  if input_mode == "Upload file":
177
  uploaded_file = st.file_uploader("Upload a T1 MRI NIfTI (.nii or .nii.gz)", type=["nii", "gz"])
178
+ elif input_mode == "Public URL":
179
  nifti_url = st.text_input("Public URL to .nii or .nii.gz")
180
+ else:
181
+ use_example_nifti = True
182
+ if EXAMPLE_NIFTI.exists():
183
+ st.caption(f"Using built-in sample: `{EXAMPLE_NIFTI.name}`")
184
+ else:
185
+ st.warning("Built-in sample MRI is not available in this deployment.")
186
 
187
  st.subheader("2) Enter EHR Features")
188
  if "ehr_df" not in st.session_state:
 
207
  if input_mode == "Public URL" and not nifti_url.strip():
208
  st.error("Please provide a public NIfTI URL.")
209
  st.stop()
210
+ if input_mode == "Built-in example" and not EXAMPLE_NIFTI.exists():
211
+ st.error("Built-in sample MRI is missing. Please use upload or URL mode.")
212
+ st.stop()
213
 
214
  if auto_download_assets:
215
  with st.spinner("Ensuring required assets are available..."):
 
249
  nifti_path = run_dir / safe_name
250
  with nifti_path.open("wb") as f:
251
  f.write(uploaded_file.getbuffer())
252
+ elif input_mode == "Public URL":
253
  with st.spinner("Downloading NIfTI from URL..."):
254
  try:
255
  nifti_path = _download_nifti_from_url(nifti_url.strip(), run_dir)
256
  except Exception as exc:
257
  st.error(f"Failed to download NIfTI from URL: {exc}")
258
  st.stop()
259
+ else:
260
+ nifti_path = run_dir / EXAMPLE_NIFTI.name
261
+ shutil.copy2(EXAMPLE_NIFTI, nifti_path)
262
 
263
  ehr_row_internal = _ehr_ui_to_internal_row(ehr_df.iloc[0].to_dict())
264