# app.py — Classify + Explain (Captum IG) — polished UX # (Optional) silence common warnings on Windows/HF import os os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import json import numpy as np import pandas as pd import torch import torch.nn as nn import gradio as gr from transformers import AutoModel, AutoTokenizer, AutoConfig from safetensors.torch import load_file from captum.attr import LayerIntegratedGradients # explainability # ---------------------------- # Paths / labels / config # ---------------------------- ARTI_DIR = "artifacts" BEST_DIR = os.path.join(ARTI_DIR, "best") THRESH_FP = os.path.join(ARTI_DIR, "thresholds.json") LABELS = ["toxic","severe_toxic","obscene","threat","insult","identity_hate"] NUM_LABELS = len(LABELS) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_LEN = 256 BASE_MODEL = "distilbert-base-uncased" # same backbone as training # ---------------------------- # Model definition (same logic) # ---------------------------- class ToxicMultiLabel(nn.Module): """ DistilBERT backbone + single linear head -> multi-label logits. (We apply sigmoid at inference to get probabilities.) """ def __init__(self, base_model_name: str, num_labels: int, head_dropout: float = 0.30): super().__init__() cfg = AutoConfig.from_pretrained(base_model_name) self.backbone = AutoModel.from_pretrained(base_model_name, config=cfg) hidden = self.backbone.config.hidden_size self.dropout = nn.Dropout(head_dropout) self.classifier = nn.Linear(hidden, num_labels) def forward(self, input_ids=None, attention_mask=None): out = self.backbone(input_ids=input_ids, attention_mask=attention_mask) cls = out.last_hidden_state[:, 0] # [CLS]-like token logits = self.classifier(self.dropout(cls)) # (B, L) return logits # ---------------------------- # Load artifacts (tokenizer, model, thresholds) # ---------------------------- def load_artifacts(): # tokenizer (prefer the saved one if present) tok_src = BEST_DIR if os.path.isfile(os.path.join(BEST_DIR, "tokenizer.json")) else BASE_MODEL tok = AutoTokenizer.from_pretrained(tok_src, use_fast=True) # model weights model = ToxicMultiLabel(BASE_MODEL, NUM_LABELS) safep = os.path.join(BEST_DIR, "model.safetensors") binp = os.path.join(BEST_DIR, "pytorch_model.bin") if os.path.isfile(safep): state = load_file(safep) elif os.path.isfile(binp): state = torch.load(binp, map_location="cpu") else: raise FileNotFoundError("No weights found (model.safetensors / pytorch_model.bin) in artifacts/best/") # strip training-only keys if any slipped in for k in list(state.keys()): if k.startswith("pos_weight") or k.startswith("loss_fn"): state.pop(k, None) model.load_state_dict(state, strict=True) model.to(DEVICE).eval() # thresholds if os.path.isfile(THRESH_FP): with open(THRESH_FP) as f: thresholds = json.load(f) else: thresholds = {lab: 0.5 for lab in LABELS} os.makedirs(ARTI_DIR, exist_ok=True) with open(THRESH_FP, "w") as f: json.dump(thresholds, f, indent=2) return model, tok, thresholds MODEL, TOK, THRESH = load_artifacts() # ========================= # Inference (Classify tab) # ========================= @torch.no_grad() def classify_comment(text: str): """ Returns: (DataFrame of per-label predictions, comma-separated positives) """ text = (text or "").strip() if not text: return pd.DataFrame(columns=["label","probability","threshold","margin","decision"]), "(none)" enc = TOK(text, truncation=True, padding=True, max_length=MAX_LEN, return_tensors="pt") enc = {k: v.to(DEVICE) for k, v in enc.items()} logits = MODEL(**enc).squeeze(0).detach().cpu().numpy() probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid rows = [] for i, lab in enumerate(LABELS): p = float(probs[i]) t = float(THRESH.get(lab, 0.5)) rows.append({ "label": lab, "probability": round(p, 4), "threshold": round(t, 4), "margin": round(p - t, 4), "decision": "POS" if p >= t else "NEG", }) df = pd.DataFrame(rows).sort_values( ["decision", "margin", "probability"], ascending=[False, False, False] ).reset_index(drop=True) positives = [r["label"] for r in rows if r["probability"] >= r["threshold"]] return df, ", ".join(positives) if positives else "(none)" # ========================= # Explainability (IG tab) # ========================= # Layer IG on embedding layer EMB_LAYER = MODEL.backbone.embeddings.word_embeddings # Captum forward: single logit for chosen label def _forward_for_label(input_ids, attention_mask, class_index: int): logits = MODEL(input_ids=input_ids, attention_mask=attention_mask) # (B, L) return logits[:, class_index] LIG = LayerIntegratedGradients(_forward_for_label, EMB_LAYER) def _tokenize_with_offsets(text: str): return TOK(text, truncation=True, padding=True, max_length=MAX_LEN, return_tensors="pt", return_offsets_mapping=True) def _merge_wordpieces(tokens, offsets, scores): """Merge WordPiece tokens (##subwords) into words; sum scores.""" words = [] for tok_piece, (start, end), sc in zip(tokens, offsets, scores): # skip special tokens with (0,0) offsets if (start, end) == (0, 0) and tok_piece.startswith("[") and tok_piece.endswith("]"): continue if tok_piece.startswith("##") and words: words[-1]["text"] += tok_piece[2:] words[-1]["end"] = end words[-1]["score"] += float(sc) else: words.append({"text": tok_piece, "start": start, "end": end, "score": float(sc)}) return words @torch.no_grad() def _predict_probs(text: str): enc = TOK(text, truncation=True, padding=True, max_length=MAX_LEN, return_tensors="pt") enc = {k: v.to(DEVICE) for k, v in enc.items()} logits = MODEL(**enc).squeeze(0).detach().cpu().numpy() return 1.0 / (1.0 + np.exp(-logits)) # (L,) def explain_comment(text: str, target_label: str, steps: int = 30): """ Returns (HTML with colored spans, selected label prob as string). Red = supports the label; Blue = opposes the label. """ import html as ihtml text = (text or "").strip() if not text: return "Provide a comment to explain.", "0.000" idx = LABELS.index(target_label) enc = _tokenize_with_offsets(text) input_ids = enc["input_ids"].to(DEVICE) attention_mask = enc["attention_mask"].to(DEVICE) offsets = enc["offset_mapping"][0].tolist() tokens = TOK.convert_ids_to_tokens(enc["input_ids"][0]) # PAD baseline ref_ids = torch.full_like(input_ids, TOK.pad_token_id) # Be robust to Captum return signature res = LIG.attribute( inputs=input_ids, baselines=ref_ids, additional_forward_args=(attention_mask, idx), n_steps=int(max(4, steps)), return_convergence_delta=True, ) attributions = res[0] if isinstance(res, tuple) else res token_attr = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy() pieces = _merge_wordpieces(tokens, offsets, token_attr) arr = np.array([p["score"] for p in pieces], dtype=np.float32) denom = float(np.max(np.abs(arr))) if np.max(np.abs(arr)) > 1e-8 else 1.0 for p in pieces: p["score_norm"] = p["score"] / denom def _color_for(s: float) -> str: alpha = min(1.0, max(0.06, abs(s))) return f"rgba(255,0,0,{alpha:.25f})" if s >= 0 else f"rgba(0,0,255,{alpha:.25f})" out, last = "", 0 for p in pieces: out += ihtml.escape(text[last:p["start"]]) out += ( f'' f'{ihtml.escape(text[p["start"]:p["end"]])}' ) last = p["end"] out += ihtml.escape(text[last:]) probs = _predict_probs(text) prob = float(probs[idx]) header = ( f"
{target_label} "
f"| Prob: {prob:.3f}