banglabert-api / app.py
Yousuf-Islam's picture
Update app.py
c275311 verified
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ==========================================
# 1. SETUP & CONFIGURATION
# ==========================================
app = FastAPI()
# Define the path to the model files (Root directory)
MODEL_PATH = "."
device = torch.device("cpu") # Hugging Face Spaces (Free Tier) uses CPU
# MANUAL LABEL MAPPING (Safety Net)
# Use this to fix any confusion between Red/Green results.
# Adjust these indices if your model predicts the wrong class.
ID2LABEL_MANUAL = {
0: "neutral",
1: "not_shirk",
2: "shirk"
}
# ==========================================
# 2. LOAD MODEL
# ==========================================
print("Loading model...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(device)
model.eval()
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ CRITICAL ERROR LOADING MODEL: {e}")
# We do not raise an error here so the app can still start and show logs,
# but predictions will fail if model is None.
# ==========================================
# 3. INPUT SCHEMA
# ==========================================
class TextRequest(BaseModel):
text: str
# ==========================================
# 4. API ENDPOINTS
# ==========================================
@app.get("/")
def home():
return {"status": "online", "system": "Dockerized BanglaBERT API"}
@app.post("/predict")
def predict(request: TextRequest):
try:
# 1. Tokenize Input
inputs = tokenizer(
request.text,
return_tensors="pt",
truncation=True,
max_length=128,
padding=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# 2. Perform Inference
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)
# 3. Determine Winner
pred_idx = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_idx].item()
# 4. Resolve Label Name
# Priority: Try model config first, fall back to manual map if missing
if model.config.id2label and len(model.config.id2label) > 0:
# Handle potential string/int key mismatch in config
pred_label = model.config.id2label.get(pred_idx, model.config.id2label.get(str(pred_idx)))
# Fallback if config is empty or failed
if not pred_label:
pred_label = ID2LABEL_MANUAL.get(pred_idx, "unknown")
# 5. Format All Scores
scores = {}
for i in range(len(probs[0])):
# Get label name for this index
if model.config.id2label:
lbl = model.config.id2label.get(i, model.config.id2label.get(str(i)))
else:
lbl = ID2LABEL_MANUAL.get(i, f"LABEL_{i}")
scores[lbl] = float(probs[0][i])
return {
"text": request.text,
"label": pred_label,
"confidence": confidence,
"scores": scores
}
except Exception as e:
print(f"Prediction Error: {e}")
return {"error": str(e)}