File size: 1,089 Bytes
92885db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

app = FastAPI()

# Load tokenizer and model from Hugging Face Hub
import os
model_name = os.getenv("MODEL_NAME", "titanabrian/whatsapp-text-classifier-distilbert")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

# Labels must match your model training
labels = ["confirmation", "question", "task", "transactional", "unclassified"]

class Message(BaseModel):
    text: str

@app.get("/")
def root():
    return {"message": "WhatsApp Intent Classifier is ready."}

@app.get("/model")
def get_model():
    return {"model_name": model_name}


@app.post("/predict")
def predict(msg: Message):
    inputs = tokenizer(msg.text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        pred_id = outputs.logits.argmax(dim=-1).item()
        return {"label": labels[pred_id], "label_id": pred_id}