kamau1 commited on
Commit
8b3f6c2
·
verified ·
1 Parent(s): 09efaa0

Downgrade PyTorch to 2.0.x for YOLOv5 compatibility and simplify model loading accordingly

Browse files
Files changed (2) hide show
  1. app/models/yolo.py +3 -31
  2. requirements.txt +2 -2
app/models/yolo.py CHANGED
@@ -54,37 +54,9 @@ class MarineSpeciesYOLO:
54
 
55
  logger.info(f"Loading YOLOv5 model from: {self.model_path}")
56
 
57
- # Handle PyTorch 2.6+ weights_only issue
58
- # Add safe globals for numpy operations that YOLOv5 needs
59
- safe_globals = [
60
- 'numpy.core.multiarray._reconstruct',
61
- 'numpy.ndarray',
62
- 'numpy.dtype',
63
- 'numpy.core.multiarray.scalar',
64
- 'collections.OrderedDict',
65
- 'torch._utils._rebuild_tensor_v2',
66
- 'torch.nn.modules.conv.Conv2d',
67
- 'torch.nn.modules.batchnorm.BatchNorm2d',
68
- 'torch.nn.modules.activation.SiLU'
69
- ]
70
-
71
- try:
72
- # Add safe globals
73
- torch.serialization.add_safe_globals(safe_globals)
74
- logger.info("Added safe globals for PyTorch loading")
75
- except Exception as e:
76
- logger.warning(f"Could not add safe globals: {e}")
77
-
78
- # Try loading with safe globals context manager as well
79
- try:
80
- with torch.serialization.safe_globals(safe_globals):
81
- self.model = yolov5.load(self.model_path, device=self.device)
82
- logger.info("Model loaded successfully with safe globals context")
83
- except Exception as e:
84
- logger.warning(f"Safe globals context failed: {e}")
85
- # Fallback to regular loading
86
- logger.info("Attempting fallback model loading...")
87
- self.model = yolov5.load(self.model_path, device=self.device)
88
 
89
  # Get class names if available
90
  if hasattr(self.model, 'names'):
 
54
 
55
  logger.info(f"Loading YOLOv5 model from: {self.model_path}")
56
 
57
+ # Load model exactly like the original Gradio implementation
58
+ # Using PyTorch 2.0.1 which doesn't have the weights_only=True issue
59
+ self.model = yolov5.load(self.model_path, device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Get class names if available
62
  if hasattr(self.model, 'names'):
requirements.txt CHANGED
@@ -3,8 +3,8 @@ fastapi==0.104.1
3
  uvicorn[standard]==0.24.0
4
 
5
  # Machine Learning and Computer Vision
6
- torch>=1.13.0
7
- torchvision>=0.14.0
8
  yolov5==6.2.3
9
  opencv-python-headless==4.8.1.78
10
  Pillow==10.1.0
 
3
  uvicorn[standard]==0.24.0
4
 
5
  # Machine Learning and Computer Vision
6
+ torch==2.0.1
7
+ torchvision==0.15.2
8
  yolov5==6.2.3
9
  opencv-python-headless==4.8.1.78
10
  Pillow==10.1.0