Yousuf-Islam commited on
Commit
86509aa
·
verified ·
1 Parent(s): a29cb39

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +83 -0
main.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import Dict
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ import re
8
+
9
+ # -----------------------------
10
+ # Load model from local folder
11
+ # -----------------------------
12
+ MODEL_PATH = "final_shirk_classifier"
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # ---------- Text normalization ----------
16
+ PUNCT_PATTERN = r"[\.!,?:;\"'”“’‘\-\–\—\(\)\[\]\{\}।]"
17
+
18
+ def normalize_bangla_text(text: str) -> str:
19
+ if not isinstance(text, str):
20
+ return ""
21
+ text = " ".join(text.split())
22
+ text = re.sub(PUNCT_PATTERN, " ", text)
23
+ text = " ".join(text.split())
24
+ return text
25
+
26
+ # ---------- Load tokenizer + model ----------
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
28
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH).to(DEVICE)
29
+ model.eval()
30
+
31
+ id2label = model.config.id2label
32
+ label_list = [id2label[i] for i in range(len(id2label))]
33
+
34
+ SHIRK_LABEL = "shirk"
35
+ SHIRK_INDEX = label_list.index(SHIRK_LABEL)
36
+ SHIRK_THRESHOLD = 0.7
37
+
38
+ # ---------- FastAPI ----------
39
+ app = FastAPI(title="Bangla Shirk Classifier API")
40
+
41
+ class PredictRequest(BaseModel):
42
+ text: str
43
+
44
+ class PredictResponse(BaseModel):
45
+ label: str
46
+ probabilities: Dict[str, float]
47
+
48
+ @app.get("/")
49
+ def root():
50
+ return {"status": "running"}
51
+
52
+ @app.post("/predict", response_model=PredictResponse)
53
+ def predict(req: PredictRequest):
54
+ text = normalize_bangla_text(req.text)
55
+
56
+ enc = tokenizer(
57
+ text,
58
+ truncation=True,
59
+ padding=True,
60
+ max_length=64,
61
+ return_tensors="pt"
62
+ )
63
+ enc = {k: v.to(DEVICE) for k, v in enc.items()}
64
+
65
+ with torch.no_grad():
66
+ outputs = model(**enc)
67
+ logits = outputs.logits[0]
68
+ probs = F.softmax(logits, dim=-1).cpu().numpy()
69
+
70
+ # apply threshold logic
71
+ top1 = int(probs.argmax())
72
+ if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD:
73
+ top2 = int(probs.argsort()[-2])
74
+ pred_idx = top2
75
+ else:
76
+ pred_idx = top1
77
+
78
+ prob_dict = {label_list[i]: float(probs[i]) for i in range(len(label_list))}
79
+
80
+ return PredictResponse(
81
+ label=label_list[pred_idx],
82
+ probabilities=prob_dict
83
+ )