AutoClassifier / main.py
Yousuf-Islam's picture
Update main.py
8b003a0 verified
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Dict
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re
# -----------------------------
# Load model from current folder
# -----------------------------
MODEL_PATH = "." # we are in the repo root
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ---------- Text normalization ----------
PUNCT_PATTERN = r"[\.!,?:;\"'β€β€œβ€™β€˜\-\–\β€”\(\)\[\]\{\}ΰ₯€]"
def normalize_bangla_text(text: str) -> str:
if not isinstance(text, str):
return ""
text = " ".join(text.split())
text = re.sub(PUNCT_PATTERN, " ", text)
text = " ".join(text.split())
return text
# ---------- Load tokenizer + model ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE)
model.eval()
id2label = model.config.id2label
label_list = [id2label[i] for i in range(len(id2label))]
SHIRK_LABEL = "shirk"
SHIRK_INDEX = label_list.index(SHIRK_LABEL)
SHIRK_THRESHOLD = 0.7 # tweak if needed
# ---------- FastAPI ----------
app = FastAPI(title="Bangla Shirk Classifier API")
class PredictRequest(BaseModel):
text: str
class PredictResponse(BaseModel):
label: str
probabilities: Dict[str, float]
@app.get("/")
def root():
return {"status": "running"}
@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
text = normalize_bangla_text(req.text)
enc = tokenizer(
text,
truncation=True,
padding=True,
max_length=64,
return_tensors="pt"
)
enc = {k: v.to(DEVICE) for k, v in enc.items()}
with torch.no_grad():
outputs = model(**enc)
logits = outputs.logits[0]
probs = F.softmax(logits, dim=-1).cpu().numpy()
# Shirk threshold logic
top1 = int(probs.argmax())
if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD:
top2 = int(probs.argsort()[-2])
pred_idx = top2
else:
pred_idx = top1
prob_dict = {label_list[i]: float(probs[i]) for i in range(len(label_list))}
return PredictResponse(
label=label_list[pred_idx],
probabilities=prob_dict
)