File size: 5,395 Bytes
275113a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
SHAP Text Explainer — Word-level attribution for text classification
Course: 215 AI Safety ch8
"""

import numpy as np
import torch
import gradio as gr
from transformers import pipeline

# Load sentiment model
classifier = pipeline(
    "sentiment-analysis",
    model="distilbert-base-uncased-finetuned-sst-2-english",
    return_all_scores=True,
)

LABEL_NAMES = ["NEGATIVE", "POSITIVE"]


def simple_word_attribution(text: str):
    """
    Compute word-level attribution using leave-one-out (LOO) method.
    Faster and more reliable than full SHAP on CPU.
    """
    if not text.strip():
        return "", "", {}

    # Baseline prediction
    base_result = classifier(text)[0]
    base_scores = {r["label"]: r["score"] for r in base_result}
    pred_label = max(base_scores, key=base_scores.get)
    pred_score = base_scores[pred_label]

    words = text.split()
    if len(words) == 0:
        return "", "", {}

    # LOO attribution
    attributions = []
    for i in range(len(words)):
        masked = " ".join(words[:i] + words[i + 1 :])
        if not masked.strip():
            attributions.append(0.0)
            continue
        result = classifier(masked)[0]
        masked_scores = {r["label"]: r["score"] for r in result}
        # Attribution = how much removing this word changes the predicted class score
        diff = base_scores[pred_label] - masked_scores[pred_label]
        attributions.append(diff)

    # Normalize for display
    max_abs = max(abs(a) for a in attributions) if attributions else 1.0
    if max_abs == 0:
        max_abs = 1.0

    # Build highlighted HTML
    html_parts = []
    for word, attr in zip(words, attributions):
        norm_attr = attr / max_abs  # -1 to 1
        if norm_attr > 0:
            # Pushes toward prediction (red = positive contribution)
            intensity = min(int(abs(norm_attr) * 200), 200)
            bg = f"rgba(239, 68, 68, {abs(norm_attr) * 0.6})"
        else:
            # Pushes against prediction (blue = negative contribution)
            intensity = min(int(abs(norm_attr) * 200), 200)
            bg = f"rgba(59, 130, 246, {abs(norm_attr) * 0.6})"
        html_parts.append(
            f'<span style="background:{bg};padding:2px 4px;margin:1px;'
            f'border-radius:3px;display:inline-block;">{word}</span>'
        )

    highlighted_html = (
        '<div style="font-size:16px;line-height:2;padding:10px;">'
        + " ".join(html_parts)
        + "</div>"
    )

    # Legend
    legend = (
        '<div style="margin-top:10px;font-size:13px;">'
        '<span style="background:rgba(239,68,68,0.5);padding:2px 8px;border-radius:3px;">Red</span>'
        f" = pushes toward {pred_label} &nbsp;&nbsp;"
        '<span style="background:rgba(59,130,246,0.5);padding:2px 8px;border-radius:3px;">Blue</span>'
        f" = pushes against {pred_label} &nbsp;&nbsp;"
        "(intensity = strength)"
        "</div>"
    )

    # Prediction info
    pred_info = (
        f"**Prediction: {pred_label}** ({pred_score:.1%})\n\n"
        f"| Label | Score |\n|---|---|\n"
    )
    for r in base_result:
        pred_info += f"| {r['label']} | {r['score']:.1%} |\n"

    # Attribution table
    pred_info += "\n**Word attributions (leave-one-out):**\n\n"
    pred_info += "| Word | Attribution | Effect |\n|---|---|---|\n"
    sorted_attr = sorted(
        zip(words, attributions), key=lambda x: abs(x[1]), reverse=True
    )
    for word, attr in sorted_attr[:15]:
        direction = "supports" if attr > 0 else "opposes"
        pred_info += f"| {word} | {attr:+.4f} | {direction} |\n"

    return highlighted_html + legend, pred_info, base_scores


with gr.Blocks(title="SHAP Text Explainer") as demo:
    gr.Markdown(
        "# SHAP Text Explainer\n"
        "Enter text to see which words contribute most to the sentiment prediction.\n"
        "Uses leave-one-out attribution (similar to SHAP) for word-level explanations.\n"
        "*Course: 215 AI Safety ch8 — Explainability*"
    )

    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Input Text",
                placeholder="Enter a sentence to analyze...",
                lines=3,
            )
            btn = gr.Button("Explain", variant="primary")
            gr.Markdown(
                "*Note: Each word is removed one at a time to measure its impact. "
                "This takes a few seconds for longer texts.*"
            )

        with gr.Column():
            highlighted = gr.HTML(label="Word Attribution")
            details_md = gr.Markdown()

    btn.click(
        lambda t: simple_word_attribution(t)[:2],
        [text_input],
        [highlighted, details_md],
    )

    gr.Examples(
        examples=[
            "This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.",
            "The food was terrible and the service was even worse. I will never go back to this restaurant.",
            "The product works okay but nothing special. It does what it says but I expected more for the price.",
            "I love how this book combines beautiful writing with deep philosophical insights.",
            "The flight was delayed by 3 hours and the airline offered no compensation or explanation.",
        ],
        inputs=[text_input],
    )

if __name__ == "__main__":
    demo.launch()