File size: 2,309 Bytes
86509aa
 
 
 
 
 
 
 
 
8b003a0
86509aa
8b003a0
86509aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b003a0
86509aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b003a0
86509aa
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Dict
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re

# -----------------------------
# Load model from current folder
# -----------------------------
MODEL_PATH = "."  # we are in the repo root
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- Text normalization ----------
PUNCT_PATTERN = r"[\.!,?:;\"'β€β€œβ€™β€˜\-\–\β€”\(\)\[\]\{\}ΰ₯€]"

def normalize_bangla_text(text: str) -> str:
    if not isinstance(text, str):
        return ""
    text = " ".join(text.split())
    text = re.sub(PUNCT_PATTERN, " ", text)
    text = " ".join(text.split())
    return text

# ---------- Load tokenizer + model ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE)
model.eval()

id2label = model.config.id2label
label_list = [id2label[i] for i in range(len(id2label))]

SHIRK_LABEL = "shirk"
SHIRK_INDEX = label_list.index(SHIRK_LABEL)
SHIRK_THRESHOLD = 0.7  # tweak if needed

# ---------- FastAPI ----------
app = FastAPI(title="Bangla Shirk Classifier API")

class PredictRequest(BaseModel):
    text: str

class PredictResponse(BaseModel):
    label: str
    probabilities: Dict[str, float]

@app.get("/")
def root():
    return {"status": "running"}

@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
    text = normalize_bangla_text(req.text)

    enc = tokenizer(
        text,
        truncation=True,
        padding=True,
        max_length=64,
        return_tensors="pt"
    )
    enc = {k: v.to(DEVICE) for k, v in enc.items()}

    with torch.no_grad():
        outputs = model(**enc)
        logits = outputs.logits[0]
        probs = F.softmax(logits, dim=-1).cpu().numpy()

    # Shirk threshold logic
    top1 = int(probs.argmax())
    if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD:
        top2 = int(probs.argsort()[-2])
        pred_idx = top2
    else:
        pred_idx = top1

    prob_dict = {label_list[i]: float(probs[i]) for i in range(len(label_list))}

    return PredictResponse(
        label=label_list[pred_idx],
        probabilities=prob_dict
    )