AttentionShow / app.py
Nicknam's picture
Update app.py
8c03cb1 verified
import os
import gc
import gradio as gr
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from typing import Dict, List, Tuple, Optional
from huggingface_hub import login
# ==========================================
# 0. SETUP & MEMORY MANAGEMENT
# ==========================================
# Force CPU/Low-VRAM settings
os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
if os.environ.get("HF_TOKEN"):
login(token=os.environ["HF_TOKEN"])
# Delayed imports to keep UI startup fast
torch = None
HookedTransformer = None
F = None
def import_heavy():
global torch, HookedTransformer, F
if torch is None:
import torch as t
import torch.nn.functional as f
from transformer_lens import HookedTransformer as HT
torch = t
HookedTransformer = HT
F = f
print("πŸ“¦ PyTorch & TransformerLens imported.")
# SINGLETON MODEL HOLDER
MODEL = None
def cleanup_memory():
"""Aggressively clears VRAM/RAM."""
gc.collect()
if torch is not None and torch.cuda.is_available():
torch.cuda.empty_cache()
def get_model():
global MODEL
if MODEL is None:
import_heavy()
print("⏳ Loading Gemma-2-2B...")
cleanup_memory()
try:
# Using CPU and float32 for maximum compatibility.
# Change device="cuda" and dtype="float16" if you have a GPU > 8GB VRAM
MODEL = HookedTransformer.from_pretrained(
"google/gemma-2-2b",
device="cpu",
dtype="float32"
)
MODEL.eval()
print("βœ… Model Loaded.")
except Exception as e:
print(f"❌ Load Error: {e}")
return None
return MODEL
# ==========================================
# 1. MAIN INFERENCE PIPELINE
# ==========================================
def run_main_analysis(text):
"""
Runs the model, extracts Logits, Logit Lens, and Head Summary.
"""
model = get_model()
if not model: return "Model Failed", None, None, None, None, None, None
cleanup_memory()
with torch.inference_mode():
# 1. Run Forward Pass & Cache
logits, cache = model.run_with_cache(text)
# 2. Basic Prediction Info
tokens = model.to_str_tokens(text)
token_indices = list(range(len(tokens)))
formatted_tokens = [(t, str(i)) for i, t in enumerate(tokens)]
last_logit = logits[0, -1]
baseline_prob = torch.softmax(last_logit, dim=-1).max().item()
baseline_id = last_logit.argmax().item()
baseline_token = model.to_string(baseline_id)
# 3. FEATURE: Logit Lens (What is the model thinking at layer X?)
# We accumulate the residual stream at the END of every block
logit_lens_data = []
# Iterate through layers to build Logit Lens
for i in range(model.cfg.n_layers):
# Extract residual stream at this layer for the LAST token
resid = cache[f"blocks.{i}.hook_resid_post"][0, -1, :]
# Apply LayerNorm + Unembed (mimic the final head)
# Gemma uses RMSNorm, TransformerLens handles the details usually,
# but manually: scaled_resid -> ln_final -> unembed
ln_resid = model.ln_final(resid)
layer_logits = model.unembed(ln_resid)
best_id = layer_logits.argmax().item()
best_tok = model.to_string(best_id)
prob = torch.softmax(layer_logits, dim=-1).max().item()
logit_lens_data.append({
"layer": i,
"token": best_tok,
"prob": prob
})
# 4. FEATURE: Head Summary (Pre-Softmax Max Scores)
# Used for the Overview Heatmap
n_layers, n_heads = model.cfg.n_layers, model.cfg.n_heads
heatmap_data = np.zeros((n_layers, n_heads))
# Store minimal data for the Head Inspector to avoid VRAM bloat
# We DO NOT store the full 2D matrix here. We re-compute on demand.
head_max_scores = {}
for l in range(n_layers):
# Shape: [Batch, Head, Q, K]. We take max over K for the last Q.
scores = cache[f"blocks.{l}.attn.hook_attn_scores"][0, :, -1, :]
scores_np = scores.detach().cpu().numpy()
heatmap_data[l] = scores_np.max(axis=-1)
head_max_scores[l] = scores_np # [Head, Key] for last pos
# 5. Build State
state = {
"text": text,
"tokens": tokens,
"baseline_id": baseline_id,
"baseline_token": baseline_token,
"head_data": head_max_scores, # Lightweight cache
"logit_lens": logit_lens_data,
"edits": []
}
# --- Visualizations ---
# A. Overview Heatmap
fig_overview = go.Figure(data=go.Heatmap(
z=heatmap_data, colorscale='RdBu_r', zmid=0,
hovertemplate="Layer %{y}, Head %{x}<br>Max Score: %{z:.2f}<extra></extra>"
))
fig_overview.update_layout(
title="Overview: Max Attention Scores (Last Token)",
xaxis_title="Head", yaxis_title="Layer",
height=500
)
# B. Logit Lens Chart
lens_x = [d["layer"] for d in logit_lens_data]
lens_y = [d["prob"] for d in logit_lens_data]
lens_text = [d["token"] for d in logit_lens_data]
fig_lens = go.Figure()
fig_lens.add_trace(go.Scatter(
x=lens_x, y=lens_y, mode='lines+markers+text',
text=lens_text, textposition="top center",
line=dict(color='indigo', width=2)
))
fig_lens.update_layout(
title="Logit Lens: Best Guess at each Layer",
xaxis_title="Layer", yaxis_title="Probability",
height=400
)
# Cleanup
del cache, logits
cleanup_memory()
# Helper for dropdowns
choices = [f"L{l} H{h}" for l in range(n_layers) for h in range(n_heads)]
return (
formatted_tokens,
f"'{baseline_token}' ({baseline_prob:.2%})",
fig_overview,
fig_lens,
state,
gr.update(choices=choices),
gr.update(visible=True) # Show advanced tabs
)
# ==========================================
# 2. FEATURE: DEEP DIVE (2D PATTERNS)
# ==========================================
def get_2d_attention_pattern(state, selection):
"""
Re-runs the model for a specific layer/head to get the FULL QxK matrix.
This is memory efficient (doesn't store all layers) but slower (re-inference).
"""
if not state or not selection: return None
# Parse L/H
parts = selection.split() # "L10 H2"
layer = int(parts[0][1:])
head = int(parts[1][1:])
model = get_model()
text = state["text"]
tokens = state["tokens"]
cleanup_memory()
# Run only capturing specific head
hook_name = f"blocks.{layer}.attn.hook_attn_scores"
with torch.inference_mode():
_, cache = model.run_with_cache(text, names_filter=[hook_name])
# Shape: [1, Head, Seq, Seq]
attn_matrix = cache[hook_name][0, head, :, :].detach().cpu().numpy()
del cache
cleanup_memory()
# Plot 2D Heatmap
# Mask future tokens (causal masking) to make it cleaner,
# though Gemma already masks them with -inf.
fig = go.Figure(data=go.Heatmap(
z=attn_matrix,
x=tokens, y=tokens,
colorscale='RdBu_r', zmid=0,
hoverongaps=False
))
fig.update_layout(
title=f"Full Attention Scores: Layer {layer}, Head {head}",
xaxis_title="Key (Source)",
yaxis_title="Query (Destination)",
height=600,
width=600,
yaxis=dict(autorange="reversed") # Standard attention map orientation
)
return fig
# ==========================================
# 3. FEATURE: CIRCUIT DISCOVERY (ABLATION)
# ==========================================
def run_circuit_discovery(state):
"""
Iteratively ablates heads to find which ones support the correct answer.
"""
if not state: return None
model = get_model()
text = state["text"]
target_id = state["baseline_id"]
# Get clean logit
with torch.inference_mode():
clean_logits = model(text)
clean_score = clean_logits[0, -1, target_id].item()
n_layers, n_heads = model.cfg.n_layers, model.cfg.n_heads
importance_map = np.zeros((n_layers, n_heads))
print("running ablation loop...")
# Efficient Loop
for l in range(n_layers):
# We optimize by defining one hook per layer
def get_ablation_hook(head_to_kill):
def hook(z, hook):
# z shape: [batch, pos, head, d_head]
z[:, :, head_to_kill, :] = 0.0
return z
return hook
for h in range(n_heads):
# Run with one head killed
with torch.inference_mode():
# hook_z is the output of the attention heads before mixing
hook_name = f"blocks.{l}.attn.hook_z"
corrupt_logits = model.run_with_hooks(
text,
fwd_hooks=[(hook_name, get_ablation_hook(h))]
)
corrupt_score = corrupt_logits[0, -1, target_id].item()
# Logit Difference (Clean - Corrupt)
# Positive value = Head was helping
# Negative value = Head was suppressing
importance_map[l, h] = clean_score - corrupt_score
cleanup_memory()
fig = go.Figure(data=go.Heatmap(
z=importance_map,
colorscale='RdBu', zmid=0,
text=np.around(importance_map, 2),
texttemplate="%{text}"
))
fig.update_layout(
title="Head Importance (Logit Difference via Ablation)",
xaxis_title="Head", yaxis_title="Layer"
)
return fig
# ==========================================
# 4. FEATURE: ACTIVATION PATCHING
# ==========================================
def run_activation_patching(state, corrupt_text, layer, component):
"""
Runs a clean pass, saves activation, runs corrupt pass, patches activation in.
"""
if not state or not corrupt_text: return "Missing inputs", None
model = get_model()
clean_text = state["text"]
layer = int(layer)
# 1. Capture CLEAN Activation
# We map "Component" to actual hook names
hook_map = {
"Attention Output": f"blocks.{layer}.hook_attn_out",
"MLP Output": f"blocks.{layer}.hook_mlp_out",
"Residual Stream": f"blocks.{layer}.hook_resid_post"
}
target_hook = hook_map.get(component, hook_map["Residual Stream"])
cleanup_memory()
with torch.inference_mode():
_, clean_cache = model.run_with_cache(clean_text, names_filter=[target_hook])
clean_act = clean_cache[target_hook] # This tensor is on CPU/GPU depending on config
# 2. Define Patch Hook
def patch_hook(act, hook):
# Inject clean activation into corrupt run
return clean_act
# 3. Run Corrupt Pass with Patch
with torch.inference_mode():
patched_logits = model.run_with_hooks(
corrupt_text,
fwd_hooks=[(target_hook, patch_hook)]
)
# 4. Result
patched_id = patched_logits[0, -1].argmax().item()
patched_token = model.to_string(patched_id)
patched_prob = torch.softmax(patched_logits[0, -1], dim=-1).max().item()
# Clean up tensor
del clean_cache, clean_act
cleanup_memory()
return f"Patched Token: '{patched_token}' ({patched_prob:.2%})", None
# ==========================================
# 5. SURGERY (EDITING)
# ==========================================
def add_edit(state, selection, pos, value):
if not state: return state, "Run model first"
try:
parts = selection.split()
l = int(parts[0][1:])
h = int(parts[1][1:])
p = int(pos)
v = float(value)
state["edits"].append({"layer": l, "head": h, "pos": p, "value": v})
return state, f"Added: L{l}H{h} Pos{p} -> {v}"
except:
return state, "Invalid input format"
def execute_surgery(state):
if not state or not state["edits"]: return "No edits queued", None
model = get_model()
text = state["text"]
# Prepare Hooks
hooks = []
# We need to group edits by layer to handle them efficiently
edits_by_layer = {}
for e in state["edits"]:
if e["layer"] not in edits_by_layer: edits_by_layer[e["layer"]] = []
edits_by_layer[e["layer"]].append(e)
for l, edits in edits_by_layer.items():
def make_hook(current_edits):
def hook(scores, h):
# scores: [batch, head, query, key]
# We edit the LAST query position (prediction step)
for edit in current_edits:
head_idx = edit["head"]
key_pos = edit["pos"]
val = edit["value"]
scores[0, head_idx, -1, key_pos] = val
return scores
return hook
hooks.append((f"blocks.{l}.attn.hook_attn_scores", make_hook(edits)))
with torch.inference_mode():
new_logits = model.run_with_hooks(text, fwd_hooks=hooks)
new_id = new_logits[0, -1].argmax().item()
new_token = model.to_string(new_id)
return f"Surgery Result: '{new_token}' (Baseline was '{state['baseline_token']}')", None
def clear_edits(state):
state["edits"] = []
return state, "Edits cleared."
# ==========================================
# UI LAYOUT
# ==========================================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
state = gr.State()
gr.Markdown("# πŸ”¬ Gemma-2-2B Mechanistic Dashboard (Full Suite)")
gr.Markdown("Includes: Logit Lens, 2D Attention Patterns, Activation Patching, and Circuit Discovery.")
# --- TOP CONTROL ---
with gr.Row():
txt_input = gr.Textbox("The capital of France is", label="Input Prompt", scale=2)
btn_run = gr.Button("πŸš€ Run Analysis", variant="primary", scale=1)
with gr.Row():
txt_tokens = gr.HighlightedText(label="Tokenized Input", combine_adjacent=False)
lbl_pred = gr.Label(label="Prediction")
# --- MAIN TABS ---
with gr.Tabs(visible=False) as main_tabs:
# TAB 1: OVERVIEW & LENS
with gr.TabItem("Overview & Logit Lens"):
with gr.Row():
plot_overview = gr.Plot(label="Attention Head Max Scores")
plot_lens = gr.Plot(label="Logit Lens (Layer Predictions)")
# TAB 2: PATTERN INSPECTOR
with gr.TabItem("πŸ” 2D Pattern Inspector"):
gr.Markdown("**Deep Dive:** Select a head to view the full Query x Key attention matrix.")
with gr.Row():
with gr.Column(scale=1):
dd_head = gr.Dropdown(label="Select Head")
btn_show_pattern = gr.Button("Show 2D Pattern")
with gr.Column(scale=3):
plot_pattern = gr.Plot(label="Attention Matrix")
# TAB 3: CIRCUIT DISCOVERY
with gr.TabItem("🧬 Circuit Discovery"):
gr.Markdown("**Ablation Study:** Automatically zeroes out each head to see if it helps or hurts the correct prediction.")
btn_circuit = gr.Button("Run Ablation Sweep (Takes ~10s)")
plot_circuit = gr.Plot()
# TAB 4: SURGERY
with gr.TabItem("πŸ”ͺ Surgery"):
gr.Markdown("**Intervention:** Edit Pre-Softmax attention scores.")
with gr.Row():
dd_surgery_head = gr.Dropdown(label="Target Head")
num_pos = gr.Number(label="Key Position", value=0, precision=0)
num_val = gr.Number(label="New Score", value=-10.0)
btn_add = gr.Button("Add Edit")
txt_log = gr.Textbox(label="Edit Log")
with gr.Row():
btn_exec = gr.Button("Execute Surgery", variant="stop")
btn_clear = gr.Button("Clear Edits")
out_surgery = gr.Textbox(label="Surgery Output")
# TAB 5: PATCHING
with gr.TabItem("🩹 Activation Patching"):
gr.Markdown("**Denoising:** Inject clean activations into a corrupted run.")
txt_corrupt = gr.Textbox(label="Corrupted Prompt", value="The capital of Germany is")
with gr.Row():
num_layer = gr.Number(label="Layer", value=10)
dd_comp = gr.Dropdown(["Residual Stream", "Attention Output", "MLP Output"], value="Residual Stream", label="Component")
btn_patch = gr.Button("Run Patch")
out_patch = gr.Textbox(label="Patch Result")
# --- EVENTS ---
# Main Run
btn_run.click(
run_main_analysis,
inputs=[txt_input],
outputs=[txt_tokens, lbl_pred, plot_overview, plot_lens, state, dd_head, main_tabs]
)
# 2D Pattern
btn_show_pattern.click(
get_2d_attention_pattern,
inputs=[state, dd_head],
outputs=[plot_pattern]
)
# Circuit Discovery
btn_circuit.click(
run_circuit_discovery,
inputs=[state],
outputs=[plot_circuit]
)
# Surgery
# Sync dropdowns for convenience
dd_head.change(lambda x: x, inputs=dd_head, outputs=dd_surgery_head)
btn_add.click(add_edit, [state, dd_surgery_head, num_pos, num_val], [state, txt_log])
btn_clear.click(clear_edits, state, [state, txt_log])
btn_exec.click(execute_surgery, state, [out_surgery, plot_overview])
# Patching
btn_patch.click(
run_activation_patching,
[state, txt_corrupt, num_layer, dd_comp],
[out_patch, plot_overview]
)
if __name__ == "__main__":
demo.launch()