File size: 6,408 Bytes
b5859c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | """
YOLOv5 model wrapper adapted from the original Gradio implementation.
Compatible with the existing marina-benthic-33k.pt model.
"""
import torch
import yolov5
import numpy as np
from typing import Optional, List, Union, Dict, Any
from pathlib import Path
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class MarineSpeciesYOLO:
"""
Wrapper class for loading and running the marine species YOLOv5 model.
Adapted from the original inference.py to work with FastAPI.
"""
def __init__(self, model_path: str, device: Optional[str] = None):
"""
Initialize the YOLO model.
Args:
model_path: Path to the YOLOv5 model file
device: Device to run inference on ('cpu', 'cuda', etc.)
"""
self.model_path = model_path
self.device = device or self._get_device()
self.model = None
self._class_names = None
logger.info(f"Initializing MarineSpeciesYOLO with device: {self.device}")
self._load_model()
def _get_device(self) -> str:
"""Auto-detect the best available device."""
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps" # Apple Silicon
else:
return "cpu"
def _load_model(self) -> None:
"""Load the YOLOv5 model."""
try:
if not Path(self.model_path).exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
logger.info(f"Loading YOLOv5 model from: {self.model_path}")
# Handle PyTorch 2.6+ weights_only issue
import torch
import warnings
# Temporarily suppress warnings and set safe globals for numpy
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Add safe globals for numpy operations that YOLOv5 needs
torch.serialization.add_safe_globals([
'numpy.core.multiarray._reconstruct',
'numpy.ndarray',
'numpy.dtype',
'numpy.core.multiarray.scalar',
])
# Load the model with YOLOv5
self.model = yolov5.load(self.model_path, device=self.device)
# Get class names if available
if hasattr(self.model, 'names'):
self._class_names = self.model.names
logger.info(f"Loaded model with {len(self._class_names)} classes")
logger.info("YOLOv5 model loaded successfully")
except Exception as e:
logger.error(f"Failed to load YOLOv5 model: {str(e)}")
raise
def predict(
self,
image: Union[str, np.ndarray],
conf_threshold: float = 0.25,
iou_threshold: float = 0.45,
image_size: int = 720,
classes: Optional[List[int]] = None
) -> torch.Tensor:
"""
Run inference on an image.
Args:
image: Input image (file path or numpy array)
conf_threshold: Confidence threshold for detections
iou_threshold: IoU threshold for NMS
image_size: Input image size for inference
classes: List of class IDs to filter (None for all classes)
Returns:
YOLOv5 detection results
"""
if self.model is None:
raise RuntimeError("Model not loaded")
# Set model parameters
self.model.conf = conf_threshold
self.model.iou = iou_threshold
if classes is not None:
self.model.classes = classes
# Run inference
try:
detections = self.model(image, size=image_size)
return detections
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise
def get_class_names(self) -> Optional[Dict[int, str]]:
"""Get the class names mapping."""
return self._class_names
def get_model_info(self) -> Dict[str, Any]:
"""Get model information."""
return {
"model_path": self.model_path,
"device": self.device,
"num_classes": len(self._class_names) if self._class_names else None,
"class_names": self._class_names
}
def warmup(self, image_size: int = 720) -> None:
"""
Warm up the model with a dummy inference.
Args:
image_size: Size for warmup inference
"""
if self.model is None:
return
try:
logger.info("Warming up model...")
# Create a dummy image
dummy_image = np.random.randint(0, 255, (image_size, image_size, 3), dtype=np.uint8)
self.predict(dummy_image, conf_threshold=0.1)
logger.info("Model warmup completed")
except Exception as e:
logger.warning(f"Model warmup failed: {str(e)}")
# Global model instance (singleton pattern)
_model_instance: Optional[MarineSpeciesYOLO] = None
def get_model() -> MarineSpeciesYOLO:
"""
Get the global model instance (singleton pattern).
Returns:
MarineSpeciesYOLO instance
"""
global _model_instance
if _model_instance is None:
_model_instance = MarineSpeciesYOLO(
model_path=settings.MODEL_PATH,
device=settings.DEVICE
)
# Warm up the model if enabled
if settings.ENABLE_MODEL_WARMUP:
_model_instance.warmup()
return _model_instance
def load_class_names(names_file: str) -> Dict[int, str]:
"""
Load class names from a .names file.
Args:
names_file: Path to the .names file
Returns:
Dictionary mapping class IDs to names
"""
class_names = {}
try:
with open(names_file, 'r') as f:
for idx, line in enumerate(f):
class_names[idx] = line.strip()
logger.info(f"Loaded {len(class_names)} class names from {names_file}")
except Exception as e:
logger.error(f"Failed to load class names: {str(e)}")
return class_names
|