Attention-Show / app.py
Nicknam's picture
Update app.py
06eec43 verified
import os
import gc
import gradio as gr
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from typing import Dict, List, Tuple
from huggingface_hub import login
# ==========================================
# 0. SETUP & ENVIRONMENT
# ==========================================
# Standard environment optimization for CPU/Low-VRAM
os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
# Login if token exists (Gemma is a gated model)
if os.environ.get("HF_TOKEN"):
login(token=os.environ["HF_TOKEN"])
# Heavy imports delayed to keep UI start fast
torch = None
HookedTransformer = None
def import_heavy():
global torch, HookedTransformer
if torch is None:
import torch as t
torch = t
from transformer_lens import HookedTransformer as HT
HookedTransformer = HT
print("πŸ“¦ Heavy libraries imported")
# ==========================================
# 1. LAZY SINGLETON LOADING
# ==========================================
MODEL = None
def cleanup_memory():
"""Aggressive memory cleanup."""
gc.collect()
if torch is not None and torch.cuda.is_available():
torch.cuda.empty_cache()
def get_model():
"""Singleton pattern: Loads Gemma-2-2B once and keeps it."""
global MODEL
if MODEL is None:
import_heavy()
print("⏳ Loading Gemma-2-2B... (18 layers, 8 heads)")
cleanup_memory()
try:
# Load on CPU by default to save VRAM.
# If you have a GPU, change device="cuda" and dtype="float16"
MODEL = HookedTransformer.from_pretrained(
"google/gemma-2-2b",
device="cpu",
dtype="float32" # Use float16 if on GPU
)
MODEL.eval()
print(f"βœ… Gemma Loaded! {MODEL.cfg.n_layers}L x {MODEL.cfg.n_heads}H")
except Exception as e:
print(f"❌ Error loading model: {e}")
return None
return MODEL
# ==========================================
# 2. CORE ANALYSIS (Inference + Extraction)
# ==========================================
def run_first_pass(text):
model = get_model()
if model is None:
return "Model failed", None, None, None, None, None, None, "Error", None
cleanup_memory()
# 1. Run Inference & Capture Cache
# We use pre-softmax scores (attn_scores) for better surgery precision
try:
with torch.inference_mode():
logits, cache = model.run_with_cache(text)
# 2. Extract Data to CPU/Numpy immediately (Greedy Memory Saving)
# We only keep what we need for the dashboard to avoid storing heavy tensors
# Token Info
tokens = model.to_str_tokens(text)
token_display = [(t, str(i)) for i, t in enumerate(tokens)]
# Prediction
baseline_token_id = logits[0, -1].argmax().item()
baseline_token = model.to_string([baseline_token_id])
# Extract Attention Scores (Pre-Softmax) for the LAST token
# Shape: [Layer, Head] -> Max Score
n_layers = model.cfg.n_layers
n_heads = model.cfg.n_heads
# Store full scores map for the last token in a Numpy dictionary
# This allows us to inspect heads later without re-running the model!
attn_scores_cache = {}
heatmap_data = np.zeros((n_layers, n_heads))
for l in range(n_layers):
# Extract [Batch, Head, Query, Key] -> [Head, Key] (for last query pos)
scores = cache[f"blocks.{l}.attn.hook_attn_scores"][0, :, -1, :].cpu().numpy()
attn_scores_cache[l] = scores
# Heatmap value: Max attention score in that head
heatmap_data[l] = scores.max(axis=-1)
# 3. Create Visuals
# Using RdBu_r because pre-softmax scores can be negative
fig = go.Figure(data=go.Heatmap(
z=heatmap_data,
colorscale='RdBu_r',
zmid=0,
hoverongaps=False
))
fig.update_layout(
title=f"Max Attention Scores (Pre-Softmax) | Next: '{baseline_token}'",
xaxis_title="Head",
yaxis_title="Layer"
)
# 4. Generate Circuit Attribution (Simplified for speed)
# We calculate this now so we can show it immediately
top_heads_msg = "Run 'Advanced Tools' -> 'Compute Attribution' for detailed breakdown."
# 5. Build Lightweight State (No Torch Tensors!)
new_state = {
"text": text,
"tokens": tokens,
"baseline": baseline_token,
"baseline_id": baseline_token_id,
"attn_scores": attn_scores_cache, # Numpy arrays
"edits": [],
"corrupt_cache": None # Will fill on demand
}
# Cleanup
del cache
del logits
cleanup_memory()
head_list = [f"Layer {l}, Head {h}" for l in range(n_layers) for h in range(n_heads)]
return (
f"Prediction: '{baseline_token}'",
fig,
token_display,
new_state,
gr.update(visible=True, choices=head_list),
gr.update(visible=True),
gr.update(visible=True),
top_heads_msg,
gr.update(visible=True)
)
except Exception as e:
import traceback
traceback.print_exc()
return f"Error: {e}", None, None, None, None, None, None, None, None
# ==========================================
# 3. INSPECTION & EDITING (Works on Numpy)
# ==========================================
def show_head_attention(state, selected_head_str):
if not state or not selected_head_str:
return None, "Select a head first", "0", "0"
try:
# Parse L/H
parts = selected_head_str.replace("Layer ", "").replace("Head ", "").split(", ")
layer = int(parts[0])
head = int(parts[1])
# Retrieve Numpy data from state (No GPU needed!)
scores = state["attn_scores"][layer][head]
tokens = state["tokens"]
# Handle shape mismatch (just in case)
scores = scores[:len(tokens)]
fig = px.bar(
x=list(range(len(tokens))),
y=scores,
labels={'x': 'Token', 'y': 'Raw Score'},
title=f"L{layer}H{head} Pre-Softmax Scores",
color=scores,
color_continuous_scale='RdBu_r'
)
# Use actual tokens on X-axis
fig.update_xaxes(
tickmode='array',
tickvals=list(range(len(tokens))),
ticktext=[f"{i}: '{t}'" for i, t in enumerate(tokens)],
tickangle=-45
)
info = f"Max score: {scores.max():.2f} on '{tokens[scores.argmax()]}'"
return fig, info, str(layer), str(head)
except Exception as e:
return None, f"Error: {e}", "0", "0"
def add_edit_to_state(state, layer, head, pos_idx, value):
if not state: return state, "Run first pass first"
try:
l, h, p, v = int(layer), int(head), int(pos_idx), float(value)
if p >= len(state["tokens"]): return state, "Position out of range"
new_edit = {"layer": l, "head": h, "pos": p, "value": v}
state["edits"].append(new_edit)
status = "\n".join([f"L{e['layer']}H{e['head']} @ Pos {e['pos']} = {e['value']}" for e in state["edits"]])
return state, status
except:
return state, "Invalid Input"
def run_surgery(state):
model = get_model()
if not model or not state or not state["edits"]:
return "No edits or model failed.", None
text = state["text"]
# Group edits for efficiency
edits_by_layer = {}
for edit in state["edits"]:
l = edit["layer"]
if l not in edits_by_layer: edits_by_layer[l] = []
edits_by_layer[l].append(edit)
# Define Hook
def surgery_hook(attn_scores, hook):
# attn_scores shape: [batch, head, query_len, key_len]
layer = hook.layer()
if layer in edits_by_layer:
for edit in edits_by_layer[layer]:
h, pos, val = edit["head"], edit["pos"], edit["value"]
# Edit the score for the LAST query token (steering next step)
attn_scores[0, h, -1, pos] = val
return attn_scores
hooks = [(f"blocks.{l}.attn.hook_attn_scores", surgery_hook) for l in edits_by_layer.keys()]
cleanup_memory()
try:
with torch.inference_mode():
logits = model.run_with_hooks(text, fwd_hooks=hooks)
new_id = logits[0, -1].argmax().item()
new_token = model.to_string([new_id])
baseline = state.get("baseline", "?")
msg = f"Baseline: '{baseline}' -> Surgery: '{new_token}'"
if baseline != new_token: msg = "πŸŽ‰ SUCCESS! " + msg
else: msg = "❌ No change. " + msg
return msg, None
except Exception as e:
return f"Error: {e}", None
def clear_edits(state):
if state: state["edits"] = []
return state, "Edits cleared."
# ==========================================
# 4. ADVANCED MI TOOLS (Circuit Discovery)
# ==========================================
def compute_head_attributions(state):
"""Computes which heads contribute most to the baseline logit."""
model = get_model()
if not model or not state: return None
text = state["text"]
target_id = state["baseline_id"]
# 1. Get clean logit
with torch.inference_mode():
logits = model(text)
clean_logit = logits[0, -1, target_id].item()
n_layers, n_heads = model.cfg.n_layers, model.cfg.n_heads
attr_map = np.zeros((n_layers, n_heads))
# 2. Ablation Loop (Iterative - slow but memory safe)
# We ablate the output of each head (hook_z)
print("⏳ Running ablation... this might take a moment.")
for l in range(n_layers):
# Optimization: We can define the hook function once
def get_ablate_hook(head_idx):
def hook(z, hook):
z[0, :, head_idx, :] = 0 # Zero out the head output
return z
return hook
for h in range(n_heads):
# We run forward pass with ONE head ablated
with torch.inference_mode():
ablated_logits = model.run_with_hooks(
text,
fwd_hooks=[(f"blocks.{l}.attn.hook_z", get_ablate_hook(h))]
)
diff = clean_logit - ablated_logits[0, -1, target_id].item()
attr_map[l, h] = diff
fig = go.Figure(data=go.Heatmap(
z=attr_map, colorscale='RdBu_r', zmid=0
))
fig.update_layout(title="Head Importance (Logit Drop via Ablation)", xaxis_title="Head", yaxis_title="Layer")
cleanup_memory()
return fig
# ==========================================
# 5. ACTIVATION PATCHING
# ==========================================
def prepare_patch_cache(state, corrupt_text):
model = get_model()
if not model or not state: return "Model missing", state
cleanup_memory()
with torch.inference_mode():
_, cache = model.run_with_cache(corrupt_text)
# We need to save specific activations to CPU to avoid VRAM hogging
# We can't save everything. Let's save resid_post for all layers.
# This is a compromise. Ideally, we re-run for specific layers.
# Strategy: Save the *Corrupted Text* string. We will re-run the corrupted pass
# specifically for the layer requested during the patch step.
state["corrupt_text"] = corrupt_text
del cache
cleanup_memory()
return f"Ready to patch. Corrupted input: '{corrupt_text}'", state
def run_activation_patch(state, layer, hook_type):
model = get_model()
if not model or "corrupt_text" not in state: return "Cache corrupted run first.", None
clean_text = state["text"]
corrupt_text = state["corrupt_text"]
layer = int(layer)
cleanup_memory()
# 1. Get the CLEAN activation for this specific component
hook_map = {
"attn": f"blocks.{layer}.hook_attn_out",
"mlp": f"blocks.{layer}.hook_mlp_out",
"resid": f"blocks.{layer}.hook_resid_post"
}
target_hook = hook_map.get(hook_type, hook_map["attn"])
with torch.inference_mode():
_, clean_cache = model.run_with_cache(clean_text, names_filter=[target_hook])
clean_act = clean_cache[target_hook]
# 2. Define Patch Hook: Inject CLEAN act into CORRUPT run
def patch_hook(act, hook):
# Act is from corrupt run. Replace with clean act.
return clean_act
# 3. Run Corrupted Pass with Patch
with torch.inference_mode():
patched_logits = model.run_with_hooks(
corrupt_text,
fwd_hooks=[(target_hook, patch_hook)]
)
patched_token = model.to_string(patched_logits[0, -1].argmax())
# Cleanup
del clean_cache, clean_act
cleanup_memory()
return f"Patching {hook_type} @ L{layer}: Resulting Token = '{patched_token}'", None
# ==========================================
# UI CONSTRUCTION
# ==========================================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
state = gr.State()
gr.Markdown("# 🧠 Gemma-2-2B: Pre-Softmax Surgery Lab")
gr.Markdown("Optimized for Low-VRAM. Model loads once. Greedy cache removal.")
# STAGE 1: Always visible
with gr.Row():
txt_input = gr.Textbox("The Eiffel Tower is in", label="Input Text")
btn_run = gr.Button("1. Run Analysis", variant="primary")
out_baseline = gr.Textbox(label="Baseline Prediction")
with gr.Row():
plot_heatmap = gr.Plot(label="Pre-Softmax Attention (Gemma 18x8)")
txt_tokens = gr.HighlightedText(label="Token Indices", combine_adjacent=False)
circuit_summary = gr.Textbox(label="Quick Tips", visible=False)
# STAGE 2: Hidden until run
with gr.Group(visible=False) as controls_group:
# HEAD INSPECTOR
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Inspect Head")
head_dropdown = gr.Dropdown(choices=[], label="Select Head")
btn_load_head = gr.Button("Load Head Data")
head_info = gr.Textbox(label="Details", lines=2)
with gr.Column(scale=2):
specific_head_plot = gr.Plot(label="Raw Scores")
# SURGERY STATION
gr.Markdown("### πŸ”ͺ Pre-Softmax Surgery")
gr.Markdown("*Modify raw attention scores before they become probabilities.*")
with gr.Row():
num_layer = gr.Textbox(label="Layer (0-17)", value="10")
num_head = gr.Textbox(label="Head (0-7)", value="0")
num_pos = gr.Textbox(label="Token Pos", value="2")
num_val = gr.Textbox(label="New Score (+/-)", value="-100")
with gr.Row():
btn_add = gr.Button("βž• Add Edit")
btn_clear = gr.Button("πŸ—‘οΈ Clear")
btn_exec = gr.Button("πŸš€ EXECUTE", variant="stop")
txt_status = gr.Textbox(label="Edit Queue", lines=3)
out_result = gr.Textbox(label="Surgery Result", lines=2)
# ADVANCED TOOLS
with gr.Accordion("πŸ”¬ Advanced Tools (Circuit Discovery)", open=False):
gr.Markdown("**Attribution Heatmap**: Runs ablation on every head to see importance.")
btn_attribution = gr.Button("Compute Attribution Heatmap (Slow)")
plot_attribution = gr.Plot()
gr.Markdown("---")
gr.Markdown("**Activation Patching (Denoising)**")
with gr.Row():
txt_corrupt = gr.Textbox("The Eiffel Tower is in Rome", label="Corrupted Prompt")
btn_prep_patch = gr.Button("1. Prepare Corrupt Run")
with gr.Row():
num_patch_l = gr.Number(label="Layer", value=10)
type_patch = gr.Dropdown(["attn", "mlp", "resid"], value="attn", label="Component")
btn_do_patch = gr.Button("2. Run Patch")
out_patch = gr.Textbox(label="Patch Result")
# WIRING
btn_run.click(
run_first_pass,
txt_input,
[out_baseline, plot_heatmap, txt_tokens, state, head_dropdown, specific_head_plot, controls_group, circuit_summary, circuit_summary]
)
btn_load_head.click(show_head_attention, [state, head_dropdown], [specific_head_plot, head_info, num_layer, num_head])
btn_add.click(add_edit_to_state, [state, num_layer, num_head, num_pos, num_val], [state, txt_status])
btn_clear.click(clear_edits, state, [state, txt_status])
btn_exec.click(run_surgery, state, [out_result, plot_heatmap])
btn_attribution.click(compute_head_attributions, state, plot_attribution)
btn_prep_patch.click(prepare_patch_cache, [state, txt_corrupt], [out_patch, state])
btn_do_patch.click(run_activation_patch, [state, num_patch_l, type_patch], [out_patch, plot_heatmap])
if __name__ == "__main__":
demo.launch()