from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch from fastapi.middleware.cors import CORSMiddleware app = FastAPI() # Enable CORS (Allows your React Frontend to talk to this API) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load Model (Global Variable) MODEL_PATH = "/code/model" print("Loading AI Model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) class InputData(BaseModel): sentence: str @app.get("/") def home(): return {"status": "Online", "model": "BanglaBERT"} @app.post("/api/predict") def predict(data: InputData): try: # Tokenize inputs = tokenizer(data.sentence, return_tensors="pt", padding=True, truncation=True, max_length=64) # Predict with torch.no_grad(): logits = model(**inputs).logits # Calculate Confidence probs = torch.nn.functional.softmax(logits, dim=1) conf = torch.max(probs).item() pred_id = torch.argmax(probs).item() # Label Mapping (1=Shirk, 0=Not Shirk) label = "shirk" if pred_id == 1 else "not shirk" return { "result": label, "confidence": f"{conf:.2%}", "cleaned_sentence": data.sentence } except Exception as e: return {"error": str(e)}