File size: 1,933 Bytes
aa983ba
 
 
 
 
 
 
 
 
 
 
2249ab9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa983ba
2249ab9
 
aa983ba
 
 
 
2249ab9
 
aa983ba
 
 
 
 
 
2249ab9
aa983ba
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

import torch
import torch.nn as nn
from torchvision import models

def build_model(num_classes: int) -> nn.Module:
    model = models.resnet34(weights=None)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

def _torch_load(path, map_location):
    # Try safe load with weights_only=True, but allowlist needed numpy scalar if present.
    try:
        return torch.load(path, map_location=map_location, weights_only=True)
    except Exception as e1:
        # If it's the numpy scalar allowlist issue or any pickle restriction, retry with safe_globals
        try:
            from torch.serialization import add_safe_globals
            import numpy as _np
            add_safe_globals([_np._core.multiarray.scalar])
            return torch.load(path, map_location=map_location, weights_only=True)
        except Exception as e2:
            # As a last resort, if and only if the file is trusted, load with weights_only=False
            # This can execute arbitrary code present in the pickle. Use only for trusted checkpoints.
            return torch.load(path, map_location=map_location, weights_only=False)

def load_weights(model: nn.Module, ckpt_path: str, map_location="cpu") -> nn.Module:
    state = _torch_load(ckpt_path, map_location=map_location)
    # Accept common formats: raw state_dict, {'state_dict': ...}, {'model': ...}
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
        state = state["model"]

    # Strip possible DistributedDataParallel prefixes
    new_state = {}
    for k, v in state.items():
        if k.startswith("module."):
            new_state[k[len("module."):]] = v
        else:
            new_state[k] = v

    model.load_state_dict(new_state, strict=False)
    model.eval()
    return model