"""
SHAP Text Explainer — Word-level attribution for text classification
Course: 215 AI Safety ch8
"""
import numpy as np
import torch
import gradio as gr
from transformers import pipeline
# Load sentiment model
classifier = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
return_all_scores=True,
)
LABEL_NAMES = ["NEGATIVE", "POSITIVE"]
def simple_word_attribution(text: str):
"""
Compute word-level attribution using leave-one-out (LOO) method.
Faster and more reliable than full SHAP on CPU.
"""
if not text.strip():
return "", "", {}
# Baseline prediction
base_result = classifier(text)[0]
base_scores = {r["label"]: r["score"] for r in base_result}
pred_label = max(base_scores, key=base_scores.get)
pred_score = base_scores[pred_label]
words = text.split()
if len(words) == 0:
return "", "", {}
# LOO attribution
attributions = []
for i in range(len(words)):
masked = " ".join(words[:i] + words[i + 1 :])
if not masked.strip():
attributions.append(0.0)
continue
result = classifier(masked)[0]
masked_scores = {r["label"]: r["score"] for r in result}
# Attribution = how much removing this word changes the predicted class score
diff = base_scores[pred_label] - masked_scores[pred_label]
attributions.append(diff)
# Normalize for display
max_abs = max(abs(a) for a in attributions) if attributions else 1.0
if max_abs == 0:
max_abs = 1.0
# Build highlighted HTML
html_parts = []
for word, attr in zip(words, attributions):
norm_attr = attr / max_abs # -1 to 1
if norm_attr > 0:
# Pushes toward prediction (red = positive contribution)
intensity = min(int(abs(norm_attr) * 200), 200)
bg = f"rgba(239, 68, 68, {abs(norm_attr) * 0.6})"
else:
# Pushes against prediction (blue = negative contribution)
intensity = min(int(abs(norm_attr) * 200), 200)
bg = f"rgba(59, 130, 246, {abs(norm_attr) * 0.6})"
html_parts.append(
f'{word}'
)
highlighted_html = (
'
'
+ " ".join(html_parts)
+ "
"
)
# Legend
legend = (
''
'Red'
f" = pushes toward {pred_label} "
'Blue'
f" = pushes against {pred_label} "
"(intensity = strength)"
"
"
)
# Prediction info
pred_info = (
f"**Prediction: {pred_label}** ({pred_score:.1%})\n\n"
f"| Label | Score |\n|---|---|\n"
)
for r in base_result:
pred_info += f"| {r['label']} | {r['score']:.1%} |\n"
# Attribution table
pred_info += "\n**Word attributions (leave-one-out):**\n\n"
pred_info += "| Word | Attribution | Effect |\n|---|---|---|\n"
sorted_attr = sorted(
zip(words, attributions), key=lambda x: abs(x[1]), reverse=True
)
for word, attr in sorted_attr[:15]:
direction = "supports" if attr > 0 else "opposes"
pred_info += f"| {word} | {attr:+.4f} | {direction} |\n"
return highlighted_html + legend, pred_info, base_scores
with gr.Blocks(title="SHAP Text Explainer") as demo:
gr.Markdown(
"# SHAP Text Explainer\n"
"Enter text to see which words contribute most to the sentiment prediction.\n"
"Uses leave-one-out attribution (similar to SHAP) for word-level explanations.\n"
"*Course: 215 AI Safety ch8 — Explainability*"
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Input Text",
placeholder="Enter a sentence to analyze...",
lines=3,
)
btn = gr.Button("Explain", variant="primary")
gr.Markdown(
"*Note: Each word is removed one at a time to measure its impact. "
"This takes a few seconds for longer texts.*"
)
with gr.Column():
highlighted = gr.HTML(label="Word Attribution")
details_md = gr.Markdown()
btn.click(
lambda t: simple_word_attribution(t)[:2],
[text_input],
[highlighted, details_md],
)
gr.Examples(
examples=[
"This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.",
"The food was terrible and the service was even worse. I will never go back to this restaurant.",
"The product works okay but nothing special. It does what it says but I expected more for the price.",
"I love how this book combines beautiful writing with deep philosophical insights.",
"The flight was delayed by 3 hours and the airline offered no compensation or explanation.",
],
inputs=[text_input],
)
if __name__ == "__main__":
demo.launch()