import torch from transformers import BertTokenizer, BertForSequenceClassification from fastapi import FastAPI, HTTPException from pydantic import BaseModel import gradio as gr import torch.nn.functional as F # ===================== # CONFIG # ===================== DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_DIR = "model_save" BEST_MODEL_PATH = "best_model.pt" MAX_LEN = 128 label_mapping = {0: "Hate Speech", 1: "Offensive", 2: "Neither"} # ===================== # LOAD MODEL # ===================== tokenizer = BertTokenizer.from_pretrained(MODEL_DIR) model = BertForSequenceClassification.from_pretrained(MODEL_DIR) model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE)) model.to(DEVICE) model.eval() # ===================== # CORE PREDICTION # ===================== def predict_offensive(text: str): if not text.strip(): return {"error": "Empty text"} encoded = tokenizer( text, return_tensors="pt", truncation=True, padding="max_length", max_length=MAX_LEN ) input_ids = encoded["input_ids"].to(DEVICE) attention_mask = encoded["attention_mask"].to(DEVICE) with torch.no_grad(): logits = model(input_ids, attention_mask=attention_mask).logits probs = F.softmax(logits, dim=1)[0] pred_idx = torch.argmax(probs).item() pred_label = label_mapping[pred_idx] confidence = probs[pred_idx].item() # Xác suất tất cả class probs_all = {label_mapping[i]: round(probs[i].item(), 4) for i in range(len(probs))} # Allowed nếu label là "Neither" và confidence > 0.6 allowed = pred_label == "Neither" and confidence > 0.6 return { "label": pred_label, "confidence": round(confidence, 4), "allowed": allowed, "probabilities": probs_all } # ===================== # FASTAPI APP # ===================== app = FastAPI(title="Offensive Language Detector API") class TextItem(BaseModel): text: str @app.post("/predict") def api_predict(item: TextItem): return predict_offensive(item.text) # ===================== # GRADIO UI # ===================== ui = gr.Interface( fn=predict_offensive, inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."), outputs=gr.JSON(label="Prediction"), title="Offensive Language Detector", description="Enter a sentence and the model will predict if it is offensive, with confidence scores for all classes." ) # Mount Gradio UI on FastAPI app = gr.mount_gradio_app(app, ui, path="/")