maregu2023 commited on
Commit
e2c3db1
·
1 Parent(s): c049516

fix: nnUNet models disappearing from dropdown on HF Spaces

Browse files

Three bugs fixed:

1. _ensure_nnunet_from_hub: final success check only verified plans.json
but not the .pth weight file. Now checks both.

2. model_registry.py: register_nnunet() failures were logged at debug
level (invisible). Changed to warning with exc_info=True so the
actual download error is visible in HF Spaces logs.

3. orchestrator.py: get_available_models() checked cfg.checkpoint_path
directly from the stale NNUNET_MODELS config (still None before
download). Now uses is_model_registered() which triggers
_ensure_models_registered() -> register_nnunet() -> HF Hub download,
so checkpoint_path is populated by the time we check.

seg_app/config/settings.py CHANGED
@@ -235,9 +235,17 @@ def _ensure_nnunet_from_hub(
235
  shutil.copy2(str(downloaded_path), str(dest))
236
  logger.info(f" -> {dest}")
237
 
238
- if (local_dir / "plans.json").is_file():
 
 
 
239
  logger.info(f"nnUNet checkpoint ready: {local_dir}")
240
  return str(local_dir)
 
 
 
 
 
241
 
242
  except Exception as e:
243
  logger.warning(f"Failed to download nnUNet weights from {hf_hub_repo}: {e}")
 
235
  shutil.copy2(str(downloaded_path), str(dest))
236
  logger.info(f" -> {dest}")
237
 
238
+ if (
239
+ (local_dir / "plans.json").is_file()
240
+ and (local_dir / "fold_0" / checkpoint_name).is_file()
241
+ ):
242
  logger.info(f"nnUNet checkpoint ready: {local_dir}")
243
  return str(local_dir)
244
+ else:
245
+ logger.warning(
246
+ f"Download appeared to succeed but required files are missing "
247
+ f"in {local_dir} (plans.json or fold_0/{checkpoint_name})"
248
+ )
249
 
250
  except Exception as e:
251
  logger.warning(f"Failed to download nnUNet weights from {hf_hub_repo}: {e}")
seg_app/inference/model_registry.py CHANGED
@@ -230,7 +230,7 @@ def _ensure_models_registered() -> None:
230
  register_nnunet()
231
  except Exception as e:
232
  import logging
233
- logging.getLogger(__name__).debug(f"nnUNet registration skipped: {e}")
234
 
235
  # Register Medical SAM 3D for volumetric interactive refinement
236
  from seg_app.models.medical_sam_3d import register_medical_sam_3d
 
230
  register_nnunet()
231
  except Exception as e:
232
  import logging
233
+ logging.getLogger(__name__).warning(f"nnUNet registration failed: {e}", exc_info=True)
234
 
235
  # Register Medical SAM 3D for volumetric interactive refinement
236
  from seg_app.models.medical_sam_3d import register_medical_sam_3d
seg_app/inference/orchestrator.py CHANGED
@@ -191,15 +191,20 @@ def get_available_models() -> List[Dict[str, str]]:
191
 
192
  Returns:
193
  List of dicts with 'id' and 'display_name' for each model.
194
- Only models with real checkpoints on disk are included.
 
195
  """
196
  from seg_app.config.settings import NNUNET_MODELS
 
197
 
198
  models: List[Dict[str, str]] = []
199
 
200
- # Add every nnUNet variant that has a checkpoint on disk
 
 
 
201
  for cfg in NNUNET_MODELS:
202
- if cfg.checkpoint_path is not None:
203
  models.append({"id": cfg.model_id, "display_name": cfg.display_name})
204
 
205
  # Medical SAM 3D is always available (weights on HF Hub or local)
@@ -253,7 +258,7 @@ def _resolve_task_to_model(task_name: str) -> str:
253
  """
254
  # Stub - will delegate to config/tasks.py
255
  task_to_model = {
256
- "brain_lesion": "nnunet-isles",
257
  }
258
 
259
  if task_name not in task_to_model:
 
191
 
192
  Returns:
193
  List of dicts with 'id' and 'display_name' for each model.
194
+ Only models that are registered (checkpoint available or
195
+ downloadable from HF Hub) are included.
196
  """
197
  from seg_app.config.settings import NNUNET_MODELS
198
+ from seg_app.inference.model_registry import is_model_registered
199
 
200
  models: List[Dict[str, str]] = []
201
 
202
+ # Add every nnUNet variant that was successfully registered.
203
+ # is_model_registered() triggers _ensure_models_registered() which
204
+ # downloads missing weights from HF Hub, so checkpoint_path in the
205
+ # config object is updated by the time we check here.
206
  for cfg in NNUNET_MODELS:
207
+ if is_model_registered(cfg.model_id):
208
  models.append({"id": cfg.model_id, "display_name": cfg.display_name})
209
 
210
  # Medical SAM 3D is always available (weights on HF Hub or local)
 
258
  """
259
  # Stub - will delegate to config/tasks.py
260
  task_to_model = {
261
+ "brain_lesion": "nnunet-isles-dwi",
262
  }
263
 
264
  if task_name not in task_to_model: