msn-enginenova21 commited on
Commit
2e4976f
·
verified ·
1 Parent(s): 8a2901a

Update yolov5/models/experimental.py

Browse files
Files changed (1) hide show
  1. yolov5/models/experimental.py +19 -8
yolov5/models/experimental.py CHANGED
@@ -3,16 +3,15 @@
3
  Experimental modules
4
  """
5
  import math
6
- import models.yolo
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
10
- import torch
11
 
12
- from models.common import Conv # as the error previously asked for
13
- from models.yolo import DetectionModel
14
  from torch.nn.modules.container import Sequential
15
- from models.common import Conv # <--- new: the class in the latest error
 
16
  from utils.downloads import attempt_download
17
 
18
 
@@ -82,10 +81,22 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
82
  model = Ensemble()
83
 
84
  for w in weights if isinstance(weights, list) else [weights]:
85
- with torch.serialization.safe_globals([DetectionModel, Sequential, Conv, torch.nn.modules.conv.Conv2d]):
86
- ckpt = torch.load(attempt_download(w), map_location='cpu') # safe load
 
 
 
 
 
 
 
 
 
 
 
 
87
  ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
88
-
89
  # Model compatibility updates
90
  if not hasattr(ckpt, 'stride'):
91
  ckpt.stride = torch.tensor([32.])
 
3
  Experimental modules
4
  """
5
  import math
 
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
 
9
 
10
+ import models.yolo # ensures module resolution
11
+ from models.yolo import DetectionModel # for allowlisting
12
  from torch.nn.modules.container import Sequential
13
+ from models.common import Conv # YOLO-specific common conv block
14
+
15
  from utils.downloads import attempt_download
16
 
17
 
 
81
  model = Ensemble()
82
 
83
  for w in weights if isinstance(weights, list) else [weights]:
84
+ # try safe allowlisted load first
85
+ try:
86
+ with torch.serialization.safe_globals([
87
+ DetectionModel,
88
+ Sequential,
89
+ Conv,
90
+ torch.nn.modules.conv.Conv2d,
91
+ torch.nn.modules.batchnorm.BatchNorm2d
92
+ ]):
93
+ ckpt = torch.load(attempt_download(w), map_location='cpu') # safe load
94
+ except Exception:
95
+ # fallback to full load if checkpoint is trusted
96
+ ckpt = torch.load(attempt_download(w), map_location='cpu', weights_only=False)
97
+
98
  ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
99
+
100
  # Model compatibility updates
101
  if not hasattr(ckpt, 'stride'):
102
  ckpt.stride = torch.tensor([32.])