VietCat Amp commited on
Commit
afb1d42
·
1 Parent(s): fdf53a5

Use monkey patch to disable weights_only for ultralytics model loading

Browse files

Amp-Thread-ID: https://ampcode.com/threads/T-22e83726-d77c-475f-b185-6d55f37c5979
Co-authored-by: Amp <amp@ampcode.com>

Files changed (1) hide show
  1. model.py +20 -12
model.py CHANGED
@@ -5,25 +5,33 @@ import yaml
5
  from huggingface_hub import hf_hub_download
6
  import os
7
  import torch
8
- from ultralytics.nn.tasks import DetectionModel
9
 
10
  class TrafficSignDetector:
11
  def __init__(self, config_path):
12
  with open(config_path, 'r') as f:
13
  config = yaml.safe_load(f)
14
 
15
- # Download model from HF Hub if not local
16
- model_path = config['model']['path']
17
- if model_path.startswith('VietCat/GTSRB-Model/'):
18
- # Extract repo and file path for GTSRB-Model
19
- repo_id = 'VietCat/GTSRB-Model'
20
- file_path = model_path.replace('VietCat/GTSRB-Model/', '')
21
- local_model_path = hf_hub_download(repo_id=repo_id, filename=file_path)
22
- # Allow safe loading of ultralytics model
23
- with torch.serialization.safe_globals([torch.nn.Module, DetectionModel]):
 
 
 
 
 
 
24
  self.model = YOLO(local_model_path)
25
- else:
26
- self.model = YOLO(model_path)
 
 
 
27
 
28
  self.conf_threshold = config['model']['confidence_threshold']
29
  self.box_color = config['inference']['box_color']
 
5
  from huggingface_hub import hf_hub_download
6
  import os
7
  import torch
 
8
 
9
  class TrafficSignDetector:
10
  def __init__(self, config_path):
11
  with open(config_path, 'r') as f:
12
  config = yaml.safe_load(f)
13
 
14
+ # Monkey patch torch.load to disable weights_only for ultralytics
15
+ original_torch_load = torch.load
16
+ def patched_torch_load(*args, **kwargs):
17
+ kwargs['weights_only'] = False
18
+ return original_torch_load(*args, **kwargs)
19
+ torch.load = patched_torch_load
20
+
21
+ try:
22
+ # Download model from HF Hub if not local
23
+ model_path = config['model']['path']
24
+ if model_path.startswith('VietCat/GTSRB-Model/'):
25
+ # Extract repo and file path for GTSRB-Model
26
+ repo_id = 'VietCat/GTSRB-Model'
27
+ file_path = model_path.replace('VietCat/GTSRB-Model/', '')
28
+ local_model_path = hf_hub_download(repo_id=repo_id, filename=file_path)
29
  self.model = YOLO(local_model_path)
30
+ else:
31
+ self.model = YOLO(model_path)
32
+ finally:
33
+ # Restore original torch.load
34
+ torch.load = original_torch_load
35
 
36
  self.conf_threshold = config['model']['confidence_threshold']
37
  self.box_color = config['inference']['box_color']