sdg-api / app /main.py
MakPr016
SDG Api
63f5626
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from pydantic import BaseModel, field_validator
from app.limiter import limiter
from app.model import classifier
import time
app = FastAPI(
title="SDG Classifier API",
description="Classifies text into UN Sustainable Development Goals",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
class ClassifyRequest(BaseModel):
text: str
top_k: int = 3
@field_validator("text")
@classmethod
def text_must_not_be_empty(cls, v):
if not v.strip():
raise ValueError("text must not be empty")
if len(v) > 2000:
raise ValueError("text must be under 2000 characters")
return v.strip()
@field_validator("top_k")
@classmethod
def top_k_must_be_valid(cls, v):
if not 1 <= v <= 5:
raise ValueError("top_k must be between 1 and 5")
return v
class SDGResult(BaseModel):
sdg: str
name: str
confidence: float
class ClassifyResponse(BaseModel):
text: str
predictions: list[SDGResult]
latency_ms: float
warning: str | None = None
@app.get("/")
def root():
return {"status": "ok", "message": "SDG Classifier API is running"}
@app.get("/health")
def health():
return {"status": "healthy"}
@app.post("/classify", response_model=ClassifyResponse, summary="Classify text into SDGs")
@limiter.limit("20/minute")
async def classify(request: Request, body: ClassifyRequest):
start = time.time()
try:
predictions = classifier.predict(body.text, body.top_k)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
latency = round((time.time() - start) * 1000, 2)
warning = None
if predictions[0]["confidence"] > 85 and predictions[1]["confidence"] < 5:
warning = "Low prediction diversity — input may not be SDG-related text."
return ClassifyResponse(
text=body.text,
predictions=predictions,
latency_ms=latency,
warning=warning
)