Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import gradio as gr | |
| import torch.nn.functional as F | |
| # ===================== | |
| # CONFIG | |
| # ===================== | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_DIR = "model_save" | |
| BEST_MODEL_PATH = "best_model.pt" | |
| MAX_LEN = 128 | |
| label_mapping = {0: "Hate Speech", 1: "Offensive", 2: "Neither"} | |
| # ===================== | |
| # LOAD MODEL | |
| # ===================== | |
| tokenizer = BertTokenizer.from_pretrained(MODEL_DIR) | |
| model = BertForSequenceClassification.from_pretrained(MODEL_DIR) | |
| model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE)) | |
| model.to(DEVICE) | |
| model.eval() | |
| # ===================== | |
| # CORE PREDICTION | |
| # ===================== | |
| def predict_offensive(text: str): | |
| if not text.strip(): | |
| return {"error": "Empty text"} | |
| encoded = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=MAX_LEN | |
| ) | |
| input_ids = encoded["input_ids"].to(DEVICE) | |
| attention_mask = encoded["attention_mask"].to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(input_ids, attention_mask=attention_mask).logits | |
| probs = F.softmax(logits, dim=1)[0] | |
| pred_idx = torch.argmax(probs).item() | |
| pred_label = label_mapping[pred_idx] | |
| confidence = probs[pred_idx].item() | |
| # Xác suất tất cả class | |
| probs_all = {label_mapping[i]: round(probs[i].item(), 4) for i in range(len(probs))} | |
| # Allowed nếu label là "Neither" và confidence > 0.6 | |
| allowed = pred_label == "Neither" and confidence > 0.6 | |
| return { | |
| "label": pred_label, | |
| "confidence": round(confidence, 4), | |
| "allowed": allowed, | |
| "probabilities": probs_all | |
| } | |
| # ===================== | |
| # FASTAPI APP | |
| # ===================== | |
| app = FastAPI(title="Offensive Language Detector API") | |
| class TextItem(BaseModel): | |
| text: str | |
| def api_predict(item: TextItem): | |
| return predict_offensive(item.text) | |
| # ===================== | |
| # GRADIO UI | |
| # ===================== | |
| ui = gr.Interface( | |
| fn=predict_offensive, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."), | |
| outputs=gr.JSON(label="Prediction"), | |
| title="Offensive Language Detector", | |
| description="Enter a sentence and the model will predict if it is offensive, with confidence scores for all classes." | |
| ) | |
| # Mount Gradio UI on FastAPI | |
| app = gr.mount_gradio_app(app, ui, path="/") | |