import torch import torch.nn.functional as F from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification # ========================================== # 1. SETUP & CONFIGURATION # ========================================== app = FastAPI() # Define the path to the model files (Root directory) MODEL_PATH = "." device = torch.device("cpu") # Hugging Face Spaces (Free Tier) uses CPU # MANUAL LABEL MAPPING (Safety Net) # Use this to fix any confusion between Red/Green results. # Adjust these indices if your model predicts the wrong class. ID2LABEL_MANUAL = { 0: "neutral", 1: "not_shirk", 2: "shirk" } # ========================================== # 2. LOAD MODEL # ========================================== print("Loading model...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) model.to(device) model.eval() print("✅ Model loaded successfully!") except Exception as e: print(f"❌ CRITICAL ERROR LOADING MODEL: {e}") # We do not raise an error here so the app can still start and show logs, # but predictions will fail if model is None. # ========================================== # 3. INPUT SCHEMA # ========================================== class TextRequest(BaseModel): text: str # ========================================== # 4. API ENDPOINTS # ========================================== @app.get("/") def home(): return {"status": "online", "system": "Dockerized BanglaBERT API"} @app.post("/predict") def predict(request: TextRequest): try: # 1. Tokenize Input inputs = tokenizer( request.text, return_tensors="pt", truncation=True, max_length=128, padding=True ) inputs = {k: v.to(device) for k, v in inputs.items()} # 2. Perform Inference with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=1) # 3. Determine Winner pred_idx = torch.argmax(probs, dim=1).item() confidence = probs[0][pred_idx].item() # 4. Resolve Label Name # Priority: Try model config first, fall back to manual map if missing if model.config.id2label and len(model.config.id2label) > 0: # Handle potential string/int key mismatch in config pred_label = model.config.id2label.get(pred_idx, model.config.id2label.get(str(pred_idx))) # Fallback if config is empty or failed if not pred_label: pred_label = ID2LABEL_MANUAL.get(pred_idx, "unknown") # 5. Format All Scores scores = {} for i in range(len(probs[0])): # Get label name for this index if model.config.id2label: lbl = model.config.id2label.get(i, model.config.id2label.get(str(i))) else: lbl = ID2LABEL_MANUAL.get(i, f"LABEL_{i}") scores[lbl] = float(probs[0][i]) return { "text": request.text, "label": pred_label, "confidence": confidence, "scores": scores } except Exception as e: print(f"Prediction Error: {e}") return {"error": str(e)}