import torch import torch.nn as nn from torchvision import models def build_model(num_classes: int) -> nn.Module: model = models.resnet34(weights=None) in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) return model def _torch_load(path, map_location): # Try safe load with weights_only=True, but allowlist needed numpy scalar if present. try: return torch.load(path, map_location=map_location, weights_only=True) except Exception as e1: # If it's the numpy scalar allowlist issue or any pickle restriction, retry with safe_globals try: from torch.serialization import add_safe_globals import numpy as _np add_safe_globals([_np._core.multiarray.scalar]) return torch.load(path, map_location=map_location, weights_only=True) except Exception as e2: # As a last resort, if and only if the file is trusted, load with weights_only=False # This can execute arbitrary code present in the pickle. Use only for trusted checkpoints. return torch.load(path, map_location=map_location, weights_only=False) def load_weights(model: nn.Module, ckpt_path: str, map_location="cpu") -> nn.Module: state = _torch_load(ckpt_path, map_location=map_location) # Accept common formats: raw state_dict, {'state_dict': ...}, {'model': ...} if isinstance(state, dict) and "state_dict" in state: state = state["state_dict"] if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict): state = state["model"] # Strip possible DistributedDataParallel prefixes new_state = {} for k, v in state.items(): if k.startswith("module."): new_state[k[len("module."):]] = v else: new_state[k] = v model.load_state_dict(new_state, strict=False) model.eval() return model