Yousuf-Islam's picture
Create app.py
05b38bc verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# Enable CORS (Allows your React Frontend to talk to this API)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load Model (Global Variable)
MODEL_PATH = "/code/model"
print("Loading AI Model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
class InputData(BaseModel):
sentence: str
@app.get("/")
def home():
return {"status": "Online", "model": "BanglaBERT"}
@app.post("/api/predict")
def predict(data: InputData):
try:
# Tokenize
inputs = tokenizer(data.sentence, return_tensors="pt", padding=True, truncation=True, max_length=64)
# Predict
with torch.no_grad():
logits = model(**inputs).logits
# Calculate Confidence
probs = torch.nn.functional.softmax(logits, dim=1)
conf = torch.max(probs).item()
pred_id = torch.argmax(probs).item()
# Label Mapping (1=Shirk, 0=Not Shirk)
label = "shirk" if pred_id == 1 else "not shirk"
return {
"result": label,
"confidence": f"{conf:.2%}",
"cleaned_sentence": data.sentence
}
except Exception as e:
return {"error": str(e)}