fishapi / yolo.py
kamau1's picture
Fix YOLOv5 model loading by setting weights_only=False for PyTorch 2.6 compatibility
b5859c0 verified
"""
YOLOv5 model wrapper adapted from the original Gradio implementation.
Compatible with the existing marina-benthic-33k.pt model.
"""
import torch
import yolov5
import numpy as np
from typing import Optional, List, Union, Dict, Any
from pathlib import Path
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class MarineSpeciesYOLO:
"""
Wrapper class for loading and running the marine species YOLOv5 model.
Adapted from the original inference.py to work with FastAPI.
"""
def __init__(self, model_path: str, device: Optional[str] = None):
"""
Initialize the YOLO model.
Args:
model_path: Path to the YOLOv5 model file
device: Device to run inference on ('cpu', 'cuda', etc.)
"""
self.model_path = model_path
self.device = device or self._get_device()
self.model = None
self._class_names = None
logger.info(f"Initializing MarineSpeciesYOLO with device: {self.device}")
self._load_model()
def _get_device(self) -> str:
"""Auto-detect the best available device."""
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps" # Apple Silicon
else:
return "cpu"
def _load_model(self) -> None:
"""Load the YOLOv5 model."""
try:
if not Path(self.model_path).exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
logger.info(f"Loading YOLOv5 model from: {self.model_path}")
# Handle PyTorch 2.6+ weights_only issue
import torch
import warnings
# Temporarily suppress warnings and set safe globals for numpy
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Add safe globals for numpy operations that YOLOv5 needs
torch.serialization.add_safe_globals([
'numpy.core.multiarray._reconstruct',
'numpy.ndarray',
'numpy.dtype',
'numpy.core.multiarray.scalar',
])
# Load the model with YOLOv5
self.model = yolov5.load(self.model_path, device=self.device)
# Get class names if available
if hasattr(self.model, 'names'):
self._class_names = self.model.names
logger.info(f"Loaded model with {len(self._class_names)} classes")
logger.info("YOLOv5 model loaded successfully")
except Exception as e:
logger.error(f"Failed to load YOLOv5 model: {str(e)}")
raise
def predict(
self,
image: Union[str, np.ndarray],
conf_threshold: float = 0.25,
iou_threshold: float = 0.45,
image_size: int = 720,
classes: Optional[List[int]] = None
) -> torch.Tensor:
"""
Run inference on an image.
Args:
image: Input image (file path or numpy array)
conf_threshold: Confidence threshold for detections
iou_threshold: IoU threshold for NMS
image_size: Input image size for inference
classes: List of class IDs to filter (None for all classes)
Returns:
YOLOv5 detection results
"""
if self.model is None:
raise RuntimeError("Model not loaded")
# Set model parameters
self.model.conf = conf_threshold
self.model.iou = iou_threshold
if classes is not None:
self.model.classes = classes
# Run inference
try:
detections = self.model(image, size=image_size)
return detections
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise
def get_class_names(self) -> Optional[Dict[int, str]]:
"""Get the class names mapping."""
return self._class_names
def get_model_info(self) -> Dict[str, Any]:
"""Get model information."""
return {
"model_path": self.model_path,
"device": self.device,
"num_classes": len(self._class_names) if self._class_names else None,
"class_names": self._class_names
}
def warmup(self, image_size: int = 720) -> None:
"""
Warm up the model with a dummy inference.
Args:
image_size: Size for warmup inference
"""
if self.model is None:
return
try:
logger.info("Warming up model...")
# Create a dummy image
dummy_image = np.random.randint(0, 255, (image_size, image_size, 3), dtype=np.uint8)
self.predict(dummy_image, conf_threshold=0.1)
logger.info("Model warmup completed")
except Exception as e:
logger.warning(f"Model warmup failed: {str(e)}")
# Global model instance (singleton pattern)
_model_instance: Optional[MarineSpeciesYOLO] = None
def get_model() -> MarineSpeciesYOLO:
"""
Get the global model instance (singleton pattern).
Returns:
MarineSpeciesYOLO instance
"""
global _model_instance
if _model_instance is None:
_model_instance = MarineSpeciesYOLO(
model_path=settings.MODEL_PATH,
device=settings.DEVICE
)
# Warm up the model if enabled
if settings.ENABLE_MODEL_WARMUP:
_model_instance.warmup()
return _model_instance
def load_class_names(names_file: str) -> Dict[int, str]:
"""
Load class names from a .names file.
Args:
names_file: Path to the .names file
Returns:
Dictionary mapping class IDs to names
"""
class_names = {}
try:
with open(names_file, 'r') as f:
for idx, line in enumerate(f):
class_names[idx] = line.strip()
logger.info(f"Loaded {len(class_names)} class names from {names_file}")
except Exception as e:
logger.error(f"Failed to load class names: {str(e)}")
return class_names