File size: 2,457 Bytes
63f5626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from pydantic import BaseModel, field_validator
from app.limiter import limiter
from app.model import classifier
import time

app = FastAPI(
    title="SDG Classifier API",
    description="Classifies text into UN Sustainable Development Goals",
    version="1.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)


class ClassifyRequest(BaseModel):
    text: str
    top_k: int = 3

    @field_validator("text")
    @classmethod
    def text_must_not_be_empty(cls, v):
        if not v.strip():
            raise ValueError("text must not be empty")
        if len(v) > 2000:
            raise ValueError("text must be under 2000 characters")
        return v.strip()

    @field_validator("top_k")
    @classmethod
    def top_k_must_be_valid(cls, v):
        if not 1 <= v <= 5:
            raise ValueError("top_k must be between 1 and 5")
        return v


class SDGResult(BaseModel):
    sdg:        str
    name:       str
    confidence: float


class ClassifyResponse(BaseModel):
    text:        str
    predictions: list[SDGResult]
    latency_ms:  float
    warning:     str | None = None


@app.get("/")
def root():
    return {"status": "ok", "message": "SDG Classifier API is running"}


@app.get("/health")
def health():
    return {"status": "healthy"}


@app.post("/classify", response_model=ClassifyResponse, summary="Classify text into SDGs")
@limiter.limit("20/minute")
async def classify(request: Request, body: ClassifyRequest):
    start = time.time()

    try:
        predictions = classifier.predict(body.text, body.top_k)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")

    latency = round((time.time() - start) * 1000, 2)

    warning = None
    if predictions[0]["confidence"] > 85 and predictions[1]["confidence"] < 5:
        warning = "Low prediction diversity — input may not be SDG-related text."

    return ClassifyResponse(
        text=body.text,
        predictions=predictions,
        latency_ms=latency,
        warning=warning
    )