from fastapi import FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from pydantic import BaseModel, field_validator from app.limiter import limiter from app.model import classifier import time app = FastAPI( title="SDG Classifier API", description="Classifies text into UN Sustainable Development Goals", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) class ClassifyRequest(BaseModel): text: str top_k: int = 3 @field_validator("text") @classmethod def text_must_not_be_empty(cls, v): if not v.strip(): raise ValueError("text must not be empty") if len(v) > 2000: raise ValueError("text must be under 2000 characters") return v.strip() @field_validator("top_k") @classmethod def top_k_must_be_valid(cls, v): if not 1 <= v <= 5: raise ValueError("top_k must be between 1 and 5") return v class SDGResult(BaseModel): sdg: str name: str confidence: float class ClassifyResponse(BaseModel): text: str predictions: list[SDGResult] latency_ms: float warning: str | None = None @app.get("/") def root(): return {"status": "ok", "message": "SDG Classifier API is running"} @app.get("/health") def health(): return {"status": "healthy"} @app.post("/classify", response_model=ClassifyResponse, summary="Classify text into SDGs") @limiter.limit("20/minute") async def classify(request: Request, body: ClassifyRequest): start = time.time() try: predictions = classifier.predict(body.text, body.top_k) except Exception as e: raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}") latency = round((time.time() - start) * 1000, 2) warning = None if predictions[0]["confidence"] > 85 and predictions[1]["confidence"] < 5: warning = "Low prediction diversity — input may not be SDG-related text." return ClassifyResponse( text=body.text, predictions=predictions, latency_ms=latency, warning=warning )