Resnet34-Testrun / model.py
jacopo22295's picture
Upload 6 files
2249ab9 verified
import torch
import torch.nn as nn
from torchvision import models
def build_model(num_classes: int) -> nn.Module:
model = models.resnet34(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
return model
def _torch_load(path, map_location):
# Try safe load with weights_only=True, but allowlist needed numpy scalar if present.
try:
return torch.load(path, map_location=map_location, weights_only=True)
except Exception as e1:
# If it's the numpy scalar allowlist issue or any pickle restriction, retry with safe_globals
try:
from torch.serialization import add_safe_globals
import numpy as _np
add_safe_globals([_np._core.multiarray.scalar])
return torch.load(path, map_location=map_location, weights_only=True)
except Exception as e2:
# As a last resort, if and only if the file is trusted, load with weights_only=False
# This can execute arbitrary code present in the pickle. Use only for trusted checkpoints.
return torch.load(path, map_location=map_location, weights_only=False)
def load_weights(model: nn.Module, ckpt_path: str, map_location="cpu") -> nn.Module:
state = _torch_load(ckpt_path, map_location=map_location)
# Accept common formats: raw state_dict, {'state_dict': ...}, {'model': ...}
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
state = state["model"]
# Strip possible DistributedDataParallel prefixes
new_state = {}
for k, v in state.items():
if k.startswith("module."):
new_state[k[len("module."):]] = v
else:
new_state[k] = v
model.load_state_dict(new_state, strict=False)
model.eval()
return model