dschandra commited on
Commit
340ad0f
·
verified ·
1 Parent(s): fc642e1

Update detector.py

Browse files
Files changed (1) hide show
  1. detector.py +9 -6
detector.py CHANGED
@@ -1,17 +1,20 @@
1
  from ultralytics import YOLO
2
  import cv2
3
  import torch
4
- import torch.serialization
5
- import ultralytics # Required to access ultralytics-specific classes
6
 
7
  class LBWDetector:
8
  def __init__(self, model_path='best.pt'):
9
- """Initialize YOLO model with safe globals for PyTorch 2.6+."""
10
- with torch.serialization.safe_globals([torch.nn.modules.container.Sequential, ultralytics.nn.tasks.DetectionModel, ultralytics.nn.modules.conv.Conv, torch.nn.modules.conv.Conv2d]):
11
- self.model = YOLO(model_path)
 
 
 
 
 
 
12
 
13
  def detect_objects(self, frame):
14
- """Detect objects in a frame and return bounding boxes and class names."""
15
  results = self.model.predict(source=frame, conf=0.3, save=False, verbose=False)
16
  detections = results[0].boxes.data.cpu().numpy() # x1, y1, x2, y2, conf, class
17
  return detections, results[0].names
 
1
  from ultralytics import YOLO
2
  import cv2
3
  import torch
 
 
4
 
5
  class LBWDetector:
6
  def __init__(self, model_path='best.pt'):
7
+ # Temporarily override torch.load to use weights_only=False
8
+ original_load = torch.load
9
+ def custom_load(*args, **kwargs):
10
+ kwargs['weights_only'] = False
11
+ return original_load(*args, **kwargs)
12
+ torch.load = custom_load
13
+ self.model = YOLO(model_path)
14
+ # Restore original torch.load
15
+ torch.load = original_load
16
 
17
  def detect_objects(self, frame):
 
18
  results = self.model.predict(source=frame, conf=0.3, save=False, verbose=False)
19
  detections = results[0].boxes.data.cpu().numpy() # x1, y1, x2, y2, conf, class
20
  return detections, results[0].names