karl0706 commited on
Commit
e7ff7b6
·
1 Parent(s): d08c94a

Add application fille

Browse files
Files changed (2) hide show
  1. app.py +328 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for multi-label toxicity prediction using a Hugging Face model.
3
+ Standalone file: app.py
4
+
5
+ Two modes supported:
6
+ - Remote inference via Hugging Face Inference API (recommended if you don't want to download the model locally). Set env var HF_API_TOKEN if your model is private.
7
+ - Local loading via transformers.from_pretrained (will download model weights locally). This is used if HF_API_TOKEN is not set and you prefer to download the model.
8
+
9
+ Usage:
10
+ 1) Install requirements:
11
+ pip install -r requirements.txt
12
+ or
13
+ pip install transformers torch gradio numpy pandas huggingface-hub
14
+
15
+ 2) If your model is private, export your token:
16
+ export HF_API_TOKEN=hf_xxx... # macOS / Linux
17
+ set HF_API_TOKEN=hf_xxx... # Windows (PowerShell: $env:HF_API_TOKEN = 'hf_xxx')
18
+
19
+ 3) Run:
20
+ python app.py
21
+
22
+ Notes:
23
+ - By default the app will try to use the Hugging Face Inference API (no heavy downloads). If no token is found and you want local loading anyway, the app falls back to downloading the model via transformers.
24
+ - Change MODEL_ID to your HF repo id if needed.
25
+
26
+ """
27
+
28
+ import os
29
+ import torch
30
+ import numpy as np
31
+ import pandas as pd
32
+ import gradio as gr
33
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
34
+
35
+ # Optional: inference API
36
+ from huggingface_hub import InferenceApi, hf_api
37
+ import tempfile
38
+ from pathlib import Path
39
+
40
+ # ---- Config ----
41
+ MODEL_ID = "NathanDB/toxic-bert-dsti" # change if needed
42
+ LABEL_COLS = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
43
+ MAX_LEN = 256
44
+ # thresholds used to convert probabilities -> binary labels
45
+ THRESHOLDS = np.array([0.90, 0.25, 0.90, 0.10, 0.40, 0.15], dtype=np.float32)
46
+
47
+ # ---- Device ----
48
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
49
+
50
+ # ---- Choose mode: remote inference (Inference API) if HF_API_TOKEN present, otherwise local from_pretrained ----
51
+ HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
52
+ USE_REMOTE = True # prefer remote when token available or model is public
53
+
54
+ inference_api = None
55
+ model = None
56
+ tokenizer = None
57
+
58
+ # Try to initialise remote inference if possible
59
+ if USE_REMOTE:
60
+ try:
61
+ if HF_API_TOKEN:
62
+ inference_api = InferenceApi(repo_id=MODEL_ID, token=HF_API_TOKEN)
63
+ print("Using Hugging Face Inference API (private or token provided).")
64
+ else:
65
+ # Try without token (works for public models)
66
+ inference_api = InferenceApi(repo_id=MODEL_ID)
67
+ print("Using Hugging Face Inference API (public model).")
68
+ except Exception as e:
69
+ print("Remote Inference API unavailable or failed to init:", e)
70
+ inference_api = None
71
+
72
+ # If remote inference is not available, fallback to local loading (this will download the model)
73
+ if inference_api is None:
74
+ print("Falling back to local model download via transformers.from_pretrained()")
75
+ # Load tokenizer & model once
76
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
77
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
78
+ model.to(DEVICE)
79
+ model.eval()
80
+ print("Model downloaded and loaded locally.")
81
+
82
+
83
+ # ---- Prediction helpers ----
84
+
85
+ def predict_toxicity_local(text: str):
86
+ """Run the model locally (downloaded weights). Returns (probs, preds, dict)"""
87
+ if not isinstance(text, str) or text.strip() == "":
88
+ return None
89
+
90
+ enc = tokenizer(
91
+ [text],
92
+ padding=True,
93
+ truncation=True,
94
+ max_length=MAX_LEN,
95
+ return_tensors="pt",
96
+ ).to(DEVICE)
97
+
98
+ with torch.no_grad():
99
+ logits = model(**enc).logits
100
+ probs = torch.sigmoid(logits).cpu().numpy()[0]
101
+
102
+ preds = (probs >= THRESHOLDS).astype(int)
103
+
104
+ result_dict = {lbl: {"probability": float(round(float(probs[i]), 6)), "predicted": bool(preds[i])} for i, lbl in enumerate(LABEL_COLS)}
105
+
106
+ return probs, preds, result_dict
107
+
108
+
109
+ def predict_toxicity_remote(text: str):
110
+ """Call Hugging Face Inference API. Tries to map results back to LABEL_COLS."""
111
+ if inference_api is None:
112
+ raise RuntimeError("Inference API is not initialized")
113
+ # The Inference API for sequence-classification returns a list of {label:..., score:...}
114
+ try:
115
+ response = inference_api(inputs=text)
116
+ except Exception as e:
117
+ return None
118
+
119
+ # Example response: [{'label': 'toxic', 'score': 0.95}, ...]
120
+ # or: {'error': '...'}
121
+ if isinstance(response, dict) and response.get("error"):
122
+ raise RuntimeError(f"Inference API error: {response.get('error')}")
123
+
124
+ # Normalize into a dict: label -> score
125
+ label_to_score = {}
126
+ if isinstance(response, list):
127
+ for item in response:
128
+ lab = item.get("label")
129
+ score = float(item.get("score", 0.0))
130
+ # Some models return labels like "LABEL_0"; try to handle that
131
+ label_to_score[lab] = score
132
+ elif isinstance(response, dict):
133
+ # Rare forms: model may return a dict with logits; try to handle
134
+ # If 'scores' key exists and is list of floats
135
+ if "scores" in response and isinstance(response["scores"], list):
136
+ # attempt to map by order
137
+ scores = response["scores"]
138
+ for i, lab in enumerate(LABEL_COLS):
139
+ label_to_score[lab] = float(scores[i]) if i < len(scores) else 0.0
140
+ else:
141
+ # fallback empty
142
+ pass
143
+
144
+ # Try to match LABEL_COLS case-insensitively
145
+ probs = np.zeros(len(LABEL_COLS), dtype=float)
146
+ for i, lbl in enumerate(LABEL_COLS):
147
+ # direct match
148
+ if lbl in label_to_score:
149
+ probs[i] = label_to_score[lbl]
150
+ continue
151
+ # case-insensitive match
152
+ for k, v in label_to_score.items():
153
+ if k.lower() == lbl.lower():
154
+ probs[i] = v
155
+ break
156
+ else:
157
+ # maybe the label is e.g. "LABEL_0" — try to order-match if counts same
158
+ pass
159
+
160
+ # If we didn't get any scores (all zeros), try to infer from ordering if lengths match
161
+ if probs.sum() == 0 and isinstance(response, list) and len(response) == len(LABEL_COLS):
162
+ for i, item in enumerate(response):
163
+ probs[i] = float(item.get("score", 0.0))
164
+
165
+ preds = (probs >= THRESHOLDS).astype(int)
166
+ result_dict = {lbl: {"probability": float(round(float(probs[i]), 6)), "predicted": bool(preds[i])} for i, lbl in enumerate(LABEL_COLS)}
167
+
168
+ return probs, preds, result_dict
169
+
170
+
171
+ # Unified wrapper used by Gradio
172
+ def predict_toxicity(text: str):
173
+ if not isinstance(text, str) or text.strip() == "":
174
+ empty_df = pd.DataFrame(columns=["label", "probability", "predicted"])
175
+ return empty_df, {}
176
+
177
+ if inference_api is not None:
178
+ out = predict_toxicity_remote(text)
179
+ if out is None:
180
+ # fallback to local if remote fails and a local model exists
181
+ if model is not None:
182
+ probs, preds, result_dict = predict_toxicity_local(text)
183
+ else:
184
+ raise RuntimeError("Remote call failed and no local model available")
185
+ else:
186
+ probs, preds, result_dict = out
187
+ else:
188
+ probs, preds, result_dict = predict_toxicity_local(text)
189
+
190
+ rows = []
191
+ for i, lbl in enumerate(LABEL_COLS):
192
+ prob = float(probs[i])
193
+ pred = int(preds[i])
194
+ rows.append({"label": lbl, "probability": round(prob, 6), "predicted": pred})
195
+
196
+ df = pd.DataFrame(rows).sort_values("probability", ascending=False).reset_index(drop=True)
197
+ return df, result_dict
198
+
199
+
200
+ # Helper to save CSV (temp)
201
+ def save_df_to_csv(df: pd.DataFrame):
202
+ tmpdir = Path(tempfile.gettempdir())
203
+ path = tmpdir / f"toxicity_result_{os.getpid()}.csv"
204
+ df.to_csv(path, index=False)
205
+ return str(path)
206
+
207
+ # Petite fonction pour créer un HTML simple et stylé
208
+ def build_result_html(df: pd.DataFrame, result_dict: dict, text: str):
209
+ """Build a stylized HTML result with thin bars and toxicity summary"""
210
+ style = """
211
+ <style>
212
+ .card { background:linear-gradient(180deg,#0b1220,#0f1724); padding:20px; border-radius:14px; color:#e6eef8; font-family:'Segoe UI',Inter,Arial; }
213
+ .title { display:flex; justify-content:space-between; align-items:center; gap:12px; margin-bottom:16px; }
214
+ .title h3 { margin:0; font-size:18px; font-weight:700; }
215
+ .title-desc { color:#9fb0c7; font-size:13px; }
216
+ .badge { padding:8px 12px; border-radius:999px; font-weight:700; font-size:14px; }
217
+ .good { background:#10b981; color:#022; }
218
+ .bad { background:#ef4444; color:#fef2f2; }
219
+
220
+ .summary-box { background:#071226; padding:12px; border-radius:10px; margin-top:14px; border-left:3px solid #ef4444; }
221
+ .summary-box.clean { border-left-color:#10b981; }
222
+ .summary-text { color:#cfe8ff; font-size:14px; line-height:1.5; }
223
+ .summary-text strong { font-weight:700; color:#10b981; }
224
+ .summary-text .toxic-label { background:#ef4444; color:#fef2f2; padding:2px 6px; border-radius:4px; margin:0 4px; font-weight:700; font-size:13px; }
225
+
226
+ .row { display:flex; align-items:center; gap:10px; margin-top:10px; }
227
+ .label { width:140px; text-transform:capitalize; font-weight:600; color:#cfe8ff; font-size:13px; }
228
+ .bar-container { display:flex; flex-direction:column; gap:4px; flex:1; }
229
+ .bar-bg { background:#071226; width:100%; border-radius:999px; height:6px; overflow:hidden; box-shadow:inset 0 1px 2px rgba(0,0,0,0.3); position:relative; }
230
+ .bar { height:100%; border-radius:999px; transition: width .6s cubic-bezier(0.34, 1.56, 0.64, 1); background:#06b6d4; }
231
+ .threshold-line { position:absolute; top:0; height:100%; width:2px; background:#ef4444; opacity:0.8; }
232
+ .bar-labels { display:flex; justify-content:space-between; font-size:11px; color:#9fb0c7; }
233
+ .prob { min-width:50px; text-align:right; font-weight:700; color:#cfe8ff; font-size:13px; }
234
+ .predicted-badge { padding:2px 6px; border-radius:4px; font-weight:700; font-size:11px; margin-left:8px; }
235
+ .predicted-true { background:#ef4444; color:#fef2f2; }
236
+ .predicted-false { background:#10b981; color:#022; }
237
+ </style>
238
+ """
239
+ html = style + "<div class='card'>"
240
+
241
+ # Header avec badge
242
+ any_toxic = any([v["predicted"] for v in result_dict.values()])
243
+ status = "<div class='badge bad'>⚠️ Toxic</div>" if any_toxic else "<div class='badge good'>✅ Clean</div>"
244
+ html += f"<div class='title'><div><h3>Toxicity Analysis</h3><div class='title-desc'>Probability per category</div></div>{status}</div>"
245
+
246
+ # Summary text
247
+ if any_toxic:
248
+ toxic_categories = [lbl.replace('_', ' ').title() for lbl, v in result_dict.items() if v["predicted"]]
249
+ toxic_str = ", ".join([f"<span class='toxic-label'>{cat}</span>" for cat in toxic_categories])
250
+ html += f"<div class='summary-box'><div class='summary-text'><strong>Message detected as toxic</strong> — we identified the following categories: {toxic_str}</div></div>"
251
+ else:
252
+ html += "<div class='summary-box clean'><div class='summary-text'><strong>✅ No toxicity detected</strong> — this message appears safe and appropriate.</div></div>"
253
+
254
+ # Thin bars with threshold indicators
255
+ html += "<div style='margin-top:16px;'>"
256
+ for i, row in df.iterrows():
257
+ label = row['label']
258
+ label_display = label.replace('_', ' ')
259
+ prob = float(row['probability'])
260
+ is_predicted = result_dict[label]["predicted"]
261
+ threshold = float(THRESHOLDS[LABEL_COLS.index(label)])
262
+ threshold_percent = threshold * 100
263
+ prob_percent = prob * 100
264
+
265
+ # Badge to show if predicted toxic
266
+ badge_class = "predicted-true" if is_predicted else "predicted-false"
267
+ badge_text = "🚨 Toxic" if is_predicted else "✓ Safe"
268
+
269
+ html += "<div class='row'>"
270
+ html += f"<div class='label'>{label_display}</div>"
271
+ html += "<div style='display:flex; align-items:center; gap:8px; flex:1;'>"
272
+ html += "<div class='bar-container'>"
273
+ html += "<div class='bar-bg' style='position:relative;'>"
274
+ html += f"<div class='bar' style='width:{prob_percent:.2f}%;'></div>"
275
+ html += f"<div class='threshold-line' style='left:{threshold_percent:.2f}%;'></div>"
276
+ html += "</div>"
277
+ html += f"<div class='bar-labels'><span>0%</span><span style='text-align:center; flex:1;'>Threshold: {threshold_percent:.1f}%</span><span>100%</span></div>"
278
+ html += "</div>"
279
+ html += f"<div class='prob'>{prob_percent:.1f}%</div>"
280
+ html += f"<div class='predicted-badge {badge_class}'>{badge_text}</div>"
281
+ html += "</div></div>"
282
+ html += "</div>"
283
+ html += "</div>"
284
+ return html
285
+
286
+ # Nouvelle UI Gradio (English)
287
+ with gr.Blocks(title="Toxicity Analyzer") as demo:
288
+ gr.HTML("<h2 style='margin:8px 0;color:#e6eef8;font-family:Inter,Arial;text-align:center;'>🛡️ Toxicity Analyzer</h2>")
289
+
290
+ with gr.Row():
291
+ with gr.Column(scale=2):
292
+ txt = gr.Textbox(label="Text to analyze", placeholder="Type or paste text here...", lines=6)
293
+ with gr.Row():
294
+ btn = gr.Button("Analyze", variant="primary", scale=2)
295
+ btn_clear = gr.Button("Clear", scale=1)
296
+
297
+ with gr.Row():
298
+ out_html = gr.HTML()
299
+
300
+ download_file = gr.File(label="📥 Download CSV", visible=False)
301
+
302
+ def analyze(text):
303
+ df, result_dict = predict_toxicity(text)
304
+ html = build_result_html(df, result_dict, text)
305
+ csv_path = save_df_to_csv(df)
306
+ return html, csv_path
307
+
308
+ def clear_all():
309
+ return "", gr.update(visible=False)
310
+
311
+ btn.click(analyze, inputs=txt, outputs=[out_html, download_file])
312
+ btn_clear.click(clear_all, inputs=None, outputs=[txt, download_file])
313
+
314
+ gr.Examples(examples=[
315
+ "I will kill you!",
316
+ "You are wonderful and helpful.",
317
+ "Get out of here, you idiot.",
318
+ "This is the best day ever!",
319
+ "I hate everything about this.",
320
+ "You are so stupid and worthless.",
321
+ "Let's grab coffee tomorrow.",
322
+ "Go die in a fire.",
323
+ "Have a great day!",
324
+ "I'm going to punch you in the face."
325
+ ], inputs=txt)
326
+
327
+ if __name__ == "__main__":
328
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ transformers
4
+ gradio