Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import Dict | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import re | |
| # ----------------------------- | |
| # Load model from current folder | |
| # ----------------------------- | |
| MODEL_PATH = "." # we are in the repo root | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------- Text normalization ---------- | |
| PUNCT_PATTERN = r"[\.!,?:;\"'ββββ\-\β\β\(\)\[\]\{\}ΰ₯€]" | |
| def normalize_bangla_text(text: str) -> str: | |
| if not isinstance(text, str): | |
| return "" | |
| text = " ".join(text.split()) | |
| text = re.sub(PUNCT_PATTERN, " ", text) | |
| text = " ".join(text.split()) | |
| return text | |
| # ---------- Load tokenizer + model ---------- | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE) | |
| model.eval() | |
| id2label = model.config.id2label | |
| label_list = [id2label[i] for i in range(len(id2label))] | |
| SHIRK_LABEL = "shirk" | |
| SHIRK_INDEX = label_list.index(SHIRK_LABEL) | |
| SHIRK_THRESHOLD = 0.7 # tweak if needed | |
| # ---------- FastAPI ---------- | |
| app = FastAPI(title="Bangla Shirk Classifier API") | |
| class PredictRequest(BaseModel): | |
| text: str | |
| class PredictResponse(BaseModel): | |
| label: str | |
| probabilities: Dict[str, float] | |
| def root(): | |
| return {"status": "running"} | |
| def predict(req: PredictRequest): | |
| text = normalize_bangla_text(req.text) | |
| enc = tokenizer( | |
| text, | |
| truncation=True, | |
| padding=True, | |
| max_length=64, | |
| return_tensors="pt" | |
| ) | |
| enc = {k: v.to(DEVICE) for k, v in enc.items()} | |
| with torch.no_grad(): | |
| outputs = model(**enc) | |
| logits = outputs.logits[0] | |
| probs = F.softmax(logits, dim=-1).cpu().numpy() | |
| # Shirk threshold logic | |
| top1 = int(probs.argmax()) | |
| if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD: | |
| top2 = int(probs.argsort()[-2]) | |
| pred_idx = top2 | |
| else: | |
| pred_idx = top1 | |
| prob_dict = {label_list[i]: float(probs[i]) for i in range(len(label_list))} | |
| return PredictResponse( | |
| label=label_list[pred_idx], | |
| probabilities=prob_dict | |
| ) | |