|
|
""" |
|
|
Pydantic models for API requests and responses. |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Dict, Any |
|
|
from pydantic import BaseModel, Field, field_validator |
|
|
import base64 |
|
|
|
|
|
|
|
|
class BoundingBox(BaseModel): |
|
|
"""Bounding box coordinates.""" |
|
|
x: float = Field(..., description="X coordinate of top-left corner") |
|
|
y: float = Field(..., description="Y coordinate of top-left corner") |
|
|
width: float = Field(..., description="Width of bounding box") |
|
|
height: float = Field(..., description="Height of bounding box") |
|
|
|
|
|
|
|
|
class Detection(BaseModel): |
|
|
"""Single detection result.""" |
|
|
class_id: int = Field(..., description="Class ID of detected species") |
|
|
class_name: str = Field(..., description="Name of detected marine species") |
|
|
confidence: float = Field(..., ge=0.0, le=1.0, description="Detection confidence score") |
|
|
bbox: BoundingBox = Field(..., description="Bounding box coordinates") |
|
|
|
|
|
|
|
|
class ModelInfo(BaseModel): |
|
|
"""Model information.""" |
|
|
model_config = {"protected_namespaces": ()} |
|
|
|
|
|
model_name: str = Field(..., description="Name of the model") |
|
|
total_classes: int = Field(..., description="Total number of species classes") |
|
|
device: str = Field(..., description="Device used for inference") |
|
|
model_path: str = Field(..., description="Path to model file") |
|
|
|
|
|
|
|
|
class InferenceRequest(BaseModel): |
|
|
"""Request model for marine species detection.""" |
|
|
image: str = Field(..., description="Base64 encoded image data") |
|
|
confidence_threshold: float = Field( |
|
|
default=0.25, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="Confidence threshold for detections" |
|
|
) |
|
|
iou_threshold: float = Field( |
|
|
default=0.45, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="IoU threshold for non-maximum suppression" |
|
|
) |
|
|
image_size: int = Field( |
|
|
default=720, |
|
|
ge=320, |
|
|
le=1280, |
|
|
description="Input image size for inference" |
|
|
) |
|
|
return_annotated_image: bool = Field( |
|
|
default=True, |
|
|
description="Whether to return annotated image with detections" |
|
|
) |
|
|
classes: Optional[List[int]] = Field( |
|
|
default=None, |
|
|
description="List of class IDs to filter (None for all classes)" |
|
|
) |
|
|
|
|
|
@field_validator('image') |
|
|
@classmethod |
|
|
def validate_image(cls, v): |
|
|
"""Validate base64 image data.""" |
|
|
try: |
|
|
|
|
|
base64.b64decode(v) |
|
|
return v |
|
|
except Exception: |
|
|
raise ValueError("Invalid base64 image data") |
|
|
|
|
|
|
|
|
class InferenceResponse(BaseModel): |
|
|
"""Response model for marine species detection.""" |
|
|
model_config = {"protected_namespaces": ()} |
|
|
|
|
|
detections: List[Detection] = Field(..., description="List of detected marine species") |
|
|
annotated_image: Optional[str] = Field( |
|
|
default=None, |
|
|
description="Base64 encoded annotated image (if requested)" |
|
|
) |
|
|
processing_time: float = Field(..., description="Processing time in seconds") |
|
|
model_info: ModelInfo = Field(..., description="Information about the model used") |
|
|
image_dimensions: Dict[str, int] = Field( |
|
|
..., |
|
|
description="Original image dimensions (width, height)" |
|
|
) |
|
|
|
|
|
|
|
|
class SpeciesInfo(BaseModel): |
|
|
"""Information about a marine species.""" |
|
|
class_id: int = Field(..., description="Class ID") |
|
|
class_name: str = Field(..., description="Species name") |
|
|
|
|
|
|
|
|
class SpeciesListResponse(BaseModel): |
|
|
"""Response model for species list endpoint.""" |
|
|
species: List[SpeciesInfo] = Field(..., description="List of all supported marine species") |
|
|
total_count: int = Field(..., description="Total number of species") |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
"""Response model for health check.""" |
|
|
model_config = {"protected_namespaces": ()} |
|
|
|
|
|
status: str = Field(..., description="API status") |
|
|
model_loaded: bool = Field(..., description="Whether the model is loaded") |
|
|
model_info: Optional[ModelInfo] = Field(default=None, description="Model information") |
|
|
timestamp: str = Field(..., description="Response timestamp") |
|
|
|
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
|
"""Error response model.""" |
|
|
error: str = Field(..., description="Error type") |
|
|
message: str = Field(..., description="Error message") |
|
|
details: Optional[Dict[str, Any]] = Field(default=None, description="Additional error details") |
|
|
|
|
|
|
|
|
class APIInfo(BaseModel): |
|
|
"""API information response.""" |
|
|
model_config = {"protected_namespaces": ()} |
|
|
|
|
|
name: str = Field(..., description="API name") |
|
|
version: str = Field(..., description="API version") |
|
|
description: str = Field(..., description="API description") |
|
|
model_info: ModelInfo = Field(..., description="Model information") |
|
|
endpoints: List[str] = Field(..., description="Available endpoints") |
|
|
|