|
|
"""API controllers for request handling and validation.""" |
|
|
|
|
|
import asyncio |
|
|
from fastapi import HTTPException |
|
|
|
|
|
from app.core.logging import logger |
|
|
from app.services.base import InferenceService |
|
|
from app.api.models import ImageRequest, PredictionResponse |
|
|
|
|
|
|
|
|
class PredictionController: |
|
|
"""Controller for prediction endpoints.""" |
|
|
|
|
|
@staticmethod |
|
|
async def predict( |
|
|
request: ImageRequest, |
|
|
service: InferenceService |
|
|
) -> PredictionResponse: |
|
|
"""Run inference using the configured service.""" |
|
|
try: |
|
|
if not service or not service.is_loaded: |
|
|
raise HTTPException(503, "Service not available") |
|
|
|
|
|
if not request.image.mediaType.startswith('image/'): |
|
|
raise HTTPException(400, f"Invalid media type: {request.image.mediaType}") |
|
|
|
|
|
return await asyncio.to_thread(service.predict, request) |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except ValueError as e: |
|
|
logger.error(f"Invalid input: {e}") |
|
|
raise HTTPException(400, str(e)) |
|
|
except Exception as e: |
|
|
logger.error(f"Prediction failed: {e}") |
|
|
raise HTTPException(500, "Internal server error") |
|
|
|