Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # ========================================== | |
| # 1. SETUP & CONFIGURATION | |
| # ========================================== | |
| app = FastAPI() | |
| # Define the path to the model files (Root directory) | |
| MODEL_PATH = "." | |
| device = torch.device("cpu") # Hugging Face Spaces (Free Tier) uses CPU | |
| # MANUAL LABEL MAPPING (Safety Net) | |
| # Use this to fix any confusion between Red/Green results. | |
| # Adjust these indices if your model predicts the wrong class. | |
| ID2LABEL_MANUAL = { | |
| 0: "neutral", | |
| 1: "not_shirk", | |
| 2: "shirk" | |
| } | |
| # ========================================== | |
| # 2. LOAD MODEL | |
| # ========================================== | |
| print("Loading model...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) | |
| model.to(device) | |
| model.eval() | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ CRITICAL ERROR LOADING MODEL: {e}") | |
| # We do not raise an error here so the app can still start and show logs, | |
| # but predictions will fail if model is None. | |
| # ========================================== | |
| # 3. INPUT SCHEMA | |
| # ========================================== | |
| class TextRequest(BaseModel): | |
| text: str | |
| # ========================================== | |
| # 4. API ENDPOINTS | |
| # ========================================== | |
| def home(): | |
| return {"status": "online", "system": "Dockerized BanglaBERT API"} | |
| def predict(request: TextRequest): | |
| try: | |
| # 1. Tokenize Input | |
| inputs = tokenizer( | |
| request.text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=128, | |
| padding=True | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # 2. Perform Inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=1) | |
| # 3. Determine Winner | |
| pred_idx = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][pred_idx].item() | |
| # 4. Resolve Label Name | |
| # Priority: Try model config first, fall back to manual map if missing | |
| if model.config.id2label and len(model.config.id2label) > 0: | |
| # Handle potential string/int key mismatch in config | |
| pred_label = model.config.id2label.get(pred_idx, model.config.id2label.get(str(pred_idx))) | |
| # Fallback if config is empty or failed | |
| if not pred_label: | |
| pred_label = ID2LABEL_MANUAL.get(pred_idx, "unknown") | |
| # 5. Format All Scores | |
| scores = {} | |
| for i in range(len(probs[0])): | |
| # Get label name for this index | |
| if model.config.id2label: | |
| lbl = model.config.id2label.get(i, model.config.id2label.get(str(i))) | |
| else: | |
| lbl = ID2LABEL_MANUAL.get(i, f"LABEL_{i}") | |
| scores[lbl] = float(probs[0][i]) | |
| return { | |
| "text": request.text, | |
| "label": pred_label, | |
| "confidence": confidence, | |
| "scores": scores | |
| } | |
| except Exception as e: | |
| print(f"Prediction Error: {e}") | |
| return {"error": str(e)} |