""" YOLOv5 model wrapper adapted from the original Gradio implementation. Compatible with the existing marina-benthic-33k.pt model. """ import torch import yolov5 import numpy as np from typing import Optional, List, Union, Dict, Any from pathlib import Path from app.core.config import settings from app.core.logging import get_logger logger = get_logger(__name__) class MarineSpeciesYOLO: """ Wrapper class for loading and running the marine species YOLOv5 model. Adapted from the original inference.py to work with FastAPI. """ def __init__(self, model_path: str, device: Optional[str] = None): """ Initialize the YOLO model. Args: model_path: Path to the YOLOv5 model file device: Device to run inference on ('cpu', 'cuda', etc.) """ self.model_path = model_path self.device = device or self._get_device() self.model = None self._class_names = None logger.info(f"Initializing MarineSpeciesYOLO with device: {self.device}") self._load_model() def _get_device(self) -> str: """Auto-detect the best available device.""" if torch.cuda.is_available(): return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return "mps" # Apple Silicon else: return "cpu" def _load_model(self) -> None: """Load the YOLOv5 model.""" try: if not Path(self.model_path).exists(): raise FileNotFoundError(f"Model file not found: {self.model_path}") logger.info(f"Loading YOLOv5 model from: {self.model_path}") # Handle PyTorch 2.6+ weights_only issue import torch import warnings # Temporarily suppress warnings and set safe globals for numpy with warnings.catch_warnings(): warnings.simplefilter("ignore") # Add safe globals for numpy operations that YOLOv5 needs torch.serialization.add_safe_globals([ 'numpy.core.multiarray._reconstruct', 'numpy.ndarray', 'numpy.dtype', 'numpy.core.multiarray.scalar', ]) # Load the model with YOLOv5 self.model = yolov5.load(self.model_path, device=self.device) # Get class names if available if hasattr(self.model, 'names'): self._class_names = self.model.names logger.info(f"Loaded model with {len(self._class_names)} classes") logger.info("YOLOv5 model loaded successfully") except Exception as e: logger.error(f"Failed to load YOLOv5 model: {str(e)}") raise def predict( self, image: Union[str, np.ndarray], conf_threshold: float = 0.25, iou_threshold: float = 0.45, image_size: int = 720, classes: Optional[List[int]] = None ) -> torch.Tensor: """ Run inference on an image. Args: image: Input image (file path or numpy array) conf_threshold: Confidence threshold for detections iou_threshold: IoU threshold for NMS image_size: Input image size for inference classes: List of class IDs to filter (None for all classes) Returns: YOLOv5 detection results """ if self.model is None: raise RuntimeError("Model not loaded") # Set model parameters self.model.conf = conf_threshold self.model.iou = iou_threshold if classes is not None: self.model.classes = classes # Run inference try: detections = self.model(image, size=image_size) return detections except Exception as e: logger.error(f"Inference failed: {str(e)}") raise def get_class_names(self) -> Optional[Dict[int, str]]: """Get the class names mapping.""" return self._class_names def get_model_info(self) -> Dict[str, Any]: """Get model information.""" return { "model_path": self.model_path, "device": self.device, "num_classes": len(self._class_names) if self._class_names else None, "class_names": self._class_names } def warmup(self, image_size: int = 720) -> None: """ Warm up the model with a dummy inference. Args: image_size: Size for warmup inference """ if self.model is None: return try: logger.info("Warming up model...") # Create a dummy image dummy_image = np.random.randint(0, 255, (image_size, image_size, 3), dtype=np.uint8) self.predict(dummy_image, conf_threshold=0.1) logger.info("Model warmup completed") except Exception as e: logger.warning(f"Model warmup failed: {str(e)}") # Global model instance (singleton pattern) _model_instance: Optional[MarineSpeciesYOLO] = None def get_model() -> MarineSpeciesYOLO: """ Get the global model instance (singleton pattern). Returns: MarineSpeciesYOLO instance """ global _model_instance if _model_instance is None: _model_instance = MarineSpeciesYOLO( model_path=settings.MODEL_PATH, device=settings.DEVICE ) # Warm up the model if enabled if settings.ENABLE_MODEL_WARMUP: _model_instance.warmup() return _model_instance def load_class_names(names_file: str) -> Dict[int, str]: """ Load class names from a .names file. Args: names_file: Path to the .names file Returns: Dictionary mapping class IDs to names """ class_names = {} try: with open(names_file, 'r') as f: for idx, line in enumerate(f): class_names[idx] = line.strip() logger.info(f"Loaded {len(class_names)} class names from {names_file}") except Exception as e: logger.error(f"Failed to load class names: {str(e)}") return class_names