""" SHAP Text Explainer — Word-level attribution for text classification Course: 215 AI Safety ch8 """ import numpy as np import torch import gradio as gr from transformers import pipeline # Load sentiment model classifier = pipeline( "sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", return_all_scores=True, ) LABEL_NAMES = ["NEGATIVE", "POSITIVE"] def simple_word_attribution(text: str): """ Compute word-level attribution using leave-one-out (LOO) method. Faster and more reliable than full SHAP on CPU. """ if not text.strip(): return "", "", {} # Baseline prediction base_result = classifier(text)[0] base_scores = {r["label"]: r["score"] for r in base_result} pred_label = max(base_scores, key=base_scores.get) pred_score = base_scores[pred_label] words = text.split() if len(words) == 0: return "", "", {} # LOO attribution attributions = [] for i in range(len(words)): masked = " ".join(words[:i] + words[i + 1 :]) if not masked.strip(): attributions.append(0.0) continue result = classifier(masked)[0] masked_scores = {r["label"]: r["score"] for r in result} # Attribution = how much removing this word changes the predicted class score diff = base_scores[pred_label] - masked_scores[pred_label] attributions.append(diff) # Normalize for display max_abs = max(abs(a) for a in attributions) if attributions else 1.0 if max_abs == 0: max_abs = 1.0 # Build highlighted HTML html_parts = [] for word, attr in zip(words, attributions): norm_attr = attr / max_abs # -1 to 1 if norm_attr > 0: # Pushes toward prediction (red = positive contribution) intensity = min(int(abs(norm_attr) * 200), 200) bg = f"rgba(239, 68, 68, {abs(norm_attr) * 0.6})" else: # Pushes against prediction (blue = negative contribution) intensity = min(int(abs(norm_attr) * 200), 200) bg = f"rgba(59, 130, 246, {abs(norm_attr) * 0.6})" html_parts.append( f'{word}' ) highlighted_html = ( '
' + " ".join(html_parts) + "
" ) # Legend legend = ( '
' 'Red' f" = pushes toward {pred_label}   " 'Blue' f" = pushes against {pred_label}   " "(intensity = strength)" "
" ) # Prediction info pred_info = ( f"**Prediction: {pred_label}** ({pred_score:.1%})\n\n" f"| Label | Score |\n|---|---|\n" ) for r in base_result: pred_info += f"| {r['label']} | {r['score']:.1%} |\n" # Attribution table pred_info += "\n**Word attributions (leave-one-out):**\n\n" pred_info += "| Word | Attribution | Effect |\n|---|---|---|\n" sorted_attr = sorted( zip(words, attributions), key=lambda x: abs(x[1]), reverse=True ) for word, attr in sorted_attr[:15]: direction = "supports" if attr > 0 else "opposes" pred_info += f"| {word} | {attr:+.4f} | {direction} |\n" return highlighted_html + legend, pred_info, base_scores with gr.Blocks(title="SHAP Text Explainer") as demo: gr.Markdown( "# SHAP Text Explainer\n" "Enter text to see which words contribute most to the sentiment prediction.\n" "Uses leave-one-out attribution (similar to SHAP) for word-level explanations.\n" "*Course: 215 AI Safety ch8 — Explainability*" ) with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Input Text", placeholder="Enter a sentence to analyze...", lines=3, ) btn = gr.Button("Explain", variant="primary") gr.Markdown( "*Note: Each word is removed one at a time to measure its impact. " "This takes a few seconds for longer texts.*" ) with gr.Column(): highlighted = gr.HTML(label="Word Attribution") details_md = gr.Markdown() btn.click( lambda t: simple_word_attribution(t)[:2], [text_input], [highlighted, details_md], ) gr.Examples( examples=[ "This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.", "The food was terrible and the service was even worse. I will never go back to this restaurant.", "The product works okay but nothing special. It does what it says but I expected more for the price.", "I love how this book combines beautiful writing with deep philosophical insights.", "The flight was delayed by 3 hours and the airline offered no compensation or explanation.", ], inputs=[text_input], ) if __name__ == "__main__": demo.launch()