fishapi / app /api /v1 /endpoints /inference.py
kamau1's picture
Initial commit
bcc2f7b verified
"""
Marine species detection inference endpoints.
"""
from fastapi import APIRouter, HTTPException, status, Depends
from typing import List
from app.models.inference import (
InferenceRequest,
InferenceResponse,
SpeciesListResponse,
SpeciesInfo,
ErrorResponse
)
from app.services.inference_service import inference_service
from app.api.dependencies import validate_model_health
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter()
@router.post(
"/detect",
response_model=InferenceResponse,
responses={
400: {"model": ErrorResponse, "description": "Bad Request"},
503: {"model": ErrorResponse, "description": "Service Unavailable"},
},
summary="Detect Marine Species",
description="Detect and identify marine species in an uploaded image using YOLOv5 model"
)
async def detect_marine_species(
request: InferenceRequest,
_: bool = Depends(validate_model_health)
) -> InferenceResponse:
"""
Detect marine species in an image.
- **image**: Base64 encoded image data
- **confidence_threshold**: Minimum confidence for detections (0.0-1.0)
- **iou_threshold**: IoU threshold for non-maximum suppression (0.0-1.0)
- **image_size**: Input image size for inference (320-1280)
- **return_annotated_image**: Whether to return annotated image with bounding boxes
- **classes**: Optional list of class IDs to filter detections
Returns detection results with bounding boxes, confidence scores, and species names.
"""
try:
logger.info("Processing marine species detection request")
result = await inference_service.detect_species(
image_data=request.image,
confidence_threshold=request.confidence_threshold,
iou_threshold=request.iou_threshold,
image_size=request.image_size,
return_annotated_image=request.return_annotated_image,
classes=request.classes
)
logger.info(f"Detection completed: {len(result.detections)} species found")
return result
except ValueError as e:
logger.error(f"Invalid input data: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid input: {str(e)}"
)
except Exception as e:
logger.error(f"Detection failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Detection failed: {str(e)}"
)
@router.get(
"/species",
response_model=SpeciesListResponse,
summary="List Supported Species",
description="Get a list of all marine species that can be detected by the model"
)
async def list_supported_species(
_: bool = Depends(validate_model_health)
) -> SpeciesListResponse:
"""
Get a list of all supported marine species.
Returns a comprehensive list of all marine species that the model can detect,
including their class IDs and scientific/common names.
"""
try:
logger.info("Fetching supported species list")
species_data = await inference_service.get_supported_species()
species_list = [
SpeciesInfo(class_id=item["class_id"], class_name=item["class_name"])
for item in species_data
]
return SpeciesListResponse(
species=species_list,
total_count=len(species_list)
)
except Exception as e:
logger.error(f"Failed to fetch species list: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch species list: {str(e)}"
)
@router.get(
"/species/{class_id}",
response_model=SpeciesInfo,
responses={
404: {"model": ErrorResponse, "description": "Species Not Found"},
},
summary="Get Species Information",
description="Get information about a specific marine species by class ID"
)
async def get_species_info(
class_id: int,
_: bool = Depends(validate_model_health)
) -> SpeciesInfo:
"""
Get information about a specific marine species.
- **class_id**: The class ID of the species to look up
Returns detailed information about the specified marine species.
"""
try:
logger.info(f"Fetching species info for class_id: {class_id}")
species_data = await inference_service.get_supported_species()
# Find the species with the given class_id
for species in species_data:
if species["class_id"] == class_id:
return SpeciesInfo(
class_id=species["class_id"],
class_name=species["class_name"]
)
# Species not found
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Species with class_id {class_id} not found"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to fetch species info: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch species info: {str(e)}"
)