Spaces:
Sleeping
Sleeping
File size: 2,309 Bytes
86509aa 8b003a0 86509aa 8b003a0 86509aa 8b003a0 86509aa 8b003a0 86509aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | from fastapi import FastAPI
from pydantic import BaseModel
from typing import Dict
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re
# -----------------------------
# Load model from current folder
# -----------------------------
MODEL_PATH = "." # we are in the repo root
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ---------- Text normalization ----------
PUNCT_PATTERN = r"[\.!,?:;\"'ββββ\-\β\β\(\)\[\]\{\}ΰ₯€]"
def normalize_bangla_text(text: str) -> str:
if not isinstance(text, str):
return ""
text = " ".join(text.split())
text = re.sub(PUNCT_PATTERN, " ", text)
text = " ".join(text.split())
return text
# ---------- Load tokenizer + model ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE)
model.eval()
id2label = model.config.id2label
label_list = [id2label[i] for i in range(len(id2label))]
SHIRK_LABEL = "shirk"
SHIRK_INDEX = label_list.index(SHIRK_LABEL)
SHIRK_THRESHOLD = 0.7 # tweak if needed
# ---------- FastAPI ----------
app = FastAPI(title="Bangla Shirk Classifier API")
class PredictRequest(BaseModel):
text: str
class PredictResponse(BaseModel):
label: str
probabilities: Dict[str, float]
@app.get("/")
def root():
return {"status": "running"}
@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
text = normalize_bangla_text(req.text)
enc = tokenizer(
text,
truncation=True,
padding=True,
max_length=64,
return_tensors="pt"
)
enc = {k: v.to(DEVICE) for k, v in enc.items()}
with torch.no_grad():
outputs = model(**enc)
logits = outputs.logits[0]
probs = F.softmax(logits, dim=-1).cpu().numpy()
# Shirk threshold logic
top1 = int(probs.argmax())
if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD:
top2 = int(probs.argsort()[-2])
pred_idx = top2
else:
pred_idx = top1
prob_dict = {label_list[i]: float(probs[i]) for i in range(len(label_list))}
return PredictResponse(
label=label_list[pred_idx],
probabilities=prob_dict
)
|