Yousuf-Islam's picture
Update main.py
265040f verified
import os
import re
import string
from typing import List, Union
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import uvicorn
import numpy as np
# ------------------------
# কনফিগারেশন
# ------------------------
# আপনার নতুন আপগ্রেডেড মডেল আইডি
MODEL_ID = "Yousuf-Islam/Upgraded_IndicBERT_Model"
LABEL2ID = {"shirk": 0, "tawheed": 1, "neutral": 2}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
MAX_LENGTH = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------
# বাংলা টেক্সট প্রিপসেসিং
# ------------------------
def clean_text(text: str) -> str:
if text is None:
return ""
text = str(text)
# ইউআরএল, ইমেইল এবং মেনশন মুছে ফেলা
text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE)
text = re.sub(r"\S+@\S+", "", text)
text = re.sub(r"@\w+|#\w+", "", text)
# ইংরেজি ও বাংলা বিরামচিহ্ন মুছে ফেলা
english_punct = string.punctuation
bengali_punct = "।॥‘’“”,;:—–-!?()[]{}<>…•°৳"
all_punct = english_punct + bengali_punct
text = text.translate(str.maketrans("", "", all_punct))
# অতিরিক্ত স্পেস মুছে ফেলা
text = " ".join(text.split())
# কমপক্ষে ২ অক্ষরের বেশি হলে রিটার্ন করবে
return text.strip() if len(text.strip()) >= 2 else ""
# ------------------------
# মডেল লোড করা
# ------------------------
print(f"Loading upgraded model from {MODEL_ID} on {device}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.to(device)
model.eval()
print("Model loaded successfully!")
# ------------------------
# FastAPI অ্যাপ সেটআপ
# ------------------------
app = FastAPI(
title="Upgraded Bangla Shirk-Tawheed Classifier",
description="API for classifying Bengali sentences into Shirk, Tawheed, or Neutral labels."
)
# CORS এনাবল করা (যাতে রিঅ্যাক্ট বা অন্য সাইট থেকে কল করা যায়)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
# ডাটা ফরম্যাট (Pydantic)
class PredictRequest(BaseModel):
text: Union[str, List[str]]
class PredictionResult(BaseModel):
input_text: str
clean_text: str
prediction: str
scores: dict
# রুটে চেক করার জন্য মেসেজ
@app.get("/")
def home():
return {
"status": "online",
"model": MODEL_ID,
"message": "Send a POST request to /predict"
}
# প্রেডিকশন এন্ডপয়েন্ট
@app.post("/predict")
def predict(req: PredictRequest):
# ইনপুট যদি একটা বাক্য হয় তবে লিস্টে রূপান্তর
input_texts = [req.text] if isinstance(req.text, str) else req.text
if not input_texts:
raise HTTPException(status_code=400, detail="No text provided")
# টেক্সট ক্লিন করা
cleaned = [clean_text(t) for t in input_texts]
# ভ্যালিড টেক্সটগুলোর ইনডেক্স খুঁজে বের করা
valid_indices = [i for i, t in enumerate(cleaned) if t != ""]
valid_texts = [cleaned[i] for i in valid_indices]
final_results = [None] * len(input_texts)
if valid_texts:
# টোকেনাইজ করা
inputs = tokenizer(
valid_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_LENGTH
).to(device)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()
for i, idx in enumerate(valid_indices):
pred_id = np.argmax(probs[i])
label = ID2LABEL[pred_id]
scores = {ID2LABEL[j]: float(probs[i][j]) for j in range(len(ID2LABEL))}
final_results[idx] = {
"input_text": input_texts[idx],
"clean_text": cleaned[idx],
"prediction": label,
"scores": scores
}
# ইনভ্যালিড বা খালি ইনপুটের জন্য ডিফল্ট মেসেজ
for i in range(len(final_results)):
if final_results[i] is None:
final_results[i] = {
"input_text": input_texts[i],
"clean_text": cleaned[i],
"prediction": "invalid/empty after cleaning",
"scores": {}
}
return {"results": final_results}
if __name__ == "__main__":
# HF Spaces এর জন্য পোর্ট অবশ্যই ৭০০০-৮০০০ এর মধ্যে হতে হয়, ডিফল্ট ৭০০০ বা ৭860
uvicorn.run(app, host="0.0.0.0", port=7860)