Spaces:
Sleeping
Sleeping
File size: 4,104 Bytes
51e944e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Request
from fastapi.middleware import Middleware
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
import datetime
import time
from api_backend.configs import settings, logger
from api_backend.models import models, MODEL_REGISTRY, ModelNotFoundError, InvalidImageError
from api_backend.schemas import ApiResponse, HealthCheckResponse, ModelName
from api_backend.services import async_predict, preprocess_image
# Setup middleware
middleware = [
Middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins,
allow_methods=["*"],
allow_headers=["*"],
)
]
if settings.enable_https_redirect:
middleware.append(Middleware(HTTPSRedirectMiddleware))
# Create FastAPI app
app = FastAPI(
title=settings.app_name,
description="FastAPI backend for AI Image Classifier with multiple Keras models",
version=settings.app_version,
contact={
"name": "Brian",
"email": "brayann.8189@gmail.com",
},
license_info={
"name": "MIT",
},
openapi_tags=[{
"name": "predictions",
"description": "Operations with image predictions",
}],
middleware=middleware
)
# Middleware
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Middleware to log request processing time."""
start_time = time.time()
response = await call_next(request)
process_time = (time.time() - start_time) * 1000
logger.info(
f"Request: {request.method} {request.url} completed in {process_time:.2f}ms"
)
return response
# Exception Handlers
@app.exception_handler(ModelNotFoundError)
async def model_not_found_handler(request, exc):
return JSONResponse(
status_code=404,
content={"message": str(exc)},
)
@app.exception_handler(InvalidImageError)
async def invalid_image_handler(request, exc):
return JSONResponse(
status_code=400,
content={"message": str(exc)},
)
# Endpoints
@app.post("/predict", response_model=ApiResponse, tags=["predictions"])
async def predict(
request: Request,
file: UploadFile = File(...),
model_name: ModelName = Query(..., description="Choose model for inference")
):
if model_name.value not in models:
logger.error(f"Model '{model_name}' not found in loaded models")
raise ModelNotFoundError(
f"Model '{model_name}' not available. Available options: {list(models.keys())}"
)
try:
model = models[model_name.value]
config = MODEL_REGISTRY[model_name.value]
contents = await file.read()
# Preprocess
input_tensor = preprocess_image(contents, config["input_size"], config["preprocess"])
# Inference
start = time.time()
predictions = await async_predict(model, input_tensor)
end = time.time()
# Decode predictions
decoded = config["decode"](predictions, top=3)[0]
results = [
{"label": label.replace("_", " "), "confidence": round(float(score * 100), 2)}
for (_, label, score) in decoded
]
return {
"predictions": results,
"model_version": model_name.value,
"inference_time": round(end - start, 4),
"timestamp": datetime.datetime.utcnow().isoformat()
}
except InvalidImageError as e:
raise
except Exception as e:
logger.error(f"Inference error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
@app.get("/", include_in_schema=False)
def root():
return {"message": "Image Classifier API is running."}
@app.get("/health", response_model=HealthCheckResponse, tags=["health"])
async def health_check():
return {
"status": "healthy",
"models_loaded": list(models.keys()),
"timestamp": datetime.datetime.utcnow().isoformat()
} |