File size: 788 Bytes
5fa76ab
 
358adf7
5fa76ab
 
 
358adf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fa76ab
63b4fe7
358adf7
 
5fa76ab
358adf7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline

app = FastAPI()

# Load model once
classifier = pipeline(
    "zero-shot-classification",
    model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
)

# Your classes
CANDIDATE_LABELS = [
    "Garbage issue",
    "Streetlight not working",
    "Road damage",
    "Water supply issue",
    "Noise pollution",
    "Flooding",
    "Corruption",
    "Other"
]

class Query(BaseModel):
    text: str

@app.post("/predict")
def predict(query: Query):
    result = classifier(
        query.text,
        candidate_labels=CANDIDATE_LABELS,
        multi_label=False   # <-- single best class
    )

    # Top-1 predicted label
    predicted_label = result["labels"][0]

    return {"label": predicted_label}