import os import re import string from typing import List, Union import torch from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import uvicorn import numpy as np # ------------------------ # কনফিগারেশন # ------------------------ # আপনার নতুন আপগ্রেডেড মডেল আইডি MODEL_ID = "Yousuf-Islam/Upgraded_IndicBERT_Model" LABEL2ID = {"shirk": 0, "tawheed": 1, "neutral": 2} ID2LABEL = {v: k for k, v in LABEL2ID.items()} MAX_LENGTH = 128 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------------ # বাংলা টেক্সট প্রিপসেসিং # ------------------------ def clean_text(text: str) -> str: if text is None: return "" text = str(text) # ইউআরএল, ইমেইল এবং মেনশন মুছে ফেলা text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE) text = re.sub(r"\S+@\S+", "", text) text = re.sub(r"@\w+|#\w+", "", text) # ইংরেজি ও বাংলা বিরামচিহ্ন মুছে ফেলা english_punct = string.punctuation bengali_punct = "।॥‘’“”,;:—–-!?()[]{}<>…•°৳" all_punct = english_punct + bengali_punct text = text.translate(str.maketrans("", "", all_punct)) # অতিরিক্ত স্পেস মুছে ফেলা text = " ".join(text.split()) # কমপক্ষে ২ অক্ষরের বেশি হলে রিটার্ন করবে return text.strip() if len(text.strip()) >= 2 else "" # ------------------------ # মডেল লোড করা # ------------------------ print(f"Loading upgraded model from {MODEL_ID} on {device}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) model.to(device) model.eval() print("Model loaded successfully!") # ------------------------ # FastAPI অ্যাপ সেটআপ # ------------------------ app = FastAPI( title="Upgraded Bangla Shirk-Tawheed Classifier", description="API for classifying Bengali sentences into Shirk, Tawheed, or Neutral labels." ) # CORS এনাবল করা (যাতে রিঅ্যাক্ট বা অন্য সাইট থেকে কল করা যায়) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) # ডাটা ফরম্যাট (Pydantic) class PredictRequest(BaseModel): text: Union[str, List[str]] class PredictionResult(BaseModel): input_text: str clean_text: str prediction: str scores: dict # রুটে চেক করার জন্য মেসেজ @app.get("/") def home(): return { "status": "online", "model": MODEL_ID, "message": "Send a POST request to /predict" } # প্রেডিকশন এন্ডপয়েন্ট @app.post("/predict") def predict(req: PredictRequest): # ইনপুট যদি একটা বাক্য হয় তবে লিস্টে রূপান্তর input_texts = [req.text] if isinstance(req.text, str) else req.text if not input_texts: raise HTTPException(status_code=400, detail="No text provided") # টেক্সট ক্লিন করা cleaned = [clean_text(t) for t in input_texts] # ভ্যালিড টেক্সটগুলোর ইনডেক্স খুঁজে বের করা valid_indices = [i for i, t in enumerate(cleaned) if t != ""] valid_texts = [cleaned[i] for i in valid_indices] final_results = [None] * len(input_texts) if valid_texts: # টোকেনাইজ করা inputs = tokenizer( valid_texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH ).to(device) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy() for i, idx in enumerate(valid_indices): pred_id = np.argmax(probs[i]) label = ID2LABEL[pred_id] scores = {ID2LABEL[j]: float(probs[i][j]) for j in range(len(ID2LABEL))} final_results[idx] = { "input_text": input_texts[idx], "clean_text": cleaned[idx], "prediction": label, "scores": scores } # ইনভ্যালিড বা খালি ইনপুটের জন্য ডিফল্ট মেসেজ for i in range(len(final_results)): if final_results[i] is None: final_results[i] = { "input_text": input_texts[i], "clean_text": cleaned[i], "prediction": "invalid/empty after cleaning", "scores": {} } return {"results": final_results} if __name__ == "__main__": # HF Spaces এর জন্য পোর্ট অবশ্যই ৭০০০-৮০০০ এর মধ্যে হতে হয়, ডিফল্ট ৭০০০ বা ৭860 uvicorn.run(app, host="0.0.0.0", port=7860)