llm-explorer / prototype_attention.py
chyams's picture
Add Attention Explorer tab (GPT-2 Large) for Lecture 7.5
b701256
#!/usr/bin/env python3
"""
Attention Explorer Prototype β€” Task 1 (Issue #13)
Tests whether GPT-2 attention weights, hidden states, and intermediate-layer
unembedding produce interpretable results for the Attention Explorer tool.
Key questions:
1. Does intermediate-layer unembedding produce interpretable nearest neighbors?
2. Does the word "class" show context-dependent neighbors at deeper layers?
3. Which method (unembedding vs cosine similarity) gives better results?
Usage:
source ~/venvs/responsible-ai-course-tools/bin/activate
python prototype_attention.py
"""
import sys
import os
import logging
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
# Suppress transformers progress bars and warnings for clean output
logging.getLogger("transformers").setLevel(logging.ERROR)
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
# ── Configuration ────────────────────────────────────────────────────────────
MODEL_NAME = "gpt2"
TEST_SENTENCES = [
"This semester we are reading Song of Solomon in Prof. Mouton's class.",
"The upper class kept getting richer while everyone else struggled.",
"She was a class act from start to finish.",
]
# GPT-2 has 12 layers (0-11). We sample across the depth.
PROBE_LAYERS = [0, 3, 6, 9, 11]
TOP_K = 5 # Number of nearest neighbors to show
FOCUS_WORD = "class" # Word to compare across sentences
# ── Helpers ──────────────────────────────────────────────────────────────────
def group_tokens_into_words(tokens):
"""
Group subword tokens into words.
In GPT-2's tokenizer, tokens starting with 'Δ ' (displayed as a space)
begin a new word. The very first token also starts a new word.
Returns: list of (word_str, [token_indices])
"""
words = []
current_word_tokens = []
current_word_str = ""
for i, tok in enumerate(tokens):
if i == 0 or tok.startswith("Δ "):
# Start a new word
if current_word_tokens:
words.append((current_word_str, current_word_tokens))
current_word_tokens = [i]
current_word_str = tok.lstrip("Δ ")
else:
# Continuation of current word
current_word_tokens.append(i)
current_word_str += tok
if current_word_tokens:
words.append((current_word_str, current_word_tokens))
return words
def merge_attention_to_words(attn_matrix, word_groups):
"""
Merge a token-level attention matrix (seq_len x seq_len) into a
word-level matrix (num_words x num_words).
Strategy:
- Rows (query): average across subword tokens in each word
- Columns (key): sum across subword tokens in each word
"""
seq_len = attn_matrix.shape[0]
num_words = len(word_groups)
merged = torch.zeros(num_words, num_words)
for wi, (_, w_indices) in enumerate(word_groups):
for wj, (_, v_indices) in enumerate(word_groups):
# Average over query subwords, sum over key subwords
block = attn_matrix[w_indices][:, v_indices] # (|wi| x |wj|)
merged[wi, wj] = block.mean(dim=0).sum()
return merged
def get_top_k_by_unembedding(hidden_state, unembed_weight, tokenizer, k=TOP_K,
layer_norm=None):
"""
Project a hidden state through the unembedding matrix to get top-k tokens.
unembed_weight shape: (vocab_size, hidden_dim)
hidden_state shape: (hidden_dim,)
If layer_norm is provided, applies it before unembedding (tuned lens).
Without layer_norm, this is the "logit lens" approach.
"""
if layer_norm is not None:
hidden_state = layer_norm(hidden_state)
logits = hidden_state @ unembed_weight.T # (vocab_size,)
probs = F.softmax(logits, dim=-1)
top_probs, top_indices = probs.topk(k)
tokens = [tokenizer.decode([idx]).strip() for idx in top_indices]
return list(zip(tokens, top_probs.tolist()))
def get_top_k_by_cosine(hidden_state, embed_weight, tokenizer, k=TOP_K):
"""
Find top-k tokens by cosine similarity against the embedding table.
embed_weight shape: (vocab_size, hidden_dim)
hidden_state shape: (hidden_dim,)
"""
hidden_norm = F.normalize(hidden_state.unsqueeze(0), dim=-1) # (1, hidden_dim)
embed_norm = F.normalize(embed_weight, dim=-1) # (vocab_size, hidden_dim)
sims = (hidden_norm @ embed_norm.T).squeeze(0) # (vocab_size,)
top_sims, top_indices = sims.topk(k)
tokens = [tokenizer.decode([idx]).strip() for idx in top_indices]
return list(zip(tokens, top_sims.tolist()))
def format_neighbors(neighbors):
"""Format a list of (token, score) pairs."""
return ", ".join(f"'{tok}' ({score:.4f})" for tok, score in neighbors)
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
print("=" * 80)
print("ATTENTION EXPLORER PROTOTYPE β€” Task 1")
print("Model:", MODEL_NAME)
print("=" * 80)
# ── Load model and tokenizer ─────────────────────────────────────────
print("\n[1] Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Must use eager attention to get attention weight tensors.
# Transformers 5.0 defaults to SDPA which returns None for attentions.
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, attn_implementation="eager"
)
model.eval()
# Model info
num_params = sum(p.numel() for p in model.parameters())
mem_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
print(f" Parameters: {num_params:,}")
print(f" Memory footprint: {mem_bytes / 1024 / 1024:.1f} MB")
print(f" Layers: {model.config.n_layer}")
print(f" Heads: {model.config.n_head}")
print(f" Hidden dim: {model.config.n_embd}")
print(f" Vocab size: {model.config.vocab_size}")
# Get weight matrices and final layer norm
unembed_weight = model.lm_head.weight.detach() # (vocab_size, hidden_dim)
embed_weight = model.transformer.wte.weight.detach() # (vocab_size, hidden_dim)
ln_f = model.transformer.ln_f # final layer norm
# Check if unembed and embed are tied (GPT-2 ties them)
weights_tied = torch.equal(unembed_weight, embed_weight)
print(f" Weights tied (embed == unembed): {weights_tied}")
# ── Process each sentence ────────────────────────────────────────────
# Store "class" results across sentences for comparison
class_results = {}
for sent_idx, text in enumerate(TEST_SENTENCES):
print("\n" + "=" * 80)
print(f"SENTENCE {sent_idx + 1}: \"{text}\"")
print("=" * 80)
# Tokenize
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"][0]
tokens = [tokenizer.decode([tid]) for tid in input_ids]
# ── (c) Tokenization and forward pass ────────────────────────────
print(f"\n[c] Tokenization ({len(tokens)} tokens):")
for i, (tid, tok) in enumerate(zip(input_ids, tokens)):
print(f" [{i:2d}] {tid:6d} -> {repr(tok)}")
with torch.no_grad():
outputs = model(
**inputs,
output_attentions=True,
output_hidden_states=True,
)
attentions = outputs.attentions # tuple of num_layers tensors
hidden_states = outputs.hidden_states # tuple of (num_layers+1) tensors
# Attention shape
print(f"\n Attention tensors: {len(attentions)} layers")
print(f" Each shape: {attentions[0].shape}")
print(f" -> (batch=1, heads={model.config.n_head}, "
f"seq_len={len(tokens)}, seq_len={len(tokens)})")
# Verify causal mask (lower triangular)
attn_layer0 = attentions[0][0, 0] # first head of first layer
is_causal = True
for i in range(len(tokens)):
for j in range(i + 1, len(tokens)):
if attn_layer0[i, j].item() > 1e-6:
is_causal = False
break
print(f" Causal mask verified (upper triangle ~0): {is_causal}")
# ── (d) Subword merging ──────────────────────────────────────────
word_groups = group_tokens_into_words(
[tokenizer.decode([tid]) for tid in input_ids]
)
# Fix: re-tokenize using the raw token strings for GPT-2's Δ  prefix
raw_tokens = tokenizer.convert_ids_to_tokens(input_ids.tolist())
word_groups = group_tokens_into_words(raw_tokens)
print(f"\n[d] Subword grouping ({len(word_groups)} words):")
for word, indices in word_groups:
subtokens = [raw_tokens[i] for i in indices]
print(f" '{word}' <- tokens {indices}: {subtokens}")
# Merge attention for layer 0, head 0 as a sanity check
attn_l0_h0 = attentions[0][0, 0] # (seq_len, seq_len)
merged = merge_attention_to_words(attn_l0_h0, word_groups)
print(f"\n Merged attention (layer 0, head 0): {merged.shape}")
# Check rows sum to ~1
row_sums = merged.sum(dim=-1)
print(f" Row sums (should be ~1.0): "
f"min={row_sums.min():.4f}, max={row_sums.max():.4f}, "
f"mean={row_sums.mean():.4f}")
# ── (e) Unembedding at intermediate layers ───────────────────────
print(f"\n[e] Intermediate-layer unembedding (top-{TOP_K} neighbors):")
print("-" * 78)
# Find index of "class" in word groups.
# Also identify which token indices correspond to just "class"
# (exclude trailing punctuation tokens from the hidden state average).
class_word_idx = None
class_token_indices = None
for wi, (word, indices) in enumerate(word_groups):
clean = word.rstrip(".,!?;:'\"")
if clean.lower() == FOCUS_WORD:
class_word_idx = wi
# Use only the token(s) that correspond to the focus word,
# excluding any trailing punctuation tokens in the group.
# For GPT-2, punctuation is typically the last token(s).
focus_indices = []
for idx in indices:
tok_str = raw_tokens[idx]
stripped = tok_str.lstrip("Δ ")
if stripped.isalpha():
focus_indices.append(idx)
class_token_indices = focus_indices if focus_indices else indices
break
for wi, (word, indices) in enumerate(word_groups):
# Only show a subset of words unless it's the focus word
# Show first word, last word, and the focus word
if wi not in [0, len(word_groups) - 1] and wi != class_word_idx:
continue
# For the focus word, use only the core token indices
# (excluding trailing punctuation that got grouped with it)
use_indices = class_token_indices if wi == class_word_idx else indices
label = FOCUS_WORD if wi == class_word_idx else word
print(f"\n Word: '{label}' (tokens {use_indices})")
layer_results = []
for layer_idx in PROBE_LAYERS:
# hidden_states[0] is embedding output, [1] is after layer 0, etc.
hs = hidden_states[layer_idx + 1][0] # (seq_len, hidden_dim)
# Average hidden states for multi-token words
word_hidden = hs[use_indices].mean(dim=0) # (hidden_dim,)
neighbors = get_top_k_by_unembedding(
word_hidden, unembed_weight, tokenizer
)
layer_results.append((layer_idx, neighbors))
print(f" Layer {layer_idx:2d}: {format_neighbors(neighbors)}")
# Store class results for cross-sentence comparison
if wi == class_word_idx:
class_results[sent_idx] = layer_results
# ── (f) Cosine similarity comparison (for focus word only) ───────
if class_word_idx is not None:
print(f"\n[f] Cosine similarity comparison for '{FOCUS_WORD}':")
print("-" * 78)
for layer_idx in PROBE_LAYERS:
hs = hidden_states[layer_idx + 1][0]
word_hidden = hs[class_token_indices].mean(dim=0)
unembed_neighbors = get_top_k_by_unembedding(
word_hidden, unembed_weight, tokenizer
)
cosine_neighbors = get_top_k_by_cosine(
word_hidden, embed_weight, tokenizer
)
print(f" Layer {layer_idx:2d}:")
print(f" Unembed: {format_neighbors(unembed_neighbors)}")
print(f" Cosine: {format_neighbors(cosine_neighbors)}")
# ── (f2) Logit lens vs tuned lens (with/without ln_f) ────────────
if class_word_idx is not None:
print(f"\n[f2] Logit lens vs Tuned lens (with ln_f) for '{FOCUS_WORD}':")
print("-" * 78)
print(" 'Logit lens' = raw hidden state -> unembed (what word IS this?)")
print(" 'Tuned lens' = hidden state -> ln_f -> unembed (what comes NEXT?)")
for layer_idx in PROBE_LAYERS:
hs = hidden_states[layer_idx + 1][0]
word_hidden = hs[class_token_indices].mean(dim=0)
raw_neighbors = get_top_k_by_unembedding(
word_hidden, unembed_weight, tokenizer
)
ln_neighbors = get_top_k_by_unembedding(
word_hidden, unembed_weight, tokenizer,
layer_norm=ln_f
)
print(f" Layer {layer_idx:2d}:")
print(f" Logit: {format_neighbors(raw_neighbors)}")
print(f" Tuned: {format_neighbors(ln_neighbors)}")
# ── (g) Summary and cross-sentence comparison ────────────────────────
print("\n" + "=" * 80)
print("SUMMARY: Cross-sentence comparison for 'class'")
print("=" * 80)
sentence_labels = [
"academic (Song of Solomon)",
"socioeconomic (upper class)",
"quality (class act)",
]
for layer_idx in PROBE_LAYERS:
print(f"\n Layer {layer_idx}:")
for sent_idx in sorted(class_results.keys()):
results_for_layer = [
(l, n) for l, n in class_results[sent_idx] if l == layer_idx
]
if results_for_layer:
_, neighbors = results_for_layer[0]
top_tokens = [tok for tok, _ in neighbors]
print(f" {sentence_labels[sent_idx]:35s} -> {top_tokens}")
# ── Final assessment ─────────────────────────────────────────────────
print("\n" + "=" * 80)
print("ASSESSMENT")
print("=" * 80)
# Check if deeper layers show divergence for "class"
if len(class_results) >= 2:
# Compare top-1 at layer 11 across sentences
final_layer_tops = {}
for sent_idx, layer_results in class_results.items():
for l, neighbors in layer_results:
if l == 11:
final_layer_tops[sent_idx] = neighbors[0][0]
early_layer_tops = {}
for sent_idx, layer_results in class_results.items():
for l, neighbors in layer_results:
if l == 0:
early_layer_tops[sent_idx] = neighbors[0][0]
early_unique = len(set(early_layer_tops.values()))
final_unique = len(set(final_layer_tops.values()))
print(f"\n 1. Context-dependent neighbors for 'class':")
print(f" Layer 0 top-1 predictions: {dict(zip([sentence_labels[k] for k in early_layer_tops], early_layer_tops.values()))}")
print(f" Layer 11 top-1 predictions: {dict(zip([sentence_labels[k] for k in final_layer_tops], final_layer_tops.values()))}")
print(f" Unique top-1 at layer 0: {early_unique}/3")
print(f" Unique top-1 at layer 11: {final_unique}/3")
if final_unique > early_unique:
print(f" -> YES: Deeper layers show more context differentiation")
elif final_unique == early_unique and final_unique > 1:
print(f" -> PARTIAL: Both layers show differentiation")
else:
print(f" -> UNCLEAR: Need further analysis")
print(f"\n 2. Unembedding vs Cosine Similarity:")
if weights_tied:
print(f" GPT-2 ties embed and unembed weights. In theory, unembedding")
print(f" logits and cosine similarity would rank the same (same matrix).")
print(f" In PRACTICE, they diverge at deeper layers because:")
print(f" - Unembedding (dot product) is what the model actually does")
print(f" to predict the next token β€” it accounts for magnitude.")
print(f" - Cosine similarity normalizes away magnitude, and at deep")
print(f" layers the hidden states have been transformed so far from")
print(f" the input embedding space that cosine similarity becomes")
print(f" meaningless (note negative scores and junk tokens at L11).")
print(f" CONCLUSION: Use unembedding (logit lens), not cosine similarity.")
else:
print(f" Weights are NOT tied β€” methods should show different results.")
print(f"\n 3. Memory footprint: {mem_bytes / 1024 / 1024:.1f} MB")
print(f" Suitable for: {'browser (with ONNX/WASM)' if mem_bytes < 500*1024*1024 else 'server-side only'}")
print(f"\n 4. Intermediate-layer interpretability:")
print(f" Check the layer-by-layer output above. If neighbors shift from")
print(f" generic/syntactic (early layers) to contextual/semantic (later")
print(f" layers), Feature 2 (nearest neighbors per layer) is VIABLE.")
print("\n" + "=" * 80)
print("PROTOTYPE COMPLETE")
print("=" * 80)
if __name__ == "__main__":
main()