Yousuf-Islam commited on
Commit
c275311
·
verified ·
1 Parent(s): 7ea2227

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -44
app.py CHANGED
@@ -1,70 +1,106 @@
1
  import torch
2
- from fastapi import FastAPI
 
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
- import torch.nn.functional as F
6
 
 
 
 
7
  app = FastAPI()
8
 
9
- # 1. Load Model
10
- MODEL_PATH = "."
11
- device = torch.device("cpu")
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  print("Loading model...")
14
  try:
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
16
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
17
  model.to(device)
18
  model.eval()
19
- print("Model loaded successfully!")
20
  except Exception as e:
21
- print(f"CRITICAL ERROR LOADING MODEL: {e}")
22
-
23
- # 2. FORCE LABEL MAPPING (The Fix)
24
- # We strictly define the labels here to match your training:
25
- # 0 -> neutral, 1 -> not_shirk, 2 -> shirk (Alphabetical Order)
26
- ID2LABEL = {
27
- 0: "neutral",
28
- 1: "not_shirk",
29
- 2: "shirk"
30
- }
31
 
 
 
 
32
  class TextRequest(BaseModel):
33
  text: str
34
 
 
 
 
 
35
  @app.get("/")
36
  def home():
37
- return {"status": "online", "message": "BanglaBERT Shirk Detector"}
38
 
39
  @app.post("/predict")
40
  def predict(request: TextRequest):
41
- # Tokenize
42
- inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=128, padding=True)
43
- inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
 
 
 
44
 
45
- # Predict
46
- with torch.no_grad():
47
- outputs = model(**inputs)
48
- probs = F.softmax(outputs.logits, dim=1)
49
-
50
- # Get Winner Index
51
- pred_idx = torch.argmax(probs, dim=1).item()
52
-
53
- # ✅ FORCE CORRECT LABEL NAME
54
- # We ignore the model's internal config and use our manual map
55
- pred_label = ID2LABEL[pred_idx]
56
 
57
- confidence = probs[0][pred_idx].item()
58
-
59
- # Get All Scores with Correct Names
60
- scores = {}
61
- for i in range(len(probs[0])):
62
- label_name = ID2LABEL[i]
63
- scores[label_name] = float(probs[0][i])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- return {
66
- "text": request.text,
67
- "label": pred_label,
68
- "confidence": confidence,
69
- "scores": scores
70
- }
 
1
  import torch
2
+ import torch.nn.functional as F
3
+ from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
6
 
7
+ # ==========================================
8
+ # 1. SETUP & CONFIGURATION
9
+ # ==========================================
10
  app = FastAPI()
11
 
12
+ # Define the path to the model files (Root directory)
13
+ MODEL_PATH = "."
14
+ device = torch.device("cpu") # Hugging Face Spaces (Free Tier) uses CPU
15
 
16
+ # MANUAL LABEL MAPPING (Safety Net)
17
+ # Use this to fix any confusion between Red/Green results.
18
+ # Adjust these indices if your model predicts the wrong class.
19
+ ID2LABEL_MANUAL = {
20
+ 0: "neutral",
21
+ 1: "not_shirk",
22
+ 2: "shirk"
23
+ }
24
+
25
+ # ==========================================
26
+ # 2. LOAD MODEL
27
+ # ==========================================
28
  print("Loading model...")
29
  try:
30
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
31
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
32
  model.to(device)
33
  model.eval()
34
+ print("Model loaded successfully!")
35
  except Exception as e:
36
+ print(f"CRITICAL ERROR LOADING MODEL: {e}")
37
+ # We do not raise an error here so the app can still start and show logs,
38
+ # but predictions will fail if model is None.
 
 
 
 
 
 
 
39
 
40
+ # ==========================================
41
+ # 3. INPUT SCHEMA
42
+ # ==========================================
43
  class TextRequest(BaseModel):
44
  text: str
45
 
46
+ # ==========================================
47
+ # 4. API ENDPOINTS
48
+ # ==========================================
49
+
50
  @app.get("/")
51
  def home():
52
+ return {"status": "online", "system": "Dockerized BanglaBERT API"}
53
 
54
  @app.post("/predict")
55
  def predict(request: TextRequest):
56
+ try:
57
+ # 1. Tokenize Input
58
+ inputs = tokenizer(
59
+ request.text,
60
+ return_tensors="pt",
61
+ truncation=True,
62
+ max_length=128,
63
+ padding=True
64
+ )
65
+ inputs = {k: v.to(device) for k, v in inputs.items()}
66
 
67
+ # 2. Perform Inference
68
+ with torch.no_grad():
69
+ outputs = model(**inputs)
70
+ probs = F.softmax(outputs.logits, dim=1)
 
 
 
 
 
 
 
71
 
72
+ # 3. Determine Winner
73
+ pred_idx = torch.argmax(probs, dim=1).item()
74
+ confidence = probs[0][pred_idx].item()
75
+
76
+ # 4. Resolve Label Name
77
+ # Priority: Try model config first, fall back to manual map if missing
78
+ if model.config.id2label and len(model.config.id2label) > 0:
79
+ # Handle potential string/int key mismatch in config
80
+ pred_label = model.config.id2label.get(pred_idx, model.config.id2label.get(str(pred_idx)))
81
+
82
+ # Fallback if config is empty or failed
83
+ if not pred_label:
84
+ pred_label = ID2LABEL_MANUAL.get(pred_idx, "unknown")
85
+
86
+ # 5. Format All Scores
87
+ scores = {}
88
+ for i in range(len(probs[0])):
89
+ # Get label name for this index
90
+ if model.config.id2label:
91
+ lbl = model.config.id2label.get(i, model.config.id2label.get(str(i)))
92
+ else:
93
+ lbl = ID2LABEL_MANUAL.get(i, f"LABEL_{i}")
94
+
95
+ scores[lbl] = float(probs[0][i])
96
+
97
+ return {
98
+ "text": request.text,
99
+ "label": pred_label,
100
+ "confidence": confidence,
101
+ "scores": scores
102
+ }
103
 
104
+ except Exception as e:
105
+ print(f"Prediction Error: {e}")
106
+ return {"error": str(e)}