Spaces:
Sleeping
Sleeping
File size: 5,395 Bytes
275113a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
SHAP Text Explainer — Word-level attribution for text classification
Course: 215 AI Safety ch8
"""
import numpy as np
import torch
import gradio as gr
from transformers import pipeline
# Load sentiment model
classifier = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
return_all_scores=True,
)
LABEL_NAMES = ["NEGATIVE", "POSITIVE"]
def simple_word_attribution(text: str):
"""
Compute word-level attribution using leave-one-out (LOO) method.
Faster and more reliable than full SHAP on CPU.
"""
if not text.strip():
return "", "", {}
# Baseline prediction
base_result = classifier(text)[0]
base_scores = {r["label"]: r["score"] for r in base_result}
pred_label = max(base_scores, key=base_scores.get)
pred_score = base_scores[pred_label]
words = text.split()
if len(words) == 0:
return "", "", {}
# LOO attribution
attributions = []
for i in range(len(words)):
masked = " ".join(words[:i] + words[i + 1 :])
if not masked.strip():
attributions.append(0.0)
continue
result = classifier(masked)[0]
masked_scores = {r["label"]: r["score"] for r in result}
# Attribution = how much removing this word changes the predicted class score
diff = base_scores[pred_label] - masked_scores[pred_label]
attributions.append(diff)
# Normalize for display
max_abs = max(abs(a) for a in attributions) if attributions else 1.0
if max_abs == 0:
max_abs = 1.0
# Build highlighted HTML
html_parts = []
for word, attr in zip(words, attributions):
norm_attr = attr / max_abs # -1 to 1
if norm_attr > 0:
# Pushes toward prediction (red = positive contribution)
intensity = min(int(abs(norm_attr) * 200), 200)
bg = f"rgba(239, 68, 68, {abs(norm_attr) * 0.6})"
else:
# Pushes against prediction (blue = negative contribution)
intensity = min(int(abs(norm_attr) * 200), 200)
bg = f"rgba(59, 130, 246, {abs(norm_attr) * 0.6})"
html_parts.append(
f'<span style="background:{bg};padding:2px 4px;margin:1px;'
f'border-radius:3px;display:inline-block;">{word}</span>'
)
highlighted_html = (
'<div style="font-size:16px;line-height:2;padding:10px;">'
+ " ".join(html_parts)
+ "</div>"
)
# Legend
legend = (
'<div style="margin-top:10px;font-size:13px;">'
'<span style="background:rgba(239,68,68,0.5);padding:2px 8px;border-radius:3px;">Red</span>'
f" = pushes toward {pred_label} "
'<span style="background:rgba(59,130,246,0.5);padding:2px 8px;border-radius:3px;">Blue</span>'
f" = pushes against {pred_label} "
"(intensity = strength)"
"</div>"
)
# Prediction info
pred_info = (
f"**Prediction: {pred_label}** ({pred_score:.1%})\n\n"
f"| Label | Score |\n|---|---|\n"
)
for r in base_result:
pred_info += f"| {r['label']} | {r['score']:.1%} |\n"
# Attribution table
pred_info += "\n**Word attributions (leave-one-out):**\n\n"
pred_info += "| Word | Attribution | Effect |\n|---|---|---|\n"
sorted_attr = sorted(
zip(words, attributions), key=lambda x: abs(x[1]), reverse=True
)
for word, attr in sorted_attr[:15]:
direction = "supports" if attr > 0 else "opposes"
pred_info += f"| {word} | {attr:+.4f} | {direction} |\n"
return highlighted_html + legend, pred_info, base_scores
with gr.Blocks(title="SHAP Text Explainer") as demo:
gr.Markdown(
"# SHAP Text Explainer\n"
"Enter text to see which words contribute most to the sentiment prediction.\n"
"Uses leave-one-out attribution (similar to SHAP) for word-level explanations.\n"
"*Course: 215 AI Safety ch8 — Explainability*"
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Input Text",
placeholder="Enter a sentence to analyze...",
lines=3,
)
btn = gr.Button("Explain", variant="primary")
gr.Markdown(
"*Note: Each word is removed one at a time to measure its impact. "
"This takes a few seconds for longer texts.*"
)
with gr.Column():
highlighted = gr.HTML(label="Word Attribution")
details_md = gr.Markdown()
btn.click(
lambda t: simple_word_attribution(t)[:2],
[text_input],
[highlighted, details_md],
)
gr.Examples(
examples=[
"This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.",
"The food was terrible and the service was even worse. I will never go back to this restaurant.",
"The product works okay but nothing special. It does what it says but I expected more for the price.",
"I love how this book combines beautiful writing with deep philosophical insights.",
"The flight was delayed by 3 hours and the airline offered no compensation or explanation.",
],
inputs=[text_input],
)
if __name__ == "__main__":
demo.launch()
|