File size: 3,419 Bytes
cb7d52f
c275311
 
cb7d52f
 
 
c275311
 
 
cb7d52f
 
c275311
 
 
cb7d52f
c275311
 
 
 
 
 
 
 
 
 
 
 
cb7d52f
 
 
 
 
 
c275311
cb7d52f
c275311
 
 
7ea2227
c275311
 
 
cb7d52f
 
 
c275311
 
 
 
cb7d52f
 
c275311
cb7d52f
 
 
c275311
 
 
 
 
 
 
 
 
 
cb7d52f
c275311
 
 
 
7ea2227
c275311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb7d52f
c275311
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ==========================================
# 1. SETUP & CONFIGURATION
# ==========================================
app = FastAPI()

# Define the path to the model files (Root directory)
MODEL_PATH = "."
device = torch.device("cpu")  # Hugging Face Spaces (Free Tier) uses CPU

# MANUAL LABEL MAPPING (Safety Net)
# Use this to fix any confusion between Red/Green results.
# Adjust these indices if your model predicts the wrong class.
ID2LABEL_MANUAL = {
    0: "neutral",
    1: "not_shirk",
    2: "shirk"
}

# ==========================================
# 2. LOAD MODEL
# ==========================================
print("Loading model...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
    model.to(device)
    model.eval()
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ CRITICAL ERROR LOADING MODEL: {e}")
    # We do not raise an error here so the app can still start and show logs,
    # but predictions will fail if model is None.

# ==========================================
# 3. INPUT SCHEMA
# ==========================================
class TextRequest(BaseModel):
    text: str

# ==========================================
# 4. API ENDPOINTS
# ==========================================

@app.get("/")
def home():
    return {"status": "online", "system": "Dockerized BanglaBERT API"}

@app.post("/predict")
def predict(request: TextRequest):
    try:
        # 1. Tokenize Input
        inputs = tokenizer(
            request.text, 
            return_tensors="pt", 
            truncation=True, 
            max_length=128, 
            padding=True
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # 2. Perform Inference
        with torch.no_grad():
            outputs = model(**inputs)
            probs = F.softmax(outputs.logits, dim=1)
        
        # 3. Determine Winner
        pred_idx = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred_idx].item()
        
        # 4. Resolve Label Name
        # Priority: Try model config first, fall back to manual map if missing
        if model.config.id2label and len(model.config.id2label) > 0:
            # Handle potential string/int key mismatch in config
            pred_label = model.config.id2label.get(pred_idx, model.config.id2label.get(str(pred_idx)))
        
        # Fallback if config is empty or failed
        if not pred_label:
            pred_label = ID2LABEL_MANUAL.get(pred_idx, "unknown")

        # 5. Format All Scores
        scores = {}
        for i in range(len(probs[0])):
            # Get label name for this index
            if model.config.id2label:
                lbl = model.config.id2label.get(i, model.config.id2label.get(str(i)))
            else:
                lbl = ID2LABEL_MANUAL.get(i, f"LABEL_{i}")
            
            scores[lbl] = float(probs[0][i])

        return {
            "text": request.text,
            "label": pred_label,
            "confidence": confidence,
            "scores": scores
        }

    except Exception as e:
        print(f"Prediction Error: {e}")
        return {"error": str(e)}