Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -229,18 +229,6 @@ def list_available_checkpoints_from_hub(repo_id: str) -> Dict[str, str]:
|
|
| 229 |
ckpts[disp] = f
|
| 230 |
return ckpts
|
| 231 |
|
| 232 |
-
def _shape_compatible_load(model: nn.Module, state: Dict):
|
| 233 |
-
"""
|
| 234 |
-
Load only keys whose shapes match, warn about the rest.
|
| 235 |
-
This mirrors your flexible loader used elsewhere, but keeps strict as the default path.
|
| 236 |
-
"""
|
| 237 |
-
model_k = model.state_dict()
|
| 238 |
-
filtered = {k: v for k, v in state.items() if k in model_k and model_k[k].shape == v.shape}
|
| 239 |
-
missing = sorted(set(model_k.keys()) - set(filtered.keys()))
|
| 240 |
-
if missing:
|
| 241 |
-
print(f"[WARN] Skipping {len(missing)} keys due to shape mismatch (likely classifier).")
|
| 242 |
-
model.load_state_dict(filtered, strict=False)
|
| 243 |
-
|
| 244 |
_LOADED: Dict[str, nn.Module] = {}
|
| 245 |
def load_model_from_hub(display_name: str, repo_filename: str) -> nn.Module:
|
| 246 |
key = f"{MODEL_REPO}::{repo_filename}"
|
|
@@ -254,26 +242,15 @@ def load_model_from_hub(display_name: str, repo_filename: str) -> nn.Module:
|
|
| 254 |
local_dir_use_symlinks=False
|
| 255 |
)
|
| 256 |
|
| 257 |
-
|
| 258 |
-
arch = DISPLAY_TO_ARCH.get(display_name) or infer_arch_from_filename(repo_filename)
|
| 259 |
if arch not in MODEL_BUILDERS:
|
| 260 |
-
raise
|
| 261 |
|
| 262 |
-
# Build fresh model and adapt head if checkpoint expects a Sequential head
|
| 263 |
model = MODEL_BUILDERS[arch](NUM_CLASSES).to(DEVICE)
|
| 264 |
state = torch.load(local_path, map_location="cpu")
|
| 265 |
model = adapt_head_for_state_dict(model, arch, state, NUM_CLASSES)
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
try:
|
| 269 |
-
model.load_state_dict(state, strict=True)
|
| 270 |
-
print(f"[INFO] StrictLoad=True for {display_name} ({repo_filename})")
|
| 271 |
-
except Exception as e:
|
| 272 |
-
print(f"[ERROR] StrictLoad failed for {display_name} ({repo_filename}): {e}")
|
| 273 |
-
_shape_compatible_load(model, state)
|
| 274 |
-
print(f"[INFO] StrictLoad=False (shape-compatible fallback) for {display_name} ({repo_filename})")
|
| 275 |
-
|
| 276 |
-
model.to(DEVICE).eval()
|
| 277 |
_LOADED[key] = model
|
| 278 |
return model
|
| 279 |
|
|
@@ -642,8 +619,11 @@ def build_app():
|
|
| 642 |
|
| 643 |
clear_btn = gr.Button("Clear files")
|
| 644 |
|
|
|
|
|
|
|
|
|
|
| 645 |
# Predictions
|
| 646 |
-
results = gr.Dataframe(label="Predictions", wrap=True, interactive=False, type="pandas")
|
| 647 |
|
| 648 |
# Action
|
| 649 |
run_btn = gr.Button("🔍 Run inference", variant="primary")
|
|
|
|
| 229 |
ckpts[disp] = f
|
| 230 |
return ckpts
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
_LOADED: Dict[str, nn.Module] = {}
|
| 233 |
def load_model_from_hub(display_name: str, repo_filename: str) -> nn.Module:
|
| 234 |
key = f"{MODEL_REPO}::{repo_filename}"
|
|
|
|
| 242 |
local_dir_use_symlinks=False
|
| 243 |
)
|
| 244 |
|
| 245 |
+
arch = arch_from_filename(repo_filename)
|
|
|
|
| 246 |
if arch not in MODEL_BUILDERS:
|
| 247 |
+
raise RuntimeError(f"Unknown architecture inferred from filename: {arch}")
|
| 248 |
|
|
|
|
| 249 |
model = MODEL_BUILDERS[arch](NUM_CLASSES).to(DEVICE)
|
| 250 |
state = torch.load(local_path, map_location="cpu")
|
| 251 |
model = adapt_head_for_state_dict(model, arch, state, NUM_CLASSES)
|
| 252 |
+
model.load_state_dict(state, strict=True)
|
| 253 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
_LOADED[key] = model
|
| 255 |
return model
|
| 256 |
|
|
|
|
| 619 |
|
| 620 |
clear_btn = gr.Button("Clear files")
|
| 621 |
|
| 622 |
+
# Pre-define empty dataframe with column headers
|
| 623 |
+
empty_df = pd.DataFrame(columns=["Model 1", "Model 2", "Ensemble"])
|
| 624 |
+
|
| 625 |
# Predictions
|
| 626 |
+
results = gr.Dataframe(label="Predictions", value=empty_df, wrap=True, interactive=False, type="pandas")
|
| 627 |
|
| 628 |
# Action
|
| 629 |
run_btn = gr.Button("🔍 Run inference", variant="primary")
|