jacopo22295 commited on
Commit
2249ab9
·
verified ·
1 Parent(s): d733d16

Upload 6 files

Browse files
Files changed (1) hide show
  1. model.py +21 -1
model.py CHANGED
@@ -9,18 +9,38 @@ def build_model(num_classes: int) -> nn.Module:
9
  model.fc = nn.Linear(in_features, num_classes)
10
  return model
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def load_weights(model: nn.Module, ckpt_path: str, map_location="cpu") -> nn.Module:
13
- state = torch.load(ckpt_path, map_location=map_location)
 
14
  if isinstance(state, dict) and "state_dict" in state:
15
  state = state["state_dict"]
16
  if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
17
  state = state["model"]
 
 
18
  new_state = {}
19
  for k, v in state.items():
20
  if k.startswith("module."):
21
  new_state[k[len("module."):]] = v
22
  else:
23
  new_state[k] = v
 
24
  model.load_state_dict(new_state, strict=False)
25
  model.eval()
26
  return model
 
9
  model.fc = nn.Linear(in_features, num_classes)
10
  return model
11
 
12
+ def _torch_load(path, map_location):
13
+ # Try safe load with weights_only=True, but allowlist needed numpy scalar if present.
14
+ try:
15
+ return torch.load(path, map_location=map_location, weights_only=True)
16
+ except Exception as e1:
17
+ # If it's the numpy scalar allowlist issue or any pickle restriction, retry with safe_globals
18
+ try:
19
+ from torch.serialization import add_safe_globals
20
+ import numpy as _np
21
+ add_safe_globals([_np._core.multiarray.scalar])
22
+ return torch.load(path, map_location=map_location, weights_only=True)
23
+ except Exception as e2:
24
+ # As a last resort, if and only if the file is trusted, load with weights_only=False
25
+ # This can execute arbitrary code present in the pickle. Use only for trusted checkpoints.
26
+ return torch.load(path, map_location=map_location, weights_only=False)
27
+
28
  def load_weights(model: nn.Module, ckpt_path: str, map_location="cpu") -> nn.Module:
29
+ state = _torch_load(ckpt_path, map_location=map_location)
30
+ # Accept common formats: raw state_dict, {'state_dict': ...}, {'model': ...}
31
  if isinstance(state, dict) and "state_dict" in state:
32
  state = state["state_dict"]
33
  if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
34
  state = state["model"]
35
+
36
+ # Strip possible DistributedDataParallel prefixes
37
  new_state = {}
38
  for k, v in state.items():
39
  if k.startswith("module."):
40
  new_state[k[len("module."):]] = v
41
  else:
42
  new_state[k] = v
43
+
44
  model.load_state_dict(new_state, strict=False)
45
  model.eval()
46
  return model