katelynhur commited on
Commit
a044fb0
·
verified ·
1 Parent(s): 7474d7a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -28
app.py CHANGED
@@ -229,18 +229,6 @@ def list_available_checkpoints_from_hub(repo_id: str) -> Dict[str, str]:
229
  ckpts[disp] = f
230
  return ckpts
231
 
232
- def _shape_compatible_load(model: nn.Module, state: Dict):
233
- """
234
- Load only keys whose shapes match, warn about the rest.
235
- This mirrors your flexible loader used elsewhere, but keeps strict as the default path.
236
- """
237
- model_k = model.state_dict()
238
- filtered = {k: v for k, v in state.items() if k in model_k and model_k[k].shape == v.shape}
239
- missing = sorted(set(model_k.keys()) - set(filtered.keys()))
240
- if missing:
241
- print(f"[WARN] Skipping {len(missing)} keys due to shape mismatch (likely classifier).")
242
- model.load_state_dict(filtered, strict=False)
243
-
244
  _LOADED: Dict[str, nn.Module] = {}
245
  def load_model_from_hub(display_name: str, repo_filename: str) -> nn.Module:
246
  key = f"{MODEL_REPO}::{repo_filename}"
@@ -254,26 +242,15 @@ def load_model_from_hub(display_name: str, repo_filename: str) -> nn.Module:
254
  local_dir_use_symlinks=False
255
  )
256
 
257
- # Resolve architecture (you already have this mapping logic earlier)
258
- arch = DISPLAY_TO_ARCH.get(display_name) or infer_arch_from_filename(repo_filename)
259
  if arch not in MODEL_BUILDERS:
260
- raise ValueError(f"Unknown or unsupported architecture for {display_name} / {repo_filename}")
261
 
262
- # Build fresh model and adapt head if checkpoint expects a Sequential head
263
  model = MODEL_BUILDERS[arch](NUM_CLASSES).to(DEVICE)
264
  state = torch.load(local_path, map_location="cpu")
265
  model = adapt_head_for_state_dict(model, arch, state, NUM_CLASSES)
266
-
267
- # Strict-first, then fallback to shape-compatible load
268
- try:
269
- model.load_state_dict(state, strict=True)
270
- print(f"[INFO] StrictLoad=True for {display_name} ({repo_filename})")
271
- except Exception as e:
272
- print(f"[ERROR] StrictLoad failed for {display_name} ({repo_filename}): {e}")
273
- _shape_compatible_load(model, state)
274
- print(f"[INFO] StrictLoad=False (shape-compatible fallback) for {display_name} ({repo_filename})")
275
-
276
- model.to(DEVICE).eval()
277
  _LOADED[key] = model
278
  return model
279
 
@@ -642,8 +619,11 @@ def build_app():
642
 
643
  clear_btn = gr.Button("Clear files")
644
 
 
 
 
645
  # Predictions
646
- results = gr.Dataframe(label="Predictions", wrap=True, interactive=False, type="pandas")
647
 
648
  # Action
649
  run_btn = gr.Button("🔍 Run inference", variant="primary")
 
229
  ckpts[disp] = f
230
  return ckpts
231
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  _LOADED: Dict[str, nn.Module] = {}
233
  def load_model_from_hub(display_name: str, repo_filename: str) -> nn.Module:
234
  key = f"{MODEL_REPO}::{repo_filename}"
 
242
  local_dir_use_symlinks=False
243
  )
244
 
245
+ arch = arch_from_filename(repo_filename)
 
246
  if arch not in MODEL_BUILDERS:
247
+ raise RuntimeError(f"Unknown architecture inferred from filename: {arch}")
248
 
 
249
  model = MODEL_BUILDERS[arch](NUM_CLASSES).to(DEVICE)
250
  state = torch.load(local_path, map_location="cpu")
251
  model = adapt_head_for_state_dict(model, arch, state, NUM_CLASSES)
252
+ model.load_state_dict(state, strict=True)
253
+ model.eval()
 
 
 
 
 
 
 
 
 
254
  _LOADED[key] = model
255
  return model
256
 
 
619
 
620
  clear_btn = gr.Button("Clear files")
621
 
622
+ # Pre-define empty dataframe with column headers
623
+ empty_df = pd.DataFrame(columns=["Model 1", "Model 2", "Ensemble"])
624
+
625
  # Predictions
626
+ results = gr.Dataframe(label="Predictions", value=empty_df, wrap=True, interactive=False, type="pandas")
627
 
628
  # Action
629
  run_btn = gr.Button("🔍 Run inference", variant="primary")