Spaces:
Sleeping
Sleeping
| import os | |
| import gc | |
| import gradio as gr | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from typing import Dict, List, Tuple, Optional | |
| from huggingface_hub import login | |
| # ========================================== | |
| # 0. SETUP & MEMORY MANAGEMENT | |
| # ========================================== | |
| # Force CPU/Low-VRAM settings | |
| os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1" | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| if os.environ.get("HF_TOKEN"): | |
| login(token=os.environ["HF_TOKEN"]) | |
| # Delayed imports to keep UI startup fast | |
| torch = None | |
| HookedTransformer = None | |
| F = None | |
| def import_heavy(): | |
| global torch, HookedTransformer, F | |
| if torch is None: | |
| import torch as t | |
| import torch.nn.functional as f | |
| from transformer_lens import HookedTransformer as HT | |
| torch = t | |
| HookedTransformer = HT | |
| F = f | |
| print("π¦ PyTorch & TransformerLens imported.") | |
| # SINGLETON MODEL HOLDER | |
| MODEL = None | |
| def cleanup_memory(): | |
| """Aggressively clears VRAM/RAM.""" | |
| gc.collect() | |
| if torch is not None and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def get_model(): | |
| global MODEL | |
| if MODEL is None: | |
| import_heavy() | |
| print("β³ Loading Gemma-2-2B...") | |
| cleanup_memory() | |
| try: | |
| # Using CPU and float32 for maximum compatibility. | |
| # Change device="cuda" and dtype="float16" if you have a GPU > 8GB VRAM | |
| MODEL = HookedTransformer.from_pretrained( | |
| "google/gemma-2-2b", | |
| device="cpu", | |
| dtype="float32" | |
| ) | |
| MODEL.eval() | |
| print("β Model Loaded.") | |
| except Exception as e: | |
| print(f"β Load Error: {e}") | |
| return None | |
| return MODEL | |
| # ========================================== | |
| # 1. MAIN INFERENCE PIPELINE | |
| # ========================================== | |
| def run_main_analysis(text): | |
| """ | |
| Runs the model, extracts Logits, Logit Lens, and Head Summary. | |
| """ | |
| model = get_model() | |
| if not model: return "Model Failed", None, None, None, None, None, None | |
| cleanup_memory() | |
| with torch.inference_mode(): | |
| # 1. Run Forward Pass & Cache | |
| logits, cache = model.run_with_cache(text) | |
| # 2. Basic Prediction Info | |
| tokens = model.to_str_tokens(text) | |
| token_indices = list(range(len(tokens))) | |
| formatted_tokens = [(t, str(i)) for i, t in enumerate(tokens)] | |
| last_logit = logits[0, -1] | |
| baseline_prob = torch.softmax(last_logit, dim=-1).max().item() | |
| baseline_id = last_logit.argmax().item() | |
| baseline_token = model.to_string(baseline_id) | |
| # 3. FEATURE: Logit Lens (What is the model thinking at layer X?) | |
| # We accumulate the residual stream at the END of every block | |
| logit_lens_data = [] | |
| # Iterate through layers to build Logit Lens | |
| for i in range(model.cfg.n_layers): | |
| # Extract residual stream at this layer for the LAST token | |
| resid = cache[f"blocks.{i}.hook_resid_post"][0, -1, :] | |
| # Apply LayerNorm + Unembed (mimic the final head) | |
| # Gemma uses RMSNorm, TransformerLens handles the details usually, | |
| # but manually: scaled_resid -> ln_final -> unembed | |
| ln_resid = model.ln_final(resid) | |
| layer_logits = model.unembed(ln_resid) | |
| best_id = layer_logits.argmax().item() | |
| best_tok = model.to_string(best_id) | |
| prob = torch.softmax(layer_logits, dim=-1).max().item() | |
| logit_lens_data.append({ | |
| "layer": i, | |
| "token": best_tok, | |
| "prob": prob | |
| }) | |
| # 4. FEATURE: Head Summary (Pre-Softmax Max Scores) | |
| # Used for the Overview Heatmap | |
| n_layers, n_heads = model.cfg.n_layers, model.cfg.n_heads | |
| heatmap_data = np.zeros((n_layers, n_heads)) | |
| # Store minimal data for the Head Inspector to avoid VRAM bloat | |
| # We DO NOT store the full 2D matrix here. We re-compute on demand. | |
| head_max_scores = {} | |
| for l in range(n_layers): | |
| # Shape: [Batch, Head, Q, K]. We take max over K for the last Q. | |
| scores = cache[f"blocks.{l}.attn.hook_attn_scores"][0, :, -1, :] | |
| scores_np = scores.detach().cpu().numpy() | |
| heatmap_data[l] = scores_np.max(axis=-1) | |
| head_max_scores[l] = scores_np # [Head, Key] for last pos | |
| # 5. Build State | |
| state = { | |
| "text": text, | |
| "tokens": tokens, | |
| "baseline_id": baseline_id, | |
| "baseline_token": baseline_token, | |
| "head_data": head_max_scores, # Lightweight cache | |
| "logit_lens": logit_lens_data, | |
| "edits": [] | |
| } | |
| # --- Visualizations --- | |
| # A. Overview Heatmap | |
| fig_overview = go.Figure(data=go.Heatmap( | |
| z=heatmap_data, colorscale='RdBu_r', zmid=0, | |
| hovertemplate="Layer %{y}, Head %{x}<br>Max Score: %{z:.2f}<extra></extra>" | |
| )) | |
| fig_overview.update_layout( | |
| title="Overview: Max Attention Scores (Last Token)", | |
| xaxis_title="Head", yaxis_title="Layer", | |
| height=500 | |
| ) | |
| # B. Logit Lens Chart | |
| lens_x = [d["layer"] for d in logit_lens_data] | |
| lens_y = [d["prob"] for d in logit_lens_data] | |
| lens_text = [d["token"] for d in logit_lens_data] | |
| fig_lens = go.Figure() | |
| fig_lens.add_trace(go.Scatter( | |
| x=lens_x, y=lens_y, mode='lines+markers+text', | |
| text=lens_text, textposition="top center", | |
| line=dict(color='indigo', width=2) | |
| )) | |
| fig_lens.update_layout( | |
| title="Logit Lens: Best Guess at each Layer", | |
| xaxis_title="Layer", yaxis_title="Probability", | |
| height=400 | |
| ) | |
| # Cleanup | |
| del cache, logits | |
| cleanup_memory() | |
| # Helper for dropdowns | |
| choices = [f"L{l} H{h}" for l in range(n_layers) for h in range(n_heads)] | |
| return ( | |
| formatted_tokens, | |
| f"'{baseline_token}' ({baseline_prob:.2%})", | |
| fig_overview, | |
| fig_lens, | |
| state, | |
| gr.update(choices=choices), | |
| gr.update(visible=True) # Show advanced tabs | |
| ) | |
| # ========================================== | |
| # 2. FEATURE: DEEP DIVE (2D PATTERNS) | |
| # ========================================== | |
| def get_2d_attention_pattern(state, selection): | |
| """ | |
| Re-runs the model for a specific layer/head to get the FULL QxK matrix. | |
| This is memory efficient (doesn't store all layers) but slower (re-inference). | |
| """ | |
| if not state or not selection: return None | |
| # Parse L/H | |
| parts = selection.split() # "L10 H2" | |
| layer = int(parts[0][1:]) | |
| head = int(parts[1][1:]) | |
| model = get_model() | |
| text = state["text"] | |
| tokens = state["tokens"] | |
| cleanup_memory() | |
| # Run only capturing specific head | |
| hook_name = f"blocks.{layer}.attn.hook_attn_scores" | |
| with torch.inference_mode(): | |
| _, cache = model.run_with_cache(text, names_filter=[hook_name]) | |
| # Shape: [1, Head, Seq, Seq] | |
| attn_matrix = cache[hook_name][0, head, :, :].detach().cpu().numpy() | |
| del cache | |
| cleanup_memory() | |
| # Plot 2D Heatmap | |
| # Mask future tokens (causal masking) to make it cleaner, | |
| # though Gemma already masks them with -inf. | |
| fig = go.Figure(data=go.Heatmap( | |
| z=attn_matrix, | |
| x=tokens, y=tokens, | |
| colorscale='RdBu_r', zmid=0, | |
| hoverongaps=False | |
| )) | |
| fig.update_layout( | |
| title=f"Full Attention Scores: Layer {layer}, Head {head}", | |
| xaxis_title="Key (Source)", | |
| yaxis_title="Query (Destination)", | |
| height=600, | |
| width=600, | |
| yaxis=dict(autorange="reversed") # Standard attention map orientation | |
| ) | |
| return fig | |
| # ========================================== | |
| # 3. FEATURE: CIRCUIT DISCOVERY (ABLATION) | |
| # ========================================== | |
| def run_circuit_discovery(state): | |
| """ | |
| Iteratively ablates heads to find which ones support the correct answer. | |
| """ | |
| if not state: return None | |
| model = get_model() | |
| text = state["text"] | |
| target_id = state["baseline_id"] | |
| # Get clean logit | |
| with torch.inference_mode(): | |
| clean_logits = model(text) | |
| clean_score = clean_logits[0, -1, target_id].item() | |
| n_layers, n_heads = model.cfg.n_layers, model.cfg.n_heads | |
| importance_map = np.zeros((n_layers, n_heads)) | |
| print("running ablation loop...") | |
| # Efficient Loop | |
| for l in range(n_layers): | |
| # We optimize by defining one hook per layer | |
| def get_ablation_hook(head_to_kill): | |
| def hook(z, hook): | |
| # z shape: [batch, pos, head, d_head] | |
| z[:, :, head_to_kill, :] = 0.0 | |
| return z | |
| return hook | |
| for h in range(n_heads): | |
| # Run with one head killed | |
| with torch.inference_mode(): | |
| # hook_z is the output of the attention heads before mixing | |
| hook_name = f"blocks.{l}.attn.hook_z" | |
| corrupt_logits = model.run_with_hooks( | |
| text, | |
| fwd_hooks=[(hook_name, get_ablation_hook(h))] | |
| ) | |
| corrupt_score = corrupt_logits[0, -1, target_id].item() | |
| # Logit Difference (Clean - Corrupt) | |
| # Positive value = Head was helping | |
| # Negative value = Head was suppressing | |
| importance_map[l, h] = clean_score - corrupt_score | |
| cleanup_memory() | |
| fig = go.Figure(data=go.Heatmap( | |
| z=importance_map, | |
| colorscale='RdBu', zmid=0, | |
| text=np.around(importance_map, 2), | |
| texttemplate="%{text}" | |
| )) | |
| fig.update_layout( | |
| title="Head Importance (Logit Difference via Ablation)", | |
| xaxis_title="Head", yaxis_title="Layer" | |
| ) | |
| return fig | |
| # ========================================== | |
| # 4. FEATURE: ACTIVATION PATCHING | |
| # ========================================== | |
| def run_activation_patching(state, corrupt_text, layer, component): | |
| """ | |
| Runs a clean pass, saves activation, runs corrupt pass, patches activation in. | |
| """ | |
| if not state or not corrupt_text: return "Missing inputs", None | |
| model = get_model() | |
| clean_text = state["text"] | |
| layer = int(layer) | |
| # 1. Capture CLEAN Activation | |
| # We map "Component" to actual hook names | |
| hook_map = { | |
| "Attention Output": f"blocks.{layer}.hook_attn_out", | |
| "MLP Output": f"blocks.{layer}.hook_mlp_out", | |
| "Residual Stream": f"blocks.{layer}.hook_resid_post" | |
| } | |
| target_hook = hook_map.get(component, hook_map["Residual Stream"]) | |
| cleanup_memory() | |
| with torch.inference_mode(): | |
| _, clean_cache = model.run_with_cache(clean_text, names_filter=[target_hook]) | |
| clean_act = clean_cache[target_hook] # This tensor is on CPU/GPU depending on config | |
| # 2. Define Patch Hook | |
| def patch_hook(act, hook): | |
| # Inject clean activation into corrupt run | |
| return clean_act | |
| # 3. Run Corrupt Pass with Patch | |
| with torch.inference_mode(): | |
| patched_logits = model.run_with_hooks( | |
| corrupt_text, | |
| fwd_hooks=[(target_hook, patch_hook)] | |
| ) | |
| # 4. Result | |
| patched_id = patched_logits[0, -1].argmax().item() | |
| patched_token = model.to_string(patched_id) | |
| patched_prob = torch.softmax(patched_logits[0, -1], dim=-1).max().item() | |
| # Clean up tensor | |
| del clean_cache, clean_act | |
| cleanup_memory() | |
| return f"Patched Token: '{patched_token}' ({patched_prob:.2%})", None | |
| # ========================================== | |
| # 5. SURGERY (EDITING) | |
| # ========================================== | |
| def add_edit(state, selection, pos, value): | |
| if not state: return state, "Run model first" | |
| try: | |
| parts = selection.split() | |
| l = int(parts[0][1:]) | |
| h = int(parts[1][1:]) | |
| p = int(pos) | |
| v = float(value) | |
| state["edits"].append({"layer": l, "head": h, "pos": p, "value": v}) | |
| return state, f"Added: L{l}H{h} Pos{p} -> {v}" | |
| except: | |
| return state, "Invalid input format" | |
| def execute_surgery(state): | |
| if not state or not state["edits"]: return "No edits queued", None | |
| model = get_model() | |
| text = state["text"] | |
| # Prepare Hooks | |
| hooks = [] | |
| # We need to group edits by layer to handle them efficiently | |
| edits_by_layer = {} | |
| for e in state["edits"]: | |
| if e["layer"] not in edits_by_layer: edits_by_layer[e["layer"]] = [] | |
| edits_by_layer[e["layer"]].append(e) | |
| for l, edits in edits_by_layer.items(): | |
| def make_hook(current_edits): | |
| def hook(scores, h): | |
| # scores: [batch, head, query, key] | |
| # We edit the LAST query position (prediction step) | |
| for edit in current_edits: | |
| head_idx = edit["head"] | |
| key_pos = edit["pos"] | |
| val = edit["value"] | |
| scores[0, head_idx, -1, key_pos] = val | |
| return scores | |
| return hook | |
| hooks.append((f"blocks.{l}.attn.hook_attn_scores", make_hook(edits))) | |
| with torch.inference_mode(): | |
| new_logits = model.run_with_hooks(text, fwd_hooks=hooks) | |
| new_id = new_logits[0, -1].argmax().item() | |
| new_token = model.to_string(new_id) | |
| return f"Surgery Result: '{new_token}' (Baseline was '{state['baseline_token']}')", None | |
| def clear_edits(state): | |
| state["edits"] = [] | |
| return state, "Edits cleared." | |
| # ========================================== | |
| # UI LAYOUT | |
| # ========================================== | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| state = gr.State() | |
| gr.Markdown("# π¬ Gemma-2-2B Mechanistic Dashboard (Full Suite)") | |
| gr.Markdown("Includes: Logit Lens, 2D Attention Patterns, Activation Patching, and Circuit Discovery.") | |
| # --- TOP CONTROL --- | |
| with gr.Row(): | |
| txt_input = gr.Textbox("The capital of France is", label="Input Prompt", scale=2) | |
| btn_run = gr.Button("π Run Analysis", variant="primary", scale=1) | |
| with gr.Row(): | |
| txt_tokens = gr.HighlightedText(label="Tokenized Input", combine_adjacent=False) | |
| lbl_pred = gr.Label(label="Prediction") | |
| # --- MAIN TABS --- | |
| with gr.Tabs(visible=False) as main_tabs: | |
| # TAB 1: OVERVIEW & LENS | |
| with gr.TabItem("Overview & Logit Lens"): | |
| with gr.Row(): | |
| plot_overview = gr.Plot(label="Attention Head Max Scores") | |
| plot_lens = gr.Plot(label="Logit Lens (Layer Predictions)") | |
| # TAB 2: PATTERN INSPECTOR | |
| with gr.TabItem("π 2D Pattern Inspector"): | |
| gr.Markdown("**Deep Dive:** Select a head to view the full Query x Key attention matrix.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dd_head = gr.Dropdown(label="Select Head") | |
| btn_show_pattern = gr.Button("Show 2D Pattern") | |
| with gr.Column(scale=3): | |
| plot_pattern = gr.Plot(label="Attention Matrix") | |
| # TAB 3: CIRCUIT DISCOVERY | |
| with gr.TabItem("𧬠Circuit Discovery"): | |
| gr.Markdown("**Ablation Study:** Automatically zeroes out each head to see if it helps or hurts the correct prediction.") | |
| btn_circuit = gr.Button("Run Ablation Sweep (Takes ~10s)") | |
| plot_circuit = gr.Plot() | |
| # TAB 4: SURGERY | |
| with gr.TabItem("πͺ Surgery"): | |
| gr.Markdown("**Intervention:** Edit Pre-Softmax attention scores.") | |
| with gr.Row(): | |
| dd_surgery_head = gr.Dropdown(label="Target Head") | |
| num_pos = gr.Number(label="Key Position", value=0, precision=0) | |
| num_val = gr.Number(label="New Score", value=-10.0) | |
| btn_add = gr.Button("Add Edit") | |
| txt_log = gr.Textbox(label="Edit Log") | |
| with gr.Row(): | |
| btn_exec = gr.Button("Execute Surgery", variant="stop") | |
| btn_clear = gr.Button("Clear Edits") | |
| out_surgery = gr.Textbox(label="Surgery Output") | |
| # TAB 5: PATCHING | |
| with gr.TabItem("π©Ή Activation Patching"): | |
| gr.Markdown("**Denoising:** Inject clean activations into a corrupted run.") | |
| txt_corrupt = gr.Textbox(label="Corrupted Prompt", value="The capital of Germany is") | |
| with gr.Row(): | |
| num_layer = gr.Number(label="Layer", value=10) | |
| dd_comp = gr.Dropdown(["Residual Stream", "Attention Output", "MLP Output"], value="Residual Stream", label="Component") | |
| btn_patch = gr.Button("Run Patch") | |
| out_patch = gr.Textbox(label="Patch Result") | |
| # --- EVENTS --- | |
| # Main Run | |
| btn_run.click( | |
| run_main_analysis, | |
| inputs=[txt_input], | |
| outputs=[txt_tokens, lbl_pred, plot_overview, plot_lens, state, dd_head, main_tabs] | |
| ) | |
| # 2D Pattern | |
| btn_show_pattern.click( | |
| get_2d_attention_pattern, | |
| inputs=[state, dd_head], | |
| outputs=[plot_pattern] | |
| ) | |
| # Circuit Discovery | |
| btn_circuit.click( | |
| run_circuit_discovery, | |
| inputs=[state], | |
| outputs=[plot_circuit] | |
| ) | |
| # Surgery | |
| # Sync dropdowns for convenience | |
| dd_head.change(lambda x: x, inputs=dd_head, outputs=dd_surgery_head) | |
| btn_add.click(add_edit, [state, dd_surgery_head, num_pos, num_val], [state, txt_log]) | |
| btn_clear.click(clear_edits, state, [state, txt_log]) | |
| btn_exec.click(execute_surgery, state, [out_surgery, plot_overview]) | |
| # Patching | |
| btn_patch.click( | |
| run_activation_patching, | |
| [state, txt_corrupt, num_layer, dd_comp], | |
| [out_patch, plot_overview] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |