| from fastapi import FastAPI, Request, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
| from slowapi import _rate_limit_exceeded_handler |
| from slowapi.errors import RateLimitExceeded |
| from pydantic import BaseModel, field_validator |
| from app.limiter import limiter |
| from app.model import classifier |
| import time |
|
|
| app = FastAPI( |
| title="SDG Classifier API", |
| description="Classifies text into UN Sustainable Development Goals", |
| version="1.0.0" |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| app.state.limiter = limiter |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) |
|
|
|
|
| class ClassifyRequest(BaseModel): |
| text: str |
| top_k: int = 3 |
|
|
| @field_validator("text") |
| @classmethod |
| def text_must_not_be_empty(cls, v): |
| if not v.strip(): |
| raise ValueError("text must not be empty") |
| if len(v) > 2000: |
| raise ValueError("text must be under 2000 characters") |
| return v.strip() |
|
|
| @field_validator("top_k") |
| @classmethod |
| def top_k_must_be_valid(cls, v): |
| if not 1 <= v <= 5: |
| raise ValueError("top_k must be between 1 and 5") |
| return v |
|
|
|
|
| class SDGResult(BaseModel): |
| sdg: str |
| name: str |
| confidence: float |
|
|
|
|
| class ClassifyResponse(BaseModel): |
| text: str |
| predictions: list[SDGResult] |
| latency_ms: float |
| warning: str | None = None |
|
|
|
|
| @app.get("/") |
| def root(): |
| return {"status": "ok", "message": "SDG Classifier API is running"} |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "healthy"} |
|
|
|
|
| @app.post("/classify", response_model=ClassifyResponse, summary="Classify text into SDGs") |
| @limiter.limit("20/minute") |
| async def classify(request: Request, body: ClassifyRequest): |
| start = time.time() |
|
|
| try: |
| predictions = classifier.predict(body.text, body.top_k) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}") |
|
|
| latency = round((time.time() - start) * 1000, 2) |
|
|
| warning = None |
| if predictions[0]["confidence"] > 85 and predictions[1]["confidence"] < 5: |
| warning = "Low prediction diversity — input may not be SDG-related text." |
|
|
| return ClassifyResponse( |
| text=body.text, |
| predictions=predictions, |
| latency_ms=latency, |
| warning=warning |
| ) |