Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -7,9 +7,9 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
| 7 |
import re
|
| 8 |
|
| 9 |
# -----------------------------
|
| 10 |
-
# Load model from
|
| 11 |
# -----------------------------
|
| 12 |
-
MODEL_PATH = "
|
| 13 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
|
| 15 |
# ---------- Text normalization ----------
|
|
@@ -33,7 +33,7 @@ label_list = [id2label[i] for i in range(len(id2label))]
|
|
| 33 |
|
| 34 |
SHIRK_LABEL = "shirk"
|
| 35 |
SHIRK_INDEX = label_list.index(SHIRK_LABEL)
|
| 36 |
-
SHIRK_THRESHOLD = 0.7
|
| 37 |
|
| 38 |
# ---------- FastAPI ----------
|
| 39 |
app = FastAPI(title="Bangla Shirk Classifier API")
|
|
@@ -67,7 +67,7 @@ def predict(req: PredictRequest):
|
|
| 67 |
logits = outputs.logits[0]
|
| 68 |
probs = F.softmax(logits, dim=-1).cpu().numpy()
|
| 69 |
|
| 70 |
-
#
|
| 71 |
top1 = int(probs.argmax())
|
| 72 |
if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD:
|
| 73 |
top2 = int(probs.argsort()[-2])
|
|
|
|
| 7 |
import re
|
| 8 |
|
| 9 |
# -----------------------------
|
| 10 |
+
# Load model from current folder
|
| 11 |
# -----------------------------
|
| 12 |
+
MODEL_PATH = "." # we are in the repo root
|
| 13 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
|
| 15 |
# ---------- Text normalization ----------
|
|
|
|
| 33 |
|
| 34 |
SHIRK_LABEL = "shirk"
|
| 35 |
SHIRK_INDEX = label_list.index(SHIRK_LABEL)
|
| 36 |
+
SHIRK_THRESHOLD = 0.7 # tweak if needed
|
| 37 |
|
| 38 |
# ---------- FastAPI ----------
|
| 39 |
app = FastAPI(title="Bangla Shirk Classifier API")
|
|
|
|
| 67 |
logits = outputs.logits[0]
|
| 68 |
probs = F.softmax(logits, dim=-1).cpu().numpy()
|
| 69 |
|
| 70 |
+
# Shirk threshold logic
|
| 71 |
top1 = int(probs.argmax())
|
| 72 |
if top1 == SHIRK_INDEX and probs[SHIRK_INDEX] < SHIRK_THRESHOLD:
|
| 73 |
top2 = int(probs.argsort()[-2])
|