""" This script provides an interactive Gradio web application for visualizing token-level attributions in language model predictions using Integrated Gradients. It loads a small LLaMA model, computes how each input token contributes to the probability of a specified target token, and generates a color-coded visualization to explain model reasoning. Features: - Loads a causal language model and tokenizer (LLaMA). - Computes Integrated Gradients attributions for a prompt and target token. - Visualizes token contributions with a grid of colored boxes (green = positive, red = negative). - Interactive Gradio UI for custom prompts and target tokens. - Includes a Feynman-style explanation for interpretability concepts. How to run: 1. Ensure Python dependencies are installed: torch, transformers, captum, matplotlib, gradio. 2. Place this file in your project directory. 3. Run the script from the command line: python app.py 4. The app will launch locally (default port 7860). Open the provided URL in your browser. 5. Enter a prompt and target token to see the visualization and interpret model predictions. Notes: - The script saves the visualization as 'token_attributions.png'. - For long prompts (>50 tokens), a warning is shown to prevent performance issues. - Example prompts are provided for quick testing. """ import os import logging import torch from transformers import AutoTokenizer, AutoModelForCausalLM from captum.attr import IntegratedGradients import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import gradio as gr # Added for interactive UI device = "cuda" if torch.cuda.is_available() else "cpu" # Basic logger for helpful messages when loading gated models logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # ---------------- Load model (gated models handled safely) ---------------- # Default attempts to load LLaMA-3.2-1B, but that model is gated on HF. We try to use # HUGGINGFACE_HUB_TOKEN if available, otherwise fall back to a small public model for demo. requested_model = "meta-llama/Llama-3.2-1B" fallback_model = "distilgpt2" hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") model_name = requested_model try: load_kwargs = {} if hf_token: load_kwargs["use_auth_token"] = hf_token tokenizer = AutoTokenizer.from_pretrained(model_name, **load_kwargs) model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs).to(device) model.eval() logger.info(f"Loaded gated model: {model_name}") except Exception as e: logger.warning(f"Could not load requested model '{requested_model}': {e}") logger.info(f"Falling back to public model: {fallback_model} for demo purposes.") model_name = fallback_model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name).to(device) model.eval() # ---------------- Modularized Functions ---------------- def compute_attributions(prompt, target_token): """ Compute Integrated Gradients attributions for a given prompt and target token. Appeals to devs/ML: Shows model interpretability; business: Builds trust by explaining AI decisions. """ inputs = tokenizer(prompt, return_tensors="pt").to(device) target_id = tokenizer(target_token, add_special_tokens=False)["input_ids"][0] def forward_func(embeds): outputs = model(inputs_embeds=embeds) logits = outputs.logits[:, -1, :] probs = torch.softmax(logits, dim=-1) return probs[:, target_id] embeddings = model.get_input_embeddings()(inputs["input_ids"]) embeddings.requires_grad_(True) ig = IntegratedGradients(forward_func) attributions, delta = ig.attribute( embeddings, n_steps=30, return_convergence_delta=True ) token_attr = attributions.sum(-1).squeeze().detach().cpu() tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze()) # Normalize safely token_attr_np = token_attr.numpy() norm_denom = (abs(token_attr_np).max() + 1e-8) token_attr_np = token_attr_np / norm_denom return tokens, token_attr_np def create_visualization(tokens, token_attr_np, prompt, target_token): """ Generate an appealing visualization: Grid of colored token boxes. Enhanced for mixed audience: Clean design, simple explanations, professional look. """ num_tokens = max(1, len(tokens)) cols = min(max(3, int(num_tokens**0.5)), 8) rows = (num_tokens + cols - 1) // cols box_w = 1.0 / cols box_h = 0.18 fig_h = max(4, rows * 0.7 + 2.0) # Increased height for more spacing fig = plt.figure(figsize=(12, fig_h)) # Add title for context fig.suptitle(f"Token Contributions to Predicting '{target_token}' in: '{prompt}'", fontsize=14, y=0.95, ha='center') ax = fig.add_axes([0, 0.30, 1, 0.60]) # Shift grid higher for more bottom space ax.set_xlim(0, cols) ax.set_ylim(0, rows) ax.axis('off') # Normalize for colormap (0-1 range) minv, maxv = token_attr_np.min(), token_attr_np.max() norm = (token_attr_np - minv) / (maxv - minv + 1e-8) cmap = plt.get_cmap('RdYlGn') # Green positive, red negative from matplotlib.patches import FancyBboxPatch for idx, (tok, score_norm) in enumerate(zip(tokens, norm)): r = idx // cols c = idx % cols x = c y = rows - 1 - r color = cmap(score_norm) pad = 0.08 rect = FancyBboxPatch((x + pad*0.15, y + pad*0.15), 1 - pad, box_h - pad*0.3, boxstyle='round,pad=0.02', linewidth=0.8, facecolor=color, edgecolor='gray', alpha=0.95) # Softer edges ax.add_patch(rect) # Improved text: Larger font, wrap long tokens display_tok = tok.replace('Ġ', ' ') if isinstance(tok, str) else str(tok) # Space for subwords ax.text(x + 0.5, y + box_h/2, display_tok, ha='center', va='center', fontsize=10, fontweight='bold') # Bold for readability # Enhanced colorbar - lowered position sm = plt.cm.ScalarMappable(cmap=cmap) sm.set_array([0, 1]) cax = fig.add_axes([0.1, 0.22, 0.8, 0.04]) # Lowered from 0.18 cb = fig.colorbar(sm, cax=cax, orientation='horizontal') cb.set_label('Contribution Strength', fontsize=11, fontweight='bold') # Markers for audience-friendly explanation - lowered fig.text(0.05, 0.16, 'Green Positive (helps prediction)', fontsize=10, ha='left') fig.text(0.75, 0.16, 'Red Negative (hinders prediction)', fontsize=10, ha='right') # Engaging caption for mixed audience - shortened and lowered with wrap caption = ( "How input tokens influence the model's target prediction: Green supports (builds AI trust), " "red opposes. For debugging (devs), reasoning insights (ML), reliable decisions (business). Normalized." ) fig.text(0.5, 0.08, caption, fontsize=9, ha='center', va='top', wrap=True) # Smaller font, lower pos # Save with higher quality out_path = 'token_attributions.png' fig.savefig(out_path, dpi=300, bbox_inches='tight', facecolor='white') plt.close(fig) # Clean up return out_path # ---------------- Gradio Interface for Interactivity ---------------- def generate_attribution(prompt, target_token): """ Gradio wrapper: Compute and visualize for custom inputs. Default example: France capital for quick demo. """ if not prompt.strip(): prompt = "The capital of France is" if not target_token.strip(): target_token = " Paris" # Add check for long prompts to prevent overload if len(prompt.split()) > 50: return "Warning: Prompt too long (>50 tokens). Shorten for better performance." try: tokens, token_attr_np = compute_attributions(prompt, target_token) img_path = create_visualization(tokens, token_attr_np, prompt, target_token) return img_path except Exception as e: return f"Error: {str(e)}" # Launch interactive app iface = gr.Interface( fn=generate_attribution, inputs=[ gr.Textbox(label="Prompt", value="The capital of France is", placeholder="Enter your prompt..."), gr.Textbox(label="Target Token", value=" Paris", placeholder="Enter target token (e.g., ' Paris')") ], outputs=gr.Image(label="Token Attribution Visualization"), title="AI Interpretability Explorer: See How Tokens Influence Predictions", description="Input a prompt and target token to visualize token contributions using [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients) on LLaMA. " "Explore model reasoning interactively.", # Insert a collapsible Feynman-style explanation and quick cheat-sheet actions using HTML so Gradio shows it above the app. # We use safe escaping for the cheat text when embedding into HTML/JS. # The small JS below enables a copy-to-clipboard action and a downloadable .txt file via data URI. article=""" ### How it works — Feynman-style This tool explains which input tokens most influence the model's next-token prediction using Integrated Gradients https://captum.ai/docs/extension/integrated_gradients. - What it does: Interpolates from a baseline to the actual input in embedding space, accumulates gradients along the path, and attributes importance to each input token. - Why it helps: Highlights which tokens push the model toward (green) or away from (red) the chosen target token. Useful for debugging, bias detection, and model transparency. - How to read results: Higher positive values (green) mean the token increases the probability of the target; negative values (red) mean the token reduces it. Values are normalized per example. - Watch-outs: IG depends on the baseline choice and number of interpolation steps. Subword tokens (e.g., Ġ) are shown with spaces; long prompts may be noisy. """ , examples=[ ["The capital of France is", " Paris"], ["I love this product because", " it's amazing"], ["The weather today is", " sunny"] ] ) if __name__ == "__main__": # Run the original example for backward compatibility, then launch Gradio print("Generating default example...") default_img = generate_attribution("", "") print(f"Default plot saved to: token_attributions.png") print("\nLaunching interactive Gradio app... Open in browser for custom examples.") iface.launch(share=True, server_name="0.0.0.0", server_port=7860)