Spaces:
Sleeping
Sleeping
Initial deploy
Browse files- README.md +11 -6
- app.py +155 -0
- requirements.txt +5 -0
README.md
CHANGED
|
@@ -1,12 +1,17 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SHAP Text Explainer
|
| 3 |
+
emoji: "\U0001F50E"
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "5.29.0"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# SHAP Text Explainer
|
| 14 |
+
|
| 15 |
+
See which words push a text classifier toward positive or negative.
|
| 16 |
+
|
| 17 |
+
**Course**: 215 AI Safety ch8 — Explainability (NLP)
|
app.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SHAP Text Explainer — Word-level attribution for text classification
|
| 3 |
+
Course: 215 AI Safety ch8
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from transformers import pipeline
|
| 10 |
+
|
| 11 |
+
# Load sentiment model
|
| 12 |
+
classifier = pipeline(
|
| 13 |
+
"sentiment-analysis",
|
| 14 |
+
model="distilbert-base-uncased-finetuned-sst-2-english",
|
| 15 |
+
return_all_scores=True,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
LABEL_NAMES = ["NEGATIVE", "POSITIVE"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def simple_word_attribution(text: str):
|
| 22 |
+
"""
|
| 23 |
+
Compute word-level attribution using leave-one-out (LOO) method.
|
| 24 |
+
Faster and more reliable than full SHAP on CPU.
|
| 25 |
+
"""
|
| 26 |
+
if not text.strip():
|
| 27 |
+
return "", "", {}
|
| 28 |
+
|
| 29 |
+
# Baseline prediction
|
| 30 |
+
base_result = classifier(text)[0]
|
| 31 |
+
base_scores = {r["label"]: r["score"] for r in base_result}
|
| 32 |
+
pred_label = max(base_scores, key=base_scores.get)
|
| 33 |
+
pred_score = base_scores[pred_label]
|
| 34 |
+
|
| 35 |
+
words = text.split()
|
| 36 |
+
if len(words) == 0:
|
| 37 |
+
return "", "", {}
|
| 38 |
+
|
| 39 |
+
# LOO attribution
|
| 40 |
+
attributions = []
|
| 41 |
+
for i in range(len(words)):
|
| 42 |
+
masked = " ".join(words[:i] + words[i + 1 :])
|
| 43 |
+
if not masked.strip():
|
| 44 |
+
attributions.append(0.0)
|
| 45 |
+
continue
|
| 46 |
+
result = classifier(masked)[0]
|
| 47 |
+
masked_scores = {r["label"]: r["score"] for r in result}
|
| 48 |
+
# Attribution = how much removing this word changes the predicted class score
|
| 49 |
+
diff = base_scores[pred_label] - masked_scores[pred_label]
|
| 50 |
+
attributions.append(diff)
|
| 51 |
+
|
| 52 |
+
# Normalize for display
|
| 53 |
+
max_abs = max(abs(a) for a in attributions) if attributions else 1.0
|
| 54 |
+
if max_abs == 0:
|
| 55 |
+
max_abs = 1.0
|
| 56 |
+
|
| 57 |
+
# Build highlighted HTML
|
| 58 |
+
html_parts = []
|
| 59 |
+
for word, attr in zip(words, attributions):
|
| 60 |
+
norm_attr = attr / max_abs # -1 to 1
|
| 61 |
+
if norm_attr > 0:
|
| 62 |
+
# Pushes toward prediction (red = positive contribution)
|
| 63 |
+
intensity = min(int(abs(norm_attr) * 200), 200)
|
| 64 |
+
bg = f"rgba(239, 68, 68, {abs(norm_attr) * 0.6})"
|
| 65 |
+
else:
|
| 66 |
+
# Pushes against prediction (blue = negative contribution)
|
| 67 |
+
intensity = min(int(abs(norm_attr) * 200), 200)
|
| 68 |
+
bg = f"rgba(59, 130, 246, {abs(norm_attr) * 0.6})"
|
| 69 |
+
html_parts.append(
|
| 70 |
+
f'<span style="background:{bg};padding:2px 4px;margin:1px;'
|
| 71 |
+
f'border-radius:3px;display:inline-block;">{word}</span>'
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
highlighted_html = (
|
| 75 |
+
'<div style="font-size:16px;line-height:2;padding:10px;">'
|
| 76 |
+
+ " ".join(html_parts)
|
| 77 |
+
+ "</div>"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Legend
|
| 81 |
+
legend = (
|
| 82 |
+
'<div style="margin-top:10px;font-size:13px;">'
|
| 83 |
+
'<span style="background:rgba(239,68,68,0.5);padding:2px 8px;border-radius:3px;">Red</span>'
|
| 84 |
+
f" = pushes toward {pred_label} "
|
| 85 |
+
'<span style="background:rgba(59,130,246,0.5);padding:2px 8px;border-radius:3px;">Blue</span>'
|
| 86 |
+
f" = pushes against {pred_label} "
|
| 87 |
+
"(intensity = strength)"
|
| 88 |
+
"</div>"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Prediction info
|
| 92 |
+
pred_info = (
|
| 93 |
+
f"**Prediction: {pred_label}** ({pred_score:.1%})\n\n"
|
| 94 |
+
f"| Label | Score |\n|---|---|\n"
|
| 95 |
+
)
|
| 96 |
+
for r in base_result:
|
| 97 |
+
pred_info += f"| {r['label']} | {r['score']:.1%} |\n"
|
| 98 |
+
|
| 99 |
+
# Attribution table
|
| 100 |
+
pred_info += "\n**Word attributions (leave-one-out):**\n\n"
|
| 101 |
+
pred_info += "| Word | Attribution | Effect |\n|---|---|---|\n"
|
| 102 |
+
sorted_attr = sorted(
|
| 103 |
+
zip(words, attributions), key=lambda x: abs(x[1]), reverse=True
|
| 104 |
+
)
|
| 105 |
+
for word, attr in sorted_attr[:15]:
|
| 106 |
+
direction = "supports" if attr > 0 else "opposes"
|
| 107 |
+
pred_info += f"| {word} | {attr:+.4f} | {direction} |\n"
|
| 108 |
+
|
| 109 |
+
return highlighted_html + legend, pred_info, base_scores
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
with gr.Blocks(title="SHAP Text Explainer") as demo:
|
| 113 |
+
gr.Markdown(
|
| 114 |
+
"# SHAP Text Explainer\n"
|
| 115 |
+
"Enter text to see which words contribute most to the sentiment prediction.\n"
|
| 116 |
+
"Uses leave-one-out attribution (similar to SHAP) for word-level explanations.\n"
|
| 117 |
+
"*Course: 215 AI Safety ch8 — Explainability*"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
with gr.Row():
|
| 121 |
+
with gr.Column():
|
| 122 |
+
text_input = gr.Textbox(
|
| 123 |
+
label="Input Text",
|
| 124 |
+
placeholder="Enter a sentence to analyze...",
|
| 125 |
+
lines=3,
|
| 126 |
+
)
|
| 127 |
+
btn = gr.Button("Explain", variant="primary")
|
| 128 |
+
gr.Markdown(
|
| 129 |
+
"*Note: Each word is removed one at a time to measure its impact. "
|
| 130 |
+
"This takes a few seconds for longer texts.*"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
with gr.Column():
|
| 134 |
+
highlighted = gr.HTML(label="Word Attribution")
|
| 135 |
+
details_md = gr.Markdown()
|
| 136 |
+
|
| 137 |
+
btn.click(
|
| 138 |
+
lambda t: simple_word_attribution(t)[:2],
|
| 139 |
+
[text_input],
|
| 140 |
+
[highlighted, details_md],
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
gr.Examples(
|
| 144 |
+
examples=[
|
| 145 |
+
"This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.",
|
| 146 |
+
"The food was terrible and the service was even worse. I will never go back to this restaurant.",
|
| 147 |
+
"The product works okay but nothing special. It does what it says but I expected more for the price.",
|
| 148 |
+
"I love how this book combines beautiful writing with deep philosophical insights.",
|
| 149 |
+
"The flight was delayed by 3 hours and the airline offered no compensation or explanation.",
|
| 150 |
+
],
|
| 151 |
+
inputs=[text_input],
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.0.0
|
| 2 |
+
transformers>=4.30.0
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
shap>=0.42.0
|
| 5 |
+
numpy
|