| |
|
|
| |
| import os |
| os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
| import json |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import gradio as gr |
| from transformers import AutoModel, AutoTokenizer, AutoConfig |
| from safetensors.torch import load_file |
| from captum.attr import LayerIntegratedGradients |
|
|
| |
| |
| |
| ARTI_DIR = "artifacts" |
| BEST_DIR = os.path.join(ARTI_DIR, "best") |
| THRESH_FP = os.path.join(ARTI_DIR, "thresholds.json") |
|
|
| LABELS = ["toxic","severe_toxic","obscene","threat","insult","identity_hate"] |
| NUM_LABELS = len(LABELS) |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| MAX_LEN = 256 |
| BASE_MODEL = "distilbert-base-uncased" |
|
|
| |
| |
| |
| class ToxicMultiLabel(nn.Module): |
| """ |
| DistilBERT backbone + single linear head -> multi-label logits. |
| (We apply sigmoid at inference to get probabilities.) |
| """ |
| def __init__(self, base_model_name: str, num_labels: int, head_dropout: float = 0.30): |
| super().__init__() |
| cfg = AutoConfig.from_pretrained(base_model_name) |
| self.backbone = AutoModel.from_pretrained(base_model_name, config=cfg) |
| hidden = self.backbone.config.hidden_size |
| self.dropout = nn.Dropout(head_dropout) |
| self.classifier = nn.Linear(hidden, num_labels) |
|
|
| def forward(self, input_ids=None, attention_mask=None): |
| out = self.backbone(input_ids=input_ids, attention_mask=attention_mask) |
| cls = out.last_hidden_state[:, 0] |
| logits = self.classifier(self.dropout(cls)) |
| return logits |
|
|
| |
| |
| |
| def load_artifacts(): |
| |
| tok_src = BEST_DIR if os.path.isfile(os.path.join(BEST_DIR, "tokenizer.json")) else BASE_MODEL |
| tok = AutoTokenizer.from_pretrained(tok_src, use_fast=True) |
|
|
| |
| model = ToxicMultiLabel(BASE_MODEL, NUM_LABELS) |
| safep = os.path.join(BEST_DIR, "model.safetensors") |
| binp = os.path.join(BEST_DIR, "pytorch_model.bin") |
|
|
| if os.path.isfile(safep): |
| state = load_file(safep) |
| elif os.path.isfile(binp): |
| state = torch.load(binp, map_location="cpu") |
| else: |
| raise FileNotFoundError("No weights found (model.safetensors / pytorch_model.bin) in artifacts/best/") |
|
|
| |
| for k in list(state.keys()): |
| if k.startswith("pos_weight") or k.startswith("loss_fn"): |
| state.pop(k, None) |
|
|
| model.load_state_dict(state, strict=True) |
| model.to(DEVICE).eval() |
|
|
| |
| if os.path.isfile(THRESH_FP): |
| with open(THRESH_FP) as f: |
| thresholds = json.load(f) |
| else: |
| thresholds = {lab: 0.5 for lab in LABELS} |
| os.makedirs(ARTI_DIR, exist_ok=True) |
| with open(THRESH_FP, "w") as f: |
| json.dump(thresholds, f, indent=2) |
|
|
| return model, tok, thresholds |
|
|
| MODEL, TOK, THRESH = load_artifacts() |
|
|
| |
| |
| |
| @torch.no_grad() |
| def classify_comment(text: str): |
| """ |
| Returns: (DataFrame of per-label predictions, comma-separated positives) |
| """ |
| text = (text or "").strip() |
| if not text: |
| return pd.DataFrame(columns=["label","probability","threshold","margin","decision"]), "(none)" |
|
|
| enc = TOK(text, truncation=True, padding=True, max_length=MAX_LEN, return_tensors="pt") |
| enc = {k: v.to(DEVICE) for k, v in enc.items()} |
| logits = MODEL(**enc).squeeze(0).detach().cpu().numpy() |
| probs = 1.0 / (1.0 + np.exp(-logits)) |
|
|
| rows = [] |
| for i, lab in enumerate(LABELS): |
| p = float(probs[i]) |
| t = float(THRESH.get(lab, 0.5)) |
| rows.append({ |
| "label": lab, |
| "probability": round(p, 4), |
| "threshold": round(t, 4), |
| "margin": round(p - t, 4), |
| "decision": "POS" if p >= t else "NEG", |
| }) |
|
|
| df = pd.DataFrame(rows).sort_values( |
| ["decision", "margin", "probability"], |
| ascending=[False, False, False] |
| ).reset_index(drop=True) |
|
|
| positives = [r["label"] for r in rows if r["probability"] >= r["threshold"]] |
| return df, ", ".join(positives) if positives else "(none)" |
|
|
| |
| |
| |
| |
| EMB_LAYER = MODEL.backbone.embeddings.word_embeddings |
|
|
| |
| def _forward_for_label(input_ids, attention_mask, class_index: int): |
| logits = MODEL(input_ids=input_ids, attention_mask=attention_mask) |
| return logits[:, class_index] |
|
|
| LIG = LayerIntegratedGradients(_forward_for_label, EMB_LAYER) |
|
|
| def _tokenize_with_offsets(text: str): |
| return TOK(text, truncation=True, padding=True, max_length=MAX_LEN, |
| return_tensors="pt", return_offsets_mapping=True) |
|
|
| def _merge_wordpieces(tokens, offsets, scores): |
| """Merge WordPiece tokens (##subwords) into words; sum scores.""" |
| words = [] |
| for tok_piece, (start, end), sc in zip(tokens, offsets, scores): |
| |
| if (start, end) == (0, 0) and tok_piece.startswith("[") and tok_piece.endswith("]"): |
| continue |
| if tok_piece.startswith("##") and words: |
| words[-1]["text"] += tok_piece[2:] |
| words[-1]["end"] = end |
| words[-1]["score"] += float(sc) |
| else: |
| words.append({"text": tok_piece, "start": start, "end": end, "score": float(sc)}) |
| return words |
|
|
| @torch.no_grad() |
| def _predict_probs(text: str): |
| enc = TOK(text, truncation=True, padding=True, max_length=MAX_LEN, return_tensors="pt") |
| enc = {k: v.to(DEVICE) for k, v in enc.items()} |
| logits = MODEL(**enc).squeeze(0).detach().cpu().numpy() |
| return 1.0 / (1.0 + np.exp(-logits)) |
|
|
| def explain_comment(text: str, target_label: str, steps: int = 30): |
| """ |
| Returns (HTML with colored spans, selected label prob as string). |
| Red = supports the label; Blue = opposes the label. |
| """ |
| import html as ihtml |
|
|
| text = (text or "").strip() |
| if not text: |
| return "<i>Provide a comment to explain.</i>", "0.000" |
|
|
| idx = LABELS.index(target_label) |
| enc = _tokenize_with_offsets(text) |
| input_ids = enc["input_ids"].to(DEVICE) |
| attention_mask = enc["attention_mask"].to(DEVICE) |
| offsets = enc["offset_mapping"][0].tolist() |
| tokens = TOK.convert_ids_to_tokens(enc["input_ids"][0]) |
|
|
| |
| ref_ids = torch.full_like(input_ids, TOK.pad_token_id) |
|
|
| |
| res = LIG.attribute( |
| inputs=input_ids, |
| baselines=ref_ids, |
| additional_forward_args=(attention_mask, idx), |
| n_steps=int(max(4, steps)), |
| return_convergence_delta=True, |
| ) |
| attributions = res[0] if isinstance(res, tuple) else res |
| token_attr = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy() |
|
|
| pieces = _merge_wordpieces(tokens, offsets, token_attr) |
| arr = np.array([p["score"] for p in pieces], dtype=np.float32) |
| denom = float(np.max(np.abs(arr))) if np.max(np.abs(arr)) > 1e-8 else 1.0 |
| for p in pieces: |
| p["score_norm"] = p["score"] / denom |
|
|
| def _color_for(s: float) -> str: |
| alpha = min(1.0, max(0.06, abs(s))) |
| return f"rgba(255,0,0,{alpha:.25f})" if s >= 0 else f"rgba(0,0,255,{alpha:.25f})" |
|
|
| out, last = "", 0 |
| for p in pieces: |
| out += ihtml.escape(text[last:p["start"]]) |
| out += ( |
| f'<span title="score={p["score_norm"]:+.3f}" ' |
| f'style="background:{_color_for(p["score_norm"])}; padding:1px 2px; border-radius:3px;">' |
| f'{ihtml.escape(text[p["start"]:p["end"]])}</span>' |
| ) |
| last = p["end"] |
| out += ihtml.escape(text[last:]) |
|
|
| probs = _predict_probs(text) |
| prob = float(probs[idx]) |
| header = ( |
| f"<h4 style='margin:6px 0;'>Label: <code>{target_label}</code> " |
| f"| Prob: {prob:.3f}</h4>" |
| "<div style='margin:4px 0 8px 0;'>Legend: " |
| "<span style='background:rgba(255,0,0,.25);padding:0 6px;'>supports</span> " |
| "<span style='background:rgba(0,0,255,.25);padding:0 6px;'>opposes</span></div>" |
| ) |
| html_block = header + f"<div style='font-family:ui-sans-serif,system-ui;line-height:1.7;font-size:15px;'>{out}</div>" |
| return html_block, f"{prob:.3f}" |
|
|
| |
| |
| |
| EXAMPLES = [ |
| "You are a complete idiot. Get banned already.", |
| "I will kill you tomorrow. Watch your back.", |
| "Thanks for your help—really appreciate your time!", |
| "Shut up, this is the dumbest edit ever.", |
| "Go away, you people don't belong here.", |
| ] |
|
|
| with gr.Blocks( |
| title="🧠 Toxic Comment Classifier & Explainer", |
| theme=gr.themes.Soft(primary_hue="blue") |
| ) as demo: |
| gr.Markdown( |
| f""" |
| # 🧠 Toxic Comment Classifier & Explainer |
| A DistilBERT-based **multi-label** model for detecting toxicity in online comments |
| with **Integrated Gradients** explanations (Captum). |
| |
| **Device:** `{DEVICE}` • **Max length:** {MAX_LEN} |
| """ |
| ) |
|
|
| |
| txt = gr.Textbox( |
| label="Enter a comment", |
| lines=4, |
| value=EXAMPLES[1], |
| placeholder="Type or paste a comment here…" |
| ) |
|
|
| with gr.Tab("🔍 Classify"): |
| btn = gr.Button("Classify", variant="primary") |
| out_tbl = gr.Dataframe( |
| headers=["label","probability","threshold","margin","decision"], |
| label="Per-label predictions", |
| interactive=False, wrap=True |
| ) |
| out_pos = gr.Textbox(label="Predicted positive labels", interactive=False) |
| btn.click(classify_comment, inputs=txt, outputs=[out_tbl, out_pos]) |
| gr.Examples(EXAMPLES, inputs=txt, label="Examples") |
|
|
| with gr.Tab("🧩 Explain"): |
| lab_dd = gr.Dropdown(choices=LABELS, value="toxic", label="Target label") |
| steps_slider = gr.Slider(6, 80, value=30, step=2, |
| label="IG steps (higher = smoother, slower)") |
| explain_btn = gr.Button("Generate explanation", variant="primary") |
| prob_box = gr.Textbox(label="Selected label probability", interactive=False) |
| html_vis = gr.HTML(label="Attribution heatmap") |
| explain_btn.click( |
| fn=explain_comment, |
| inputs=[txt, lab_dd, steps_slider], |
| outputs=[html_vis, prob_box] |
| ) |
| gr.Examples(EXAMPLES, inputs=txt, label="Examples for Explain") |
|
|
| with gr.Accordion("ℹ️ About & Responsible Use", open=False): |
| gr.Markdown( |
| """ |
| **Labels:** `toxic`, `severe_toxic`, `obscene`, `threat`, `insult`, `identity_hate` |
| This demo is for **research/education**. Do not use as-is for moderation without |
| human oversight, bias assessment, and policy alignment. Explanations |
| (IG attributions) are **heuristics**, not proof of model causality. |
| """ |
| ) |
|
|
| if __name__ == "__main__": |
| |
| demo.launch(share=False) |
|
|