fishapi / app /services /inference_service.py
kamau1's picture
Initial commit
bcc2f7b verified
"""
Inference service for marine species detection.
"""
import time
import base64
import io
from typing import List, Optional, Dict, Tuple
import numpy as np
from PIL import Image
import cv2
from app.core.config import settings
from app.core.logging import get_logger
from app.models.inference import Detection, BoundingBox, InferenceResponse, ModelInfo
from app.services.model_service import model_service
from app.utils.image_processing import decode_base64_image, encode_image_to_base64
logger = get_logger(__name__)
class InferenceService:
"""Service for running marine species detection inference."""
def __init__(self):
self.model_service = model_service
async def detect_species(
self,
image_data: str,
confidence_threshold: float = 0.25,
iou_threshold: float = 0.45,
image_size: int = 720,
return_annotated_image: bool = True,
classes: Optional[List[int]] = None
) -> InferenceResponse:
"""
Detect marine species in an image.
Args:
image_data: Base64 encoded image
confidence_threshold: Confidence threshold for detections
iou_threshold: IoU threshold for NMS
image_size: Input image size for inference
return_annotated_image: Whether to return annotated image
classes: List of class IDs to filter
Returns:
InferenceResponse with detection results
"""
start_time = time.time()
try:
# Decode the image
image, original_dims = decode_base64_image(image_data)
logger.info(f"Processing image with dimensions: {original_dims}")
# Get the model
model = self.model_service.get_model()
# Run inference
predictions = model.predict(
image=image,
conf_threshold=confidence_threshold,
iou_threshold=iou_threshold,
image_size=image_size,
classes=classes
)
# Process predictions
detections = self._process_predictions(predictions)
# Generate annotated image if requested
annotated_image_b64 = None
if return_annotated_image and detections:
annotated_image = self._create_annotated_image(image, predictions)
annotated_image_b64 = encode_image_to_base64(annotated_image)
# Get model info
model_info_dict = self.model_service.get_model_info()
model_info = ModelInfo(**model_info_dict)
processing_time = time.time() - start_time
logger.info(f"Inference completed in {processing_time:.3f}s, found {len(detections)} detections")
return InferenceResponse(
detections=detections,
annotated_image=annotated_image_b64,
processing_time=processing_time,
model_info=model_info,
image_dimensions={"width": original_dims[0], "height": original_dims[1]}
)
except Exception as e:
logger.error(f"Inference failed: {str(e)}")
raise
def _process_predictions(self, predictions) -> List[Detection]:
"""
Process YOLOv5 predictions into Detection objects.
Args:
predictions: YOLOv5 prediction results
Returns:
List of Detection objects
"""
detections = []
class_names = self.model_service.get_class_names()
try:
# Get predictions as pandas DataFrame
pred_df = predictions.pandas().xyxy[0]
for _, row in pred_df.iterrows():
# Extract bounding box coordinates
x1, y1, x2, y2 = row['xmin'], row['ymin'], row['xmax'], row['ymax']
width = x2 - x1
height = y2 - y1
# Get class information
class_id = int(row['class'])
confidence = float(row['confidence'])
# Get class name
if class_names and class_id in class_names:
class_name = class_names[class_id]
else:
class_name = f"class_{class_id}"
# Create detection object
detection = Detection(
class_id=class_id,
class_name=class_name,
confidence=confidence,
bbox=BoundingBox(
x=float(x1),
y=float(y1),
width=float(width),
height=float(height)
)
)
detections.append(detection)
except Exception as e:
logger.error(f"Failed to process predictions: {str(e)}")
raise
return detections
def _create_annotated_image(self, original_image: np.ndarray, predictions) -> np.ndarray:
"""
Create an annotated image with detection boxes and labels.
Args:
original_image: Original input image
predictions: YOLOv5 prediction results
Returns:
Annotated image as numpy array
"""
try:
# Use YOLOv5's built-in rendering
rendered_imgs = predictions.render()
if rendered_imgs and len(rendered_imgs) > 0:
return rendered_imgs[0]
else:
# Fallback: return original image if rendering fails
return original_image
except Exception as e:
logger.error(f"Failed to create annotated image: {str(e)}")
# Return original image as fallback
return original_image
async def get_supported_species(self) -> List[Dict]:
"""
Get list of all supported marine species.
Returns:
List of species information
"""
class_names = self.model_service.get_class_names()
if not class_names:
return []
species_list = []
for class_id, class_name in class_names.items():
species_list.append({
"class_id": class_id,
"class_name": class_name
})
return sorted(species_list, key=lambda x: x["class_name"])
# Global service instance
inference_service = InferenceService()