| | """ |
| | YOLOv5 model wrapper adapted from the original Gradio implementation. |
| | Compatible with the existing marina-benthic-33k.pt model. |
| | """ |
| |
|
| | import torch |
| | import yolov5 |
| | import numpy as np |
| | from typing import Optional, List, Union, Dict, Any |
| | from pathlib import Path |
| |
|
| | from app.core.config import settings |
| | from app.core.logging import get_logger |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class MarineSpeciesYOLO: |
| | """ |
| | Wrapper class for loading and running the marine species YOLOv5 model. |
| | Adapted from the original inference.py to work with FastAPI. |
| | """ |
| | |
| | def __init__(self, model_path: str, device: Optional[str] = None): |
| | """ |
| | Initialize the YOLO model. |
| | |
| | Args: |
| | model_path: Path to the YOLOv5 model file |
| | device: Device to run inference on ('cpu', 'cuda', etc.) |
| | """ |
| | self.model_path = model_path |
| | self.device = device or self._get_device() |
| | self.model = None |
| | self._class_names = None |
| | |
| | logger.info(f"Initializing MarineSpeciesYOLO with device: {self.device}") |
| | self._load_model() |
| | |
| | def _get_device(self) -> str: |
| | """Auto-detect the best available device.""" |
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| | return "mps" |
| | else: |
| | return "cpu" |
| | |
| | def _load_model(self) -> None: |
| | """Load the YOLOv5 model.""" |
| | try: |
| | if not Path(self.model_path).exists(): |
| | raise FileNotFoundError(f"Model file not found: {self.model_path}") |
| |
|
| | logger.info(f"Loading YOLOv5 model from: {self.model_path}") |
| |
|
| | |
| | import torch |
| | import warnings |
| |
|
| | |
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("ignore") |
| |
|
| | |
| | torch.serialization.add_safe_globals([ |
| | 'numpy.core.multiarray._reconstruct', |
| | 'numpy.ndarray', |
| | 'numpy.dtype', |
| | 'numpy.core.multiarray.scalar', |
| | ]) |
| |
|
| | |
| | self.model = yolov5.load(self.model_path, device=self.device) |
| |
|
| | |
| | if hasattr(self.model, 'names'): |
| | self._class_names = self.model.names |
| | logger.info(f"Loaded model with {len(self._class_names)} classes") |
| |
|
| | logger.info("YOLOv5 model loaded successfully") |
| |
|
| | except Exception as e: |
| | logger.error(f"Failed to load YOLOv5 model: {str(e)}") |
| | raise |
| | |
| | def predict( |
| | self, |
| | image: Union[str, np.ndarray], |
| | conf_threshold: float = 0.25, |
| | iou_threshold: float = 0.45, |
| | image_size: int = 720, |
| | classes: Optional[List[int]] = None |
| | ) -> torch.Tensor: |
| | """ |
| | Run inference on an image. |
| | |
| | Args: |
| | image: Input image (file path or numpy array) |
| | conf_threshold: Confidence threshold for detections |
| | iou_threshold: IoU threshold for NMS |
| | image_size: Input image size for inference |
| | classes: List of class IDs to filter (None for all classes) |
| | |
| | Returns: |
| | YOLOv5 detection results |
| | """ |
| | if self.model is None: |
| | raise RuntimeError("Model not loaded") |
| | |
| | |
| | self.model.conf = conf_threshold |
| | self.model.iou = iou_threshold |
| | |
| | if classes is not None: |
| | self.model.classes = classes |
| | |
| | |
| | try: |
| | detections = self.model(image, size=image_size) |
| | return detections |
| | except Exception as e: |
| | logger.error(f"Inference failed: {str(e)}") |
| | raise |
| | |
| | def get_class_names(self) -> Optional[Dict[int, str]]: |
| | """Get the class names mapping.""" |
| | return self._class_names |
| | |
| | def get_model_info(self) -> Dict[str, Any]: |
| | """Get model information.""" |
| | return { |
| | "model_path": self.model_path, |
| | "device": self.device, |
| | "num_classes": len(self._class_names) if self._class_names else None, |
| | "class_names": self._class_names |
| | } |
| | |
| | def warmup(self, image_size: int = 720) -> None: |
| | """ |
| | Warm up the model with a dummy inference. |
| | |
| | Args: |
| | image_size: Size for warmup inference |
| | """ |
| | if self.model is None: |
| | return |
| | |
| | try: |
| | logger.info("Warming up model...") |
| | |
| | dummy_image = np.random.randint(0, 255, (image_size, image_size, 3), dtype=np.uint8) |
| | self.predict(dummy_image, conf_threshold=0.1) |
| | logger.info("Model warmup completed") |
| | except Exception as e: |
| | logger.warning(f"Model warmup failed: {str(e)}") |
| |
|
| |
|
| | |
| | _model_instance: Optional[MarineSpeciesYOLO] = None |
| |
|
| |
|
| | def get_model() -> MarineSpeciesYOLO: |
| | """ |
| | Get the global model instance (singleton pattern). |
| | |
| | Returns: |
| | MarineSpeciesYOLO instance |
| | """ |
| | global _model_instance |
| | |
| | if _model_instance is None: |
| | _model_instance = MarineSpeciesYOLO( |
| | model_path=settings.MODEL_PATH, |
| | device=settings.DEVICE |
| | ) |
| | |
| | |
| | if settings.ENABLE_MODEL_WARMUP: |
| | _model_instance.warmup() |
| | |
| | return _model_instance |
| |
|
| |
|
| | def load_class_names(names_file: str) -> Dict[int, str]: |
| | """ |
| | Load class names from a .names file. |
| | |
| | Args: |
| | names_file: Path to the .names file |
| | |
| | Returns: |
| | Dictionary mapping class IDs to names |
| | """ |
| | class_names = {} |
| | try: |
| | with open(names_file, 'r') as f: |
| | for idx, line in enumerate(f): |
| | class_names[idx] = line.strip() |
| | logger.info(f"Loaded {len(class_names)} class names from {names_file}") |
| | except Exception as e: |
| | logger.error(f"Failed to load class names: {str(e)}") |
| | |
| | return class_names |
| |
|