| | import os |
| | import re |
| | import string |
| | from typing import List, Union |
| | import torch |
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import uvicorn |
| | import numpy as np |
| |
|
| | |
| | |
| | |
| | |
| | MODEL_ID = "Yousuf-Islam/Upgraded_IndicBERT_Model" |
| |
|
| | LABEL2ID = {"shirk": 0, "tawheed": 1, "neutral": 2} |
| | ID2LABEL = {v: k for k, v in LABEL2ID.items()} |
| | MAX_LENGTH = 128 |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | |
| | |
| | def clean_text(text: str) -> str: |
| | if text is None: |
| | return "" |
| | text = str(text) |
| | |
| | |
| | text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE) |
| | text = re.sub(r"\S+@\S+", "", text) |
| | text = re.sub(r"@\w+|#\w+", "", text) |
| | |
| | |
| | english_punct = string.punctuation |
| | bengali_punct = "।॥‘’“”,;:—–-!?()[]{}<>…•°৳" |
| | all_punct = english_punct + bengali_punct |
| | text = text.translate(str.maketrans("", "", all_punct)) |
| | |
| | |
| | text = " ".join(text.split()) |
| | |
| | |
| | return text.strip() if len(text.strip()) >= 2 else "" |
| |
|
| | |
| | |
| | |
| | print(f"Loading upgraded model from {MODEL_ID} on {device}...") |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| | model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) |
| | model.to(device) |
| | model.eval() |
| | print("Model loaded successfully!") |
| |
|
| | |
| | |
| | |
| | app = FastAPI( |
| | title="Upgraded Bangla Shirk-Tawheed Classifier", |
| | description="API for classifying Bengali sentences into Shirk, Tawheed, or Neutral labels." |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=False, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | class PredictRequest(BaseModel): |
| | text: Union[str, List[str]] |
| |
|
| | class PredictionResult(BaseModel): |
| | input_text: str |
| | clean_text: str |
| | prediction: str |
| | scores: dict |
| |
|
| | |
| | @app.get("/") |
| | def home(): |
| | return { |
| | "status": "online", |
| | "model": MODEL_ID, |
| | "message": "Send a POST request to /predict" |
| | } |
| |
|
| | |
| | @app.post("/predict") |
| | def predict(req: PredictRequest): |
| | |
| | input_texts = [req.text] if isinstance(req.text, str) else req.text |
| |
|
| | if not input_texts: |
| | raise HTTPException(status_code=400, detail="No text provided") |
| |
|
| | |
| | cleaned = [clean_text(t) for t in input_texts] |
| | |
| | |
| | valid_indices = [i for i, t in enumerate(cleaned) if t != ""] |
| | valid_texts = [cleaned[i] for i in valid_indices] |
| | |
| | final_results = [None] * len(input_texts) |
| |
|
| | if valid_texts: |
| | |
| | inputs = tokenizer( |
| | valid_texts, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=MAX_LENGTH |
| | ).to(device) |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy() |
| | |
| | for i, idx in enumerate(valid_indices): |
| | pred_id = np.argmax(probs[i]) |
| | label = ID2LABEL[pred_id] |
| | scores = {ID2LABEL[j]: float(probs[i][j]) for j in range(len(ID2LABEL))} |
| | |
| | final_results[idx] = { |
| | "input_text": input_texts[idx], |
| | "clean_text": cleaned[idx], |
| | "prediction": label, |
| | "scores": scores |
| | } |
| |
|
| | |
| | for i in range(len(final_results)): |
| | if final_results[i] is None: |
| | final_results[i] = { |
| | "input_text": input_texts[i], |
| | "clean_text": cleaned[i], |
| | "prediction": "invalid/empty after cleaning", |
| | "scores": {} |
| | } |
| |
|
| | return {"results": final_results} |
| |
|
| | if __name__ == "__main__": |
| | |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |