Spaces:
Running
Running
| import os | |
| import gc | |
| import gradio as gr | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| from typing import Dict, List, Tuple | |
| from huggingface_hub import login | |
| # ========================================== | |
| # 0. SETUP & ENVIRONMENT | |
| # ========================================== | |
| # Standard environment optimization for CPU/Low-VRAM | |
| os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1" | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| # Login if token exists (Gemma is a gated model) | |
| if os.environ.get("HF_TOKEN"): | |
| login(token=os.environ["HF_TOKEN"]) | |
| # Heavy imports delayed to keep UI start fast | |
| torch = None | |
| HookedTransformer = None | |
| def import_heavy(): | |
| global torch, HookedTransformer | |
| if torch is None: | |
| import torch as t | |
| torch = t | |
| from transformer_lens import HookedTransformer as HT | |
| HookedTransformer = HT | |
| print("π¦ Heavy libraries imported") | |
| # ========================================== | |
| # 1. LAZY SINGLETON LOADING | |
| # ========================================== | |
| MODEL = None | |
| def cleanup_memory(): | |
| """Aggressive memory cleanup.""" | |
| gc.collect() | |
| if torch is not None and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def get_model(): | |
| """Singleton pattern: Loads Gemma-2-2B once and keeps it.""" | |
| global MODEL | |
| if MODEL is None: | |
| import_heavy() | |
| print("β³ Loading Gemma-2-2B... (18 layers, 8 heads)") | |
| cleanup_memory() | |
| try: | |
| # Load on CPU by default to save VRAM. | |
| # If you have a GPU, change device="cuda" and dtype="float16" | |
| MODEL = HookedTransformer.from_pretrained( | |
| "google/gemma-2-2b", | |
| device="cpu", | |
| dtype="float32" # Use float16 if on GPU | |
| ) | |
| MODEL.eval() | |
| print(f"β Gemma Loaded! {MODEL.cfg.n_layers}L x {MODEL.cfg.n_heads}H") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| return None | |
| return MODEL | |
| # ========================================== | |
| # 2. CORE ANALYSIS (Inference + Extraction) | |
| # ========================================== | |
| def run_first_pass(text): | |
| model = get_model() | |
| if model is None: | |
| return "Model failed", None, None, None, None, None, None, "Error", None | |
| cleanup_memory() | |
| # 1. Run Inference & Capture Cache | |
| # We use pre-softmax scores (attn_scores) for better surgery precision | |
| try: | |
| with torch.inference_mode(): | |
| logits, cache = model.run_with_cache(text) | |
| # 2. Extract Data to CPU/Numpy immediately (Greedy Memory Saving) | |
| # We only keep what we need for the dashboard to avoid storing heavy tensors | |
| # Token Info | |
| tokens = model.to_str_tokens(text) | |
| token_display = [(t, str(i)) for i, t in enumerate(tokens)] | |
| # Prediction | |
| baseline_token_id = logits[0, -1].argmax().item() | |
| baseline_token = model.to_string([baseline_token_id]) | |
| # Extract Attention Scores (Pre-Softmax) for the LAST token | |
| # Shape: [Layer, Head] -> Max Score | |
| n_layers = model.cfg.n_layers | |
| n_heads = model.cfg.n_heads | |
| # Store full scores map for the last token in a Numpy dictionary | |
| # This allows us to inspect heads later without re-running the model! | |
| attn_scores_cache = {} | |
| heatmap_data = np.zeros((n_layers, n_heads)) | |
| for l in range(n_layers): | |
| # Extract [Batch, Head, Query, Key] -> [Head, Key] (for last query pos) | |
| scores = cache[f"blocks.{l}.attn.hook_attn_scores"][0, :, -1, :].cpu().numpy() | |
| attn_scores_cache[l] = scores | |
| # Heatmap value: Max attention score in that head | |
| heatmap_data[l] = scores.max(axis=-1) | |
| # 3. Create Visuals | |
| # Using RdBu_r because pre-softmax scores can be negative | |
| fig = go.Figure(data=go.Heatmap( | |
| z=heatmap_data, | |
| colorscale='RdBu_r', | |
| zmid=0, | |
| hoverongaps=False | |
| )) | |
| fig.update_layout( | |
| title=f"Max Attention Scores (Pre-Softmax) | Next: '{baseline_token}'", | |
| xaxis_title="Head", | |
| yaxis_title="Layer" | |
| ) | |
| # 4. Generate Circuit Attribution (Simplified for speed) | |
| # We calculate this now so we can show it immediately | |
| top_heads_msg = "Run 'Advanced Tools' -> 'Compute Attribution' for detailed breakdown." | |
| # 5. Build Lightweight State (No Torch Tensors!) | |
| new_state = { | |
| "text": text, | |
| "tokens": tokens, | |
| "baseline": baseline_token, | |
| "baseline_id": baseline_token_id, | |
| "attn_scores": attn_scores_cache, # Numpy arrays | |
| "edits": [], | |
| "corrupt_cache": None # Will fill on demand | |
| } | |
| # Cleanup | |
| del cache | |
| del logits | |
| cleanup_memory() | |
| head_list = [f"Layer {l}, Head {h}" for l in range(n_layers) for h in range(n_heads)] | |
| return ( | |
| f"Prediction: '{baseline_token}'", | |
| fig, | |
| token_display, | |
| new_state, | |
| gr.update(visible=True, choices=head_list), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| top_heads_msg, | |
| gr.update(visible=True) | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error: {e}", None, None, None, None, None, None, None, None | |
| # ========================================== | |
| # 3. INSPECTION & EDITING (Works on Numpy) | |
| # ========================================== | |
| def show_head_attention(state, selected_head_str): | |
| if not state or not selected_head_str: | |
| return None, "Select a head first", "0", "0" | |
| try: | |
| # Parse L/H | |
| parts = selected_head_str.replace("Layer ", "").replace("Head ", "").split(", ") | |
| layer = int(parts[0]) | |
| head = int(parts[1]) | |
| # Retrieve Numpy data from state (No GPU needed!) | |
| scores = state["attn_scores"][layer][head] | |
| tokens = state["tokens"] | |
| # Handle shape mismatch (just in case) | |
| scores = scores[:len(tokens)] | |
| fig = px.bar( | |
| x=list(range(len(tokens))), | |
| y=scores, | |
| labels={'x': 'Token', 'y': 'Raw Score'}, | |
| title=f"L{layer}H{head} Pre-Softmax Scores", | |
| color=scores, | |
| color_continuous_scale='RdBu_r' | |
| ) | |
| # Use actual tokens on X-axis | |
| fig.update_xaxes( | |
| tickmode='array', | |
| tickvals=list(range(len(tokens))), | |
| ticktext=[f"{i}: '{t}'" for i, t in enumerate(tokens)], | |
| tickangle=-45 | |
| ) | |
| info = f"Max score: {scores.max():.2f} on '{tokens[scores.argmax()]}'" | |
| return fig, info, str(layer), str(head) | |
| except Exception as e: | |
| return None, f"Error: {e}", "0", "0" | |
| def add_edit_to_state(state, layer, head, pos_idx, value): | |
| if not state: return state, "Run first pass first" | |
| try: | |
| l, h, p, v = int(layer), int(head), int(pos_idx), float(value) | |
| if p >= len(state["tokens"]): return state, "Position out of range" | |
| new_edit = {"layer": l, "head": h, "pos": p, "value": v} | |
| state["edits"].append(new_edit) | |
| status = "\n".join([f"L{e['layer']}H{e['head']} @ Pos {e['pos']} = {e['value']}" for e in state["edits"]]) | |
| return state, status | |
| except: | |
| return state, "Invalid Input" | |
| def run_surgery(state): | |
| model = get_model() | |
| if not model or not state or not state["edits"]: | |
| return "No edits or model failed.", None | |
| text = state["text"] | |
| # Group edits for efficiency | |
| edits_by_layer = {} | |
| for edit in state["edits"]: | |
| l = edit["layer"] | |
| if l not in edits_by_layer: edits_by_layer[l] = [] | |
| edits_by_layer[l].append(edit) | |
| # Define Hook | |
| def surgery_hook(attn_scores, hook): | |
| # attn_scores shape: [batch, head, query_len, key_len] | |
| layer = hook.layer() | |
| if layer in edits_by_layer: | |
| for edit in edits_by_layer[layer]: | |
| h, pos, val = edit["head"], edit["pos"], edit["value"] | |
| # Edit the score for the LAST query token (steering next step) | |
| attn_scores[0, h, -1, pos] = val | |
| return attn_scores | |
| hooks = [(f"blocks.{l}.attn.hook_attn_scores", surgery_hook) for l in edits_by_layer.keys()] | |
| cleanup_memory() | |
| try: | |
| with torch.inference_mode(): | |
| logits = model.run_with_hooks(text, fwd_hooks=hooks) | |
| new_id = logits[0, -1].argmax().item() | |
| new_token = model.to_string([new_id]) | |
| baseline = state.get("baseline", "?") | |
| msg = f"Baseline: '{baseline}' -> Surgery: '{new_token}'" | |
| if baseline != new_token: msg = "π SUCCESS! " + msg | |
| else: msg = "β No change. " + msg | |
| return msg, None | |
| except Exception as e: | |
| return f"Error: {e}", None | |
| def clear_edits(state): | |
| if state: state["edits"] = [] | |
| return state, "Edits cleared." | |
| # ========================================== | |
| # 4. ADVANCED MI TOOLS (Circuit Discovery) | |
| # ========================================== | |
| def compute_head_attributions(state): | |
| """Computes which heads contribute most to the baseline logit.""" | |
| model = get_model() | |
| if not model or not state: return None | |
| text = state["text"] | |
| target_id = state["baseline_id"] | |
| # 1. Get clean logit | |
| with torch.inference_mode(): | |
| logits = model(text) | |
| clean_logit = logits[0, -1, target_id].item() | |
| n_layers, n_heads = model.cfg.n_layers, model.cfg.n_heads | |
| attr_map = np.zeros((n_layers, n_heads)) | |
| # 2. Ablation Loop (Iterative - slow but memory safe) | |
| # We ablate the output of each head (hook_z) | |
| print("β³ Running ablation... this might take a moment.") | |
| for l in range(n_layers): | |
| # Optimization: We can define the hook function once | |
| def get_ablate_hook(head_idx): | |
| def hook(z, hook): | |
| z[0, :, head_idx, :] = 0 # Zero out the head output | |
| return z | |
| return hook | |
| for h in range(n_heads): | |
| # We run forward pass with ONE head ablated | |
| with torch.inference_mode(): | |
| ablated_logits = model.run_with_hooks( | |
| text, | |
| fwd_hooks=[(f"blocks.{l}.attn.hook_z", get_ablate_hook(h))] | |
| ) | |
| diff = clean_logit - ablated_logits[0, -1, target_id].item() | |
| attr_map[l, h] = diff | |
| fig = go.Figure(data=go.Heatmap( | |
| z=attr_map, colorscale='RdBu_r', zmid=0 | |
| )) | |
| fig.update_layout(title="Head Importance (Logit Drop via Ablation)", xaxis_title="Head", yaxis_title="Layer") | |
| cleanup_memory() | |
| return fig | |
| # ========================================== | |
| # 5. ACTIVATION PATCHING | |
| # ========================================== | |
| def prepare_patch_cache(state, corrupt_text): | |
| model = get_model() | |
| if not model or not state: return "Model missing", state | |
| cleanup_memory() | |
| with torch.inference_mode(): | |
| _, cache = model.run_with_cache(corrupt_text) | |
| # We need to save specific activations to CPU to avoid VRAM hogging | |
| # We can't save everything. Let's save resid_post for all layers. | |
| # This is a compromise. Ideally, we re-run for specific layers. | |
| # Strategy: Save the *Corrupted Text* string. We will re-run the corrupted pass | |
| # specifically for the layer requested during the patch step. | |
| state["corrupt_text"] = corrupt_text | |
| del cache | |
| cleanup_memory() | |
| return f"Ready to patch. Corrupted input: '{corrupt_text}'", state | |
| def run_activation_patch(state, layer, hook_type): | |
| model = get_model() | |
| if not model or "corrupt_text" not in state: return "Cache corrupted run first.", None | |
| clean_text = state["text"] | |
| corrupt_text = state["corrupt_text"] | |
| layer = int(layer) | |
| cleanup_memory() | |
| # 1. Get the CLEAN activation for this specific component | |
| hook_map = { | |
| "attn": f"blocks.{layer}.hook_attn_out", | |
| "mlp": f"blocks.{layer}.hook_mlp_out", | |
| "resid": f"blocks.{layer}.hook_resid_post" | |
| } | |
| target_hook = hook_map.get(hook_type, hook_map["attn"]) | |
| with torch.inference_mode(): | |
| _, clean_cache = model.run_with_cache(clean_text, names_filter=[target_hook]) | |
| clean_act = clean_cache[target_hook] | |
| # 2. Define Patch Hook: Inject CLEAN act into CORRUPT run | |
| def patch_hook(act, hook): | |
| # Act is from corrupt run. Replace with clean act. | |
| return clean_act | |
| # 3. Run Corrupted Pass with Patch | |
| with torch.inference_mode(): | |
| patched_logits = model.run_with_hooks( | |
| corrupt_text, | |
| fwd_hooks=[(target_hook, patch_hook)] | |
| ) | |
| patched_token = model.to_string(patched_logits[0, -1].argmax()) | |
| # Cleanup | |
| del clean_cache, clean_act | |
| cleanup_memory() | |
| return f"Patching {hook_type} @ L{layer}: Resulting Token = '{patched_token}'", None | |
| # ========================================== | |
| # UI CONSTRUCTION | |
| # ========================================== | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| state = gr.State() | |
| gr.Markdown("# π§ Gemma-2-2B: Pre-Softmax Surgery Lab") | |
| gr.Markdown("Optimized for Low-VRAM. Model loads once. Greedy cache removal.") | |
| # STAGE 1: Always visible | |
| with gr.Row(): | |
| txt_input = gr.Textbox("The Eiffel Tower is in", label="Input Text") | |
| btn_run = gr.Button("1. Run Analysis", variant="primary") | |
| out_baseline = gr.Textbox(label="Baseline Prediction") | |
| with gr.Row(): | |
| plot_heatmap = gr.Plot(label="Pre-Softmax Attention (Gemma 18x8)") | |
| txt_tokens = gr.HighlightedText(label="Token Indices", combine_adjacent=False) | |
| circuit_summary = gr.Textbox(label="Quick Tips", visible=False) | |
| # STAGE 2: Hidden until run | |
| with gr.Group(visible=False) as controls_group: | |
| # HEAD INSPECTOR | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Inspect Head") | |
| head_dropdown = gr.Dropdown(choices=[], label="Select Head") | |
| btn_load_head = gr.Button("Load Head Data") | |
| head_info = gr.Textbox(label="Details", lines=2) | |
| with gr.Column(scale=2): | |
| specific_head_plot = gr.Plot(label="Raw Scores") | |
| # SURGERY STATION | |
| gr.Markdown("### πͺ Pre-Softmax Surgery") | |
| gr.Markdown("*Modify raw attention scores before they become probabilities.*") | |
| with gr.Row(): | |
| num_layer = gr.Textbox(label="Layer (0-17)", value="10") | |
| num_head = gr.Textbox(label="Head (0-7)", value="0") | |
| num_pos = gr.Textbox(label="Token Pos", value="2") | |
| num_val = gr.Textbox(label="New Score (+/-)", value="-100") | |
| with gr.Row(): | |
| btn_add = gr.Button("β Add Edit") | |
| btn_clear = gr.Button("ποΈ Clear") | |
| btn_exec = gr.Button("π EXECUTE", variant="stop") | |
| txt_status = gr.Textbox(label="Edit Queue", lines=3) | |
| out_result = gr.Textbox(label="Surgery Result", lines=2) | |
| # ADVANCED TOOLS | |
| with gr.Accordion("π¬ Advanced Tools (Circuit Discovery)", open=False): | |
| gr.Markdown("**Attribution Heatmap**: Runs ablation on every head to see importance.") | |
| btn_attribution = gr.Button("Compute Attribution Heatmap (Slow)") | |
| plot_attribution = gr.Plot() | |
| gr.Markdown("---") | |
| gr.Markdown("**Activation Patching (Denoising)**") | |
| with gr.Row(): | |
| txt_corrupt = gr.Textbox("The Eiffel Tower is in Rome", label="Corrupted Prompt") | |
| btn_prep_patch = gr.Button("1. Prepare Corrupt Run") | |
| with gr.Row(): | |
| num_patch_l = gr.Number(label="Layer", value=10) | |
| type_patch = gr.Dropdown(["attn", "mlp", "resid"], value="attn", label="Component") | |
| btn_do_patch = gr.Button("2. Run Patch") | |
| out_patch = gr.Textbox(label="Patch Result") | |
| # WIRING | |
| btn_run.click( | |
| run_first_pass, | |
| txt_input, | |
| [out_baseline, plot_heatmap, txt_tokens, state, head_dropdown, specific_head_plot, controls_group, circuit_summary, circuit_summary] | |
| ) | |
| btn_load_head.click(show_head_attention, [state, head_dropdown], [specific_head_plot, head_info, num_layer, num_head]) | |
| btn_add.click(add_edit_to_state, [state, num_layer, num_head, num_pos, num_val], [state, txt_status]) | |
| btn_clear.click(clear_edits, state, [state, txt_status]) | |
| btn_exec.click(run_surgery, state, [out_result, plot_heatmap]) | |
| btn_attribution.click(compute_head_attributions, state, plot_attribution) | |
| btn_prep_patch.click(prepare_patch_cache, [state, txt_corrupt], [out_patch, state]) | |
| btn_do_patch.click(run_activation_patch, [state, num_patch_l, type_patch], [out_patch, plot_heatmap]) | |
| if __name__ == "__main__": | |
| demo.launch() | |