fishapi / app /models /inference.py
kamau1's picture
Downgrade YOLOv5 for compatibility, simplify model loading, fix Pydantic warnings, and improve error handling
28552c3 verified
"""
Pydantic models for API requests and responses.
"""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field, field_validator
import base64
class BoundingBox(BaseModel):
"""Bounding box coordinates."""
x: float = Field(..., description="X coordinate of top-left corner")
y: float = Field(..., description="Y coordinate of top-left corner")
width: float = Field(..., description="Width of bounding box")
height: float = Field(..., description="Height of bounding box")
class Detection(BaseModel):
"""Single detection result."""
class_id: int = Field(..., description="Class ID of detected species")
class_name: str = Field(..., description="Name of detected marine species")
confidence: float = Field(..., ge=0.0, le=1.0, description="Detection confidence score")
bbox: BoundingBox = Field(..., description="Bounding box coordinates")
class ModelInfo(BaseModel):
"""Model information."""
model_config = {"protected_namespaces": ()}
model_name: str = Field(..., description="Name of the model")
total_classes: int = Field(..., description="Total number of species classes")
device: str = Field(..., description="Device used for inference")
model_path: str = Field(..., description="Path to model file")
class InferenceRequest(BaseModel):
"""Request model for marine species detection."""
image: str = Field(..., description="Base64 encoded image data")
confidence_threshold: float = Field(
default=0.25,
ge=0.0,
le=1.0,
description="Confidence threshold for detections"
)
iou_threshold: float = Field(
default=0.45,
ge=0.0,
le=1.0,
description="IoU threshold for non-maximum suppression"
)
image_size: int = Field(
default=720,
ge=320,
le=1280,
description="Input image size for inference"
)
return_annotated_image: bool = Field(
default=True,
description="Whether to return annotated image with detections"
)
classes: Optional[List[int]] = Field(
default=None,
description="List of class IDs to filter (None for all classes)"
)
@field_validator('image')
@classmethod
def validate_image(cls, v):
"""Validate base64 image data."""
try:
# Try to decode base64 to ensure it's valid
base64.b64decode(v)
return v
except Exception:
raise ValueError("Invalid base64 image data")
class InferenceResponse(BaseModel):
"""Response model for marine species detection."""
model_config = {"protected_namespaces": ()}
detections: List[Detection] = Field(..., description="List of detected marine species")
annotated_image: Optional[str] = Field(
default=None,
description="Base64 encoded annotated image (if requested)"
)
processing_time: float = Field(..., description="Processing time in seconds")
model_info: ModelInfo = Field(..., description="Information about the model used")
image_dimensions: Dict[str, int] = Field(
...,
description="Original image dimensions (width, height)"
)
class SpeciesInfo(BaseModel):
"""Information about a marine species."""
class_id: int = Field(..., description="Class ID")
class_name: str = Field(..., description="Species name")
class SpeciesListResponse(BaseModel):
"""Response model for species list endpoint."""
species: List[SpeciesInfo] = Field(..., description="List of all supported marine species")
total_count: int = Field(..., description="Total number of species")
class HealthResponse(BaseModel):
"""Response model for health check."""
model_config = {"protected_namespaces": ()}
status: str = Field(..., description="API status")
model_loaded: bool = Field(..., description="Whether the model is loaded")
model_info: Optional[ModelInfo] = Field(default=None, description="Model information")
timestamp: str = Field(..., description="Response timestamp")
class ErrorResponse(BaseModel):
"""Error response model."""
error: str = Field(..., description="Error type")
message: str = Field(..., description="Error message")
details: Optional[Dict[str, Any]] = Field(default=None, description="Additional error details")
class APIInfo(BaseModel):
"""API information response."""
model_config = {"protected_namespaces": ()}
name: str = Field(..., description="API name")
version: str = Field(..., description="API version")
description: str = Field(..., description="API description")
model_info: ModelInfo = Field(..., description="Model information")
endpoints: List[str] = Field(..., description="Available endpoints")