lotachi commited on
Commit
7d13847
·
verified ·
1 Parent(s): 571f01e

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.colors as mcolors
9
+ import shap
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
+
12
+ # Load fine-tuned model from Hub
13
+ MODEL_ID = "lotachi/hatebert-toxic-classifier" # we push this in Cell 10
14
+ FALLBACK = "GroNLP/hateBERT"
15
+
16
+ device = torch.device("cpu") # HF Spaces free tier is CPU
17
+
18
+ try:
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(device)
21
+ print(f"Loaded fine-tuned model: {MODEL_ID}")
22
+ except:
23
+ print("Fine-tuned model not found, loading base HateBERT")
24
+ tokenizer = AutoTokenizer.from_pretrained(FALLBACK)
25
+ model = AutoModelForSequenceClassification.from_pretrained(FALLBACK, num_labels=2).to(device)
26
+
27
+ model.eval()
28
+ CLASS_NAMES = ["Non-Toxic", "Toxic"]
29
+
30
+ def predict_single(text):
31
+ enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding=True)
32
+ with torch.no_grad():
33
+ logits = model(**enc).logits
34
+ probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
35
+ pred = probs.argmax()
36
+ return CLASS_NAMES[pred], {CLASS_NAMES[i]: float(probs[i]) for i in range(2)}, float(probs[1])
37
+
38
+ def predict_batch(texts):
39
+ all_probs = []
40
+ for i in range(0, len(texts), 8):
41
+ batch = list(texts[i:i+8])
42
+ enc = tokenizer(batch, padding=True, truncation=True, max_length=128, return_tensors="pt")
43
+ with torch.no_grad():
44
+ logits = model(**enc).logits
45
+ all_probs.append(torch.softmax(logits, dim=1).cpu().numpy())
46
+ return np.vstack(all_probs)
47
+
48
+ masker = shap.maskers.Text(tokenizer)
49
+ explainer = shap.Explainer(predict_batch, masker, output_names=CLASS_NAMES)
50
+
51
+ def classify_text(text):
52
+ if not text or not text.strip():
53
+ return "Please enter some text.", {}, None
54
+ text = text.strip()[:800]
55
+ label, prob_dict, toxic_prob = predict_single(text)
56
+ if toxic_prob >= 0.8:
57
+ display = f"🚨 {label} ({toxic_prob:.0%} confidence)"
58
+ elif toxic_prob >= 0.5:
59
+ display = f"⚠️ {label} ({toxic_prob:.0%} confidence)"
60
+ else:
61
+ display = f"✅ {label} ({1-toxic_prob:.0%} confidence)"
62
+ try:
63
+ sv = explainer([text])
64
+ tokens = tokenizer.tokenize(text)[:25]
65
+ vals = sv[0].values[:len(tokens), 1]
66
+ vmax = max(abs(vals).max(), 0.01)
67
+ norm = mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
68
+ cmap = plt.cm.RdYlGn_r
69
+ fig, ax = plt.subplots(figsize=(10, max(3, len(tokens)*0.35)))
70
+ ax.barh(range(len(tokens)), vals, color=cmap(norm(vals)), edgecolor="white", height=0.7)
71
+ ax.set_yticks(range(len(tokens)))
72
+ ax.set_yticklabels(tokens, fontsize=10)
73
+ ax.axvline(0, color="black", linewidth=0.8)
74
+ ax.set_xlabel("SHAP Value", fontsize=10)
75
+ ax.set_title("Token SHAP Importance (Toxic class)", fontsize=12, fontweight="bold")
76
+ ax.invert_yaxis()
77
+ ax.spines[["top","right"]].set_visible(False)
78
+ plt.tight_layout()
79
+ except:
80
+ fig = None
81
+ return display, prob_dict, fig
82
+
83
+ EXAMPLES = [
84
+ ["I really enjoyed the community event today!"],
85
+ ["Thanks for your help, it made a big difference."],
86
+ ["The policy has been criticised by many stakeholders."],
87
+ ]
88
+
89
+ with gr.Blocks(title="Hate Speech Detector", theme=gr.themes.Soft()) as demo:
90
+ gr.Markdown("""# 🛡️ Hate Speech & Toxic Comment Detector
91
+ **MSc Data Science | CMP-L016 Deep Learning Applications**
92
+
93
+ Classifies text using fine-tuned **HateBERT** with **SHAP** word-level explanations.
94
+ """)
95
+ with gr.Row():
96
+ with gr.Column():
97
+ text_input = gr.Textbox(lines=5, placeholder="Enter text...", label="Input Text")
98
+ submit_btn = gr.Button("🔍 Classify", variant="primary")
99
+ gr.Examples(examples=EXAMPLES, inputs=text_input)
100
+ with gr.Column():
101
+ label_out = gr.Textbox(label="Result", interactive=False, lines=2)
102
+ prob_out = gr.Label(num_top_classes=2, label="Confidence")
103
+ shap_out = gr.Plot(label="SHAP Explanation")
104
+ gr.Markdown("> ⚠️ Research tool only. Not for production moderation decisions.")
105
+ submit_btn.click(classify_text, inputs=text_input, outputs=[label_out, prob_out, shap_out])
106
+ text_input.submit(classify_text, inputs=text_input, outputs=[label_out, prob_out, shap_out])
107
+
108
+ demo.launch()