File size: 4,104 Bytes
51e944e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Request
from fastapi.middleware import Middleware
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
import datetime
import time
from api_backend.configs import settings, logger
from api_backend.models import models, MODEL_REGISTRY, ModelNotFoundError, InvalidImageError
from api_backend.schemas import ApiResponse, HealthCheckResponse, ModelName
from api_backend.services import async_predict, preprocess_image

# Setup middleware
middleware = [
    Middleware(
        CORSMiddleware,
        allow_origins=settings.allowed_origins,
        allow_methods=["*"],
        allow_headers=["*"],
    )
]

if settings.enable_https_redirect:
    middleware.append(Middleware(HTTPSRedirectMiddleware))

# Create FastAPI app
app = FastAPI(
    title=settings.app_name,
    description="FastAPI backend for AI Image Classifier with multiple Keras models",
    version=settings.app_version,
    contact={
        "name": "Brian",
        "email": "brayann.8189@gmail.com",
    },
    license_info={
        "name": "MIT",
    },
    openapi_tags=[{
        "name": "predictions",
        "description": "Operations with image predictions",
    }],
    middleware=middleware
)

# Middleware
@app.middleware("http")
async def log_requests(request: Request, call_next):
    """Middleware to log request processing time."""
    start_time = time.time()
    response = await call_next(request)
    process_time = (time.time() - start_time) * 1000
    logger.info(
        f"Request: {request.method} {request.url} completed in {process_time:.2f}ms"
    )
    return response

# Exception Handlers
@app.exception_handler(ModelNotFoundError)
async def model_not_found_handler(request, exc):
    return JSONResponse(
        status_code=404,
        content={"message": str(exc)},
    )

@app.exception_handler(InvalidImageError)
async def invalid_image_handler(request, exc):
    return JSONResponse(
        status_code=400,
        content={"message": str(exc)},
    )

# Endpoints
@app.post("/predict", response_model=ApiResponse, tags=["predictions"])
async def predict(
    request: Request,
    file: UploadFile = File(...),
    model_name: ModelName = Query(..., description="Choose model for inference")
):
    if model_name.value not in models:
        logger.error(f"Model '{model_name}' not found in loaded models")
        raise ModelNotFoundError(
            f"Model '{model_name}' not available. Available options: {list(models.keys())}"
        )

    try:
        model = models[model_name.value]
        config = MODEL_REGISTRY[model_name.value]
        contents = await file.read()
        
        # Preprocess
        input_tensor = preprocess_image(contents, config["input_size"], config["preprocess"])

        # Inference
        start = time.time()
        predictions = await async_predict(model, input_tensor)
        end = time.time()

        # Decode predictions
        decoded = config["decode"](predictions, top=3)[0]
        results = [
            {"label": label.replace("_", " "), "confidence": round(float(score * 100), 2)}
            for (_, label, score) in decoded
        ]

        return {
            "predictions": results,
            "model_version": model_name.value,
            "inference_time": round(end - start, 4),
            "timestamp": datetime.datetime.utcnow().isoformat()
        }
    except InvalidImageError as e:
        raise
    except Exception as e:
        logger.error(f"Inference error: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
    
@app.get("/", include_in_schema=False)
def root():
    return {"message": "Image Classifier API is running."}

@app.get("/health", response_model=HealthCheckResponse, tags=["health"])
async def health_check():
    return {
        "status": "healthy",
        "models_loaded": list(models.keys()),
        "timestamp": datetime.datetime.utcnow().isoformat()
    }