Downgrade PyTorch to 2.0.x for YOLOv5 compatibility and simplify model loading accordingly
Browse files- app/models/yolo.py +3 -31
- requirements.txt +2 -2
app/models/yolo.py
CHANGED
|
@@ -54,37 +54,9 @@ class MarineSpeciesYOLO:
|
|
| 54 |
|
| 55 |
logger.info(f"Loading YOLOv5 model from: {self.model_path}")
|
| 56 |
|
| 57 |
-
#
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
'numpy.core.multiarray._reconstruct',
|
| 61 |
-
'numpy.ndarray',
|
| 62 |
-
'numpy.dtype',
|
| 63 |
-
'numpy.core.multiarray.scalar',
|
| 64 |
-
'collections.OrderedDict',
|
| 65 |
-
'torch._utils._rebuild_tensor_v2',
|
| 66 |
-
'torch.nn.modules.conv.Conv2d',
|
| 67 |
-
'torch.nn.modules.batchnorm.BatchNorm2d',
|
| 68 |
-
'torch.nn.modules.activation.SiLU'
|
| 69 |
-
]
|
| 70 |
-
|
| 71 |
-
try:
|
| 72 |
-
# Add safe globals
|
| 73 |
-
torch.serialization.add_safe_globals(safe_globals)
|
| 74 |
-
logger.info("Added safe globals for PyTorch loading")
|
| 75 |
-
except Exception as e:
|
| 76 |
-
logger.warning(f"Could not add safe globals: {e}")
|
| 77 |
-
|
| 78 |
-
# Try loading with safe globals context manager as well
|
| 79 |
-
try:
|
| 80 |
-
with torch.serialization.safe_globals(safe_globals):
|
| 81 |
-
self.model = yolov5.load(self.model_path, device=self.device)
|
| 82 |
-
logger.info("Model loaded successfully with safe globals context")
|
| 83 |
-
except Exception as e:
|
| 84 |
-
logger.warning(f"Safe globals context failed: {e}")
|
| 85 |
-
# Fallback to regular loading
|
| 86 |
-
logger.info("Attempting fallback model loading...")
|
| 87 |
-
self.model = yolov5.load(self.model_path, device=self.device)
|
| 88 |
|
| 89 |
# Get class names if available
|
| 90 |
if hasattr(self.model, 'names'):
|
|
|
|
| 54 |
|
| 55 |
logger.info(f"Loading YOLOv5 model from: {self.model_path}")
|
| 56 |
|
| 57 |
+
# Load model exactly like the original Gradio implementation
|
| 58 |
+
# Using PyTorch 2.0.1 which doesn't have the weights_only=True issue
|
| 59 |
+
self.model = yolov5.load(self.model_path, device=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# Get class names if available
|
| 62 |
if hasattr(self.model, 'names'):
|
requirements.txt
CHANGED
|
@@ -3,8 +3,8 @@ fastapi==0.104.1
|
|
| 3 |
uvicorn[standard]==0.24.0
|
| 4 |
|
| 5 |
# Machine Learning and Computer Vision
|
| 6 |
-
torch
|
| 7 |
-
torchvision
|
| 8 |
yolov5==6.2.3
|
| 9 |
opencv-python-headless==4.8.1.78
|
| 10 |
Pillow==10.1.0
|
|
|
|
| 3 |
uvicorn[standard]==0.24.0
|
| 4 |
|
| 5 |
# Machine Learning and Computer Vision
|
| 6 |
+
torch==2.0.1
|
| 7 |
+
torchvision==0.15.2
|
| 8 |
yolov5==6.2.3
|
| 9 |
opencv-python-headless==4.8.1.78
|
| 10 |
Pillow==10.1.0
|