Yousuf-Islam commited on
Commit
8b003a0
·
verified ·
1 Parent(s): 3495248

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +4 -4
main.py CHANGED
@@ -7,9 +7,9 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
  import re
8
 
9
  # -----------------------------
10
- # Load model from local folder
11
  # -----------------------------
12
- MODEL_PATH = "final_shirk_classifier"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # ---------- Text normalization ----------
@@ -33,7 +33,7 @@ label_list = [id2label[i] for i in range(len(id2label))]
33
 
34
  SHIRK_LABEL = "shirk"
35
  SHIRK_INDEX = label_list.index(SHIRK_LABEL)
36
- SHIRK_THRESHOLD = 0.7
37
 
38
  # ---------- FastAPI ----------
39
  app = FastAPI(title="Bangla Shirk Classifier API")
@@ -67,7 +67,7 @@ def predict(req: PredictRequest):
67
  logits = outputs.logits[0]
68
  probs = F.softmax(logits, dim=-1).cpu().numpy()
69
 
70
- # apply threshold logic
71
  top1 = int(probs.argmax())
72
  if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD:
73
  top2 = int(probs.argsort()[-2])
 
7
  import re
8
 
9
  # -----------------------------
10
+ # Load model from current folder
11
  # -----------------------------
12
+ MODEL_PATH = "." # we are in the repo root
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # ---------- Text normalization ----------
 
33
 
34
  SHIRK_LABEL = "shirk"
35
  SHIRK_INDEX = label_list.index(SHIRK_LABEL)
36
+ SHIRK_THRESHOLD = 0.7 # tweak if needed
37
 
38
  # ---------- FastAPI ----------
39
  app = FastAPI(title="Bangla Shirk Classifier API")
 
67
  logits = outputs.logits[0]
68
  probs = F.softmax(logits, dim=-1).cpu().numpy()
69
 
70
+ # Shirk threshold logic
71
  top1 = int(probs.argmax())
72
  if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD:
73
  top2 = int(probs.argsort()[-2])