from fastapi import FastAPI from pydantic import BaseModel from typing import Dict import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification import re # ----------------------------- # Load model from current folder # ----------------------------- MODEL_PATH = "." # we are in the repo root DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ---------- Text normalization ---------- PUNCT_PATTERN = r"[\.!,?:;\"'”“’‘\-\–\—\(\)\[\]\{\}।]" def normalize_bangla_text(text: str) -> str: if not isinstance(text, str): return "" text = " ".join(text.split()) text = re.sub(PUNCT_PATTERN, " ", text) text = " ".join(text.split()) return text # ---------- Load tokenizer + model ---------- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE) model.eval() id2label = model.config.id2label label_list = [id2label[i] for i in range(len(id2label))] SHIRK_LABEL = "shirk" SHIRK_INDEX = label_list.index(SHIRK_LABEL) SHIRK_THRESHOLD = 0.7 # tweak if needed # ---------- FastAPI ---------- app = FastAPI(title="Bangla Shirk Classifier API") class PredictRequest(BaseModel): text: str class PredictResponse(BaseModel): label: str probabilities: Dict[str, float] @app.get("/") def root(): return {"status": "running"} @app.post("/predict", response_model=PredictResponse) def predict(req: PredictRequest): text = normalize_bangla_text(req.text) enc = tokenizer( text, truncation=True, padding=True, max_length=64, return_tensors="pt" ) enc = {k: v.to(DEVICE) for k, v in enc.items()} with torch.no_grad(): outputs = model(**enc) logits = outputs.logits[0] probs = F.softmax(logits, dim=-1).cpu().numpy() # Shirk threshold logic top1 = int(probs.argmax()) if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD: top2 = int(probs.argsort()[-2]) pred_idx = top2 else: pred_idx = top1 prob_dict = {label_list[i]: float(probs[i]) for i in range(len(label_list))} return PredictResponse( label=label_list[pred_idx], probabilities=prob_dict )