yuu1234 commited on
Commit
d3c6084
·
1 Parent(s): 99fece0
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -27,7 +27,10 @@ model.eval()
27
  # =====================
28
  # CORE PREDICTION
29
  # =====================
30
- def predict_offensive(text):
 
 
 
31
  encoded = tokenizer(
32
  text,
33
  return_tensors="pt",
@@ -40,7 +43,7 @@ def predict_offensive(text):
40
 
41
  with torch.no_grad():
42
  logits = model(input_ids, attention_mask=attention_mask).logits
43
- probs = F.softmax(logits, dim=1)[0] # Softmax để lấy xác suất
44
 
45
  pred_idx = torch.argmax(probs).item()
46
  pred_label = label_mapping[pred_idx]
@@ -69,23 +72,19 @@ class TextItem(BaseModel):
69
 
70
  @app.post("/predict")
71
  def api_predict(item: TextItem):
72
- if not item.text:
73
- raise HTTPException(status_code=400, detail="Missing text")
74
  return predict_offensive(item.text)
75
 
76
  # =====================
77
  # GRADIO UI
78
  # =====================
79
- def gradio_ui(text):
80
- return predict_offensive(text)
81
-
82
  ui = gr.Interface(
83
- fn=gradio_ui,
84
  inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."),
85
- outputs="json",
86
  title="Offensive Language Detector",
87
  description="Enter a sentence and the model will predict if it is offensive, with confidence scores for all classes."
88
  )
89
 
90
  # Mount Gradio UI on FastAPI
91
  app = gr.mount_gradio_app(app, ui, path="/")
 
 
27
  # =====================
28
  # CORE PREDICTION
29
  # =====================
30
+ def predict_offensive(text: str):
31
+ if not text.strip():
32
+ return {"error": "Empty text"}
33
+
34
  encoded = tokenizer(
35
  text,
36
  return_tensors="pt",
 
43
 
44
  with torch.no_grad():
45
  logits = model(input_ids, attention_mask=attention_mask).logits
46
+ probs = F.softmax(logits, dim=1)[0]
47
 
48
  pred_idx = torch.argmax(probs).item()
49
  pred_label = label_mapping[pred_idx]
 
72
 
73
  @app.post("/predict")
74
  def api_predict(item: TextItem):
 
 
75
  return predict_offensive(item.text)
76
 
77
  # =====================
78
  # GRADIO UI
79
  # =====================
 
 
 
80
  ui = gr.Interface(
81
+ fn=predict_offensive,
82
  inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."),
83
+ outputs=gr.JSON(label="Prediction"),
84
  title="Offensive Language Detector",
85
  description="Enter a sentence and the model will predict if it is offensive, with confidence scores for all classes."
86
  )
87
 
88
  # Mount Gradio UI on FastAPI
89
  app = gr.mount_gradio_app(app, ui, path="/")
90
+