Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| def build_model(num_classes: int) -> nn.Module: | |
| model = models.resnet34(weights=None) | |
| in_features = model.fc.in_features | |
| model.fc = nn.Linear(in_features, num_classes) | |
| return model | |
| def _torch_load(path, map_location): | |
| # Try safe load with weights_only=True, but allowlist needed numpy scalar if present. | |
| try: | |
| return torch.load(path, map_location=map_location, weights_only=True) | |
| except Exception as e1: | |
| # If it's the numpy scalar allowlist issue or any pickle restriction, retry with safe_globals | |
| try: | |
| from torch.serialization import add_safe_globals | |
| import numpy as _np | |
| add_safe_globals([_np._core.multiarray.scalar]) | |
| return torch.load(path, map_location=map_location, weights_only=True) | |
| except Exception as e2: | |
| # As a last resort, if and only if the file is trusted, load with weights_only=False | |
| # This can execute arbitrary code present in the pickle. Use only for trusted checkpoints. | |
| return torch.load(path, map_location=map_location, weights_only=False) | |
| def load_weights(model: nn.Module, ckpt_path: str, map_location="cpu") -> nn.Module: | |
| state = _torch_load(ckpt_path, map_location=map_location) | |
| # Accept common formats: raw state_dict, {'state_dict': ...}, {'model': ...} | |
| if isinstance(state, dict) and "state_dict" in state: | |
| state = state["state_dict"] | |
| if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict): | |
| state = state["model"] | |
| # Strip possible DistributedDataParallel prefixes | |
| new_state = {} | |
| for k, v in state.items(): | |
| if k.startswith("module."): | |
| new_state[k[len("module."):]] = v | |
| else: | |
| new_state[k] = v | |
| model.load_state_dict(new_state, strict=False) | |
| model.eval() | |
| return model | |