yuu1234's picture
Add 7
d3c6084
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import gradio as gr
import torch.nn.functional as F
# =====================
# CONFIG
# =====================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_DIR = "model_save"
BEST_MODEL_PATH = "best_model.pt"
MAX_LEN = 128
label_mapping = {0: "Hate Speech", 1: "Offensive", 2: "Neither"}
# =====================
# LOAD MODEL
# =====================
tokenizer = BertTokenizer.from_pretrained(MODEL_DIR)
model = BertForSequenceClassification.from_pretrained(MODEL_DIR)
model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
# =====================
# CORE PREDICTION
# =====================
def predict_offensive(text: str):
if not text.strip():
return {"error": "Empty text"}
encoded = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=MAX_LEN
)
input_ids = encoded["input_ids"].to(DEVICE)
attention_mask = encoded["attention_mask"].to(DEVICE)
with torch.no_grad():
logits = model(input_ids, attention_mask=attention_mask).logits
probs = F.softmax(logits, dim=1)[0]
pred_idx = torch.argmax(probs).item()
pred_label = label_mapping[pred_idx]
confidence = probs[pred_idx].item()
# Xác suất tất cả class
probs_all = {label_mapping[i]: round(probs[i].item(), 4) for i in range(len(probs))}
# Allowed nếu label là "Neither" và confidence > 0.6
allowed = pred_label == "Neither" and confidence > 0.6
return {
"label": pred_label,
"confidence": round(confidence, 4),
"allowed": allowed,
"probabilities": probs_all
}
# =====================
# FASTAPI APP
# =====================
app = FastAPI(title="Offensive Language Detector API")
class TextItem(BaseModel):
text: str
@app.post("/predict")
def api_predict(item: TextItem):
return predict_offensive(item.text)
# =====================
# GRADIO UI
# =====================
ui = gr.Interface(
fn=predict_offensive,
inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."),
outputs=gr.JSON(label="Prediction"),
title="Offensive Language Detector",
description="Enter a sentence and the model will predict if it is offensive, with confidence scores for all classes."
)
# Mount Gradio UI on FastAPI
app = gr.mount_gradio_app(app, ui, path="/")