|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
import os |
|
|
model_name = os.getenv("MODEL_NAME", "titanabrian/whatsapp-text-classifier-distilbert") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
labels = ["confirmation", "question", "task", "transactional", "unclassified"] |
|
|
|
|
|
class Message(BaseModel): |
|
|
text: str |
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"message": "WhatsApp Intent Classifier is ready."} |
|
|
|
|
|
@app.get("/model") |
|
|
def get_model(): |
|
|
return {"model_name": model_name} |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
def predict(msg: Message): |
|
|
inputs = tokenizer(msg.text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
pred_id = outputs.logits.argmax(dim=-1).item() |
|
|
return {"label": labels[pred_id], "label_id": pred_id} |