Spaces:
Running on L4
Running on L4
| #!/usr/bin/env python3 | |
| """ | |
| Attention Explorer Prototype β Task 1 (Issue #13) | |
| Tests whether GPT-2 attention weights, hidden states, and intermediate-layer | |
| unembedding produce interpretable results for the Attention Explorer tool. | |
| Key questions: | |
| 1. Does intermediate-layer unembedding produce interpretable nearest neighbors? | |
| 2. Does the word "class" show context-dependent neighbors at deeper layers? | |
| 3. Which method (unembedding vs cosine similarity) gives better results? | |
| Usage: | |
| source ~/venvs/responsible-ai-course-tools/bin/activate | |
| python prototype_attention.py | |
| """ | |
| import sys | |
| import os | |
| import logging | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Suppress transformers progress bars and warnings for clean output | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" | |
| os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" | |
| # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_NAME = "gpt2" | |
| TEST_SENTENCES = [ | |
| "This semester we are reading Song of Solomon in Prof. Mouton's class.", | |
| "The upper class kept getting richer while everyone else struggled.", | |
| "She was a class act from start to finish.", | |
| ] | |
| # GPT-2 has 12 layers (0-11). We sample across the depth. | |
| PROBE_LAYERS = [0, 3, 6, 9, 11] | |
| TOP_K = 5 # Number of nearest neighbors to show | |
| FOCUS_WORD = "class" # Word to compare across sentences | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def group_tokens_into_words(tokens): | |
| """ | |
| Group subword tokens into words. | |
| In GPT-2's tokenizer, tokens starting with 'Δ ' (displayed as a space) | |
| begin a new word. The very first token also starts a new word. | |
| Returns: list of (word_str, [token_indices]) | |
| """ | |
| words = [] | |
| current_word_tokens = [] | |
| current_word_str = "" | |
| for i, tok in enumerate(tokens): | |
| if i == 0 or tok.startswith("Δ "): | |
| # Start a new word | |
| if current_word_tokens: | |
| words.append((current_word_str, current_word_tokens)) | |
| current_word_tokens = [i] | |
| current_word_str = tok.lstrip("Δ ") | |
| else: | |
| # Continuation of current word | |
| current_word_tokens.append(i) | |
| current_word_str += tok | |
| if current_word_tokens: | |
| words.append((current_word_str, current_word_tokens)) | |
| return words | |
| def merge_attention_to_words(attn_matrix, word_groups): | |
| """ | |
| Merge a token-level attention matrix (seq_len x seq_len) into a | |
| word-level matrix (num_words x num_words). | |
| Strategy: | |
| - Rows (query): average across subword tokens in each word | |
| - Columns (key): sum across subword tokens in each word | |
| """ | |
| seq_len = attn_matrix.shape[0] | |
| num_words = len(word_groups) | |
| merged = torch.zeros(num_words, num_words) | |
| for wi, (_, w_indices) in enumerate(word_groups): | |
| for wj, (_, v_indices) in enumerate(word_groups): | |
| # Average over query subwords, sum over key subwords | |
| block = attn_matrix[w_indices][:, v_indices] # (|wi| x |wj|) | |
| merged[wi, wj] = block.mean(dim=0).sum() | |
| return merged | |
| def get_top_k_by_unembedding(hidden_state, unembed_weight, tokenizer, k=TOP_K, | |
| layer_norm=None): | |
| """ | |
| Project a hidden state through the unembedding matrix to get top-k tokens. | |
| unembed_weight shape: (vocab_size, hidden_dim) | |
| hidden_state shape: (hidden_dim,) | |
| If layer_norm is provided, applies it before unembedding (tuned lens). | |
| Without layer_norm, this is the "logit lens" approach. | |
| """ | |
| if layer_norm is not None: | |
| hidden_state = layer_norm(hidden_state) | |
| logits = hidden_state @ unembed_weight.T # (vocab_size,) | |
| probs = F.softmax(logits, dim=-1) | |
| top_probs, top_indices = probs.topk(k) | |
| tokens = [tokenizer.decode([idx]).strip() for idx in top_indices] | |
| return list(zip(tokens, top_probs.tolist())) | |
| def get_top_k_by_cosine(hidden_state, embed_weight, tokenizer, k=TOP_K): | |
| """ | |
| Find top-k tokens by cosine similarity against the embedding table. | |
| embed_weight shape: (vocab_size, hidden_dim) | |
| hidden_state shape: (hidden_dim,) | |
| """ | |
| hidden_norm = F.normalize(hidden_state.unsqueeze(0), dim=-1) # (1, hidden_dim) | |
| embed_norm = F.normalize(embed_weight, dim=-1) # (vocab_size, hidden_dim) | |
| sims = (hidden_norm @ embed_norm.T).squeeze(0) # (vocab_size,) | |
| top_sims, top_indices = sims.topk(k) | |
| tokens = [tokenizer.decode([idx]).strip() for idx in top_indices] | |
| return list(zip(tokens, top_sims.tolist())) | |
| def format_neighbors(neighbors): | |
| """Format a list of (token, score) pairs.""" | |
| return ", ".join(f"'{tok}' ({score:.4f})" for tok, score in neighbors) | |
| # ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| print("=" * 80) | |
| print("ATTENTION EXPLORER PROTOTYPE β Task 1") | |
| print("Model:", MODEL_NAME) | |
| print("=" * 80) | |
| # ββ Load model and tokenizer βββββββββββββββββββββββββββββββββββββββββ | |
| print("\n[1] Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Must use eager attention to get attention weight tensors. | |
| # Transformers 5.0 defaults to SDPA which returns None for attentions. | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, attn_implementation="eager" | |
| ) | |
| model.eval() | |
| # Model info | |
| num_params = sum(p.numel() for p in model.parameters()) | |
| mem_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) | |
| print(f" Parameters: {num_params:,}") | |
| print(f" Memory footprint: {mem_bytes / 1024 / 1024:.1f} MB") | |
| print(f" Layers: {model.config.n_layer}") | |
| print(f" Heads: {model.config.n_head}") | |
| print(f" Hidden dim: {model.config.n_embd}") | |
| print(f" Vocab size: {model.config.vocab_size}") | |
| # Get weight matrices and final layer norm | |
| unembed_weight = model.lm_head.weight.detach() # (vocab_size, hidden_dim) | |
| embed_weight = model.transformer.wte.weight.detach() # (vocab_size, hidden_dim) | |
| ln_f = model.transformer.ln_f # final layer norm | |
| # Check if unembed and embed are tied (GPT-2 ties them) | |
| weights_tied = torch.equal(unembed_weight, embed_weight) | |
| print(f" Weights tied (embed == unembed): {weights_tied}") | |
| # ββ Process each sentence ββββββββββββββββββββββββββββββββββββββββββββ | |
| # Store "class" results across sentences for comparison | |
| class_results = {} | |
| for sent_idx, text in enumerate(TEST_SENTENCES): | |
| print("\n" + "=" * 80) | |
| print(f"SENTENCE {sent_idx + 1}: \"{text}\"") | |
| print("=" * 80) | |
| # Tokenize | |
| inputs = tokenizer(text, return_tensors="pt") | |
| input_ids = inputs["input_ids"][0] | |
| tokens = [tokenizer.decode([tid]) for tid in input_ids] | |
| # ββ (c) Tokenization and forward pass ββββββββββββββββββββββββββββ | |
| print(f"\n[c] Tokenization ({len(tokens)} tokens):") | |
| for i, (tid, tok) in enumerate(zip(input_ids, tokens)): | |
| print(f" [{i:2d}] {tid:6d} -> {repr(tok)}") | |
| with torch.no_grad(): | |
| outputs = model( | |
| **inputs, | |
| output_attentions=True, | |
| output_hidden_states=True, | |
| ) | |
| attentions = outputs.attentions # tuple of num_layers tensors | |
| hidden_states = outputs.hidden_states # tuple of (num_layers+1) tensors | |
| # Attention shape | |
| print(f"\n Attention tensors: {len(attentions)} layers") | |
| print(f" Each shape: {attentions[0].shape}") | |
| print(f" -> (batch=1, heads={model.config.n_head}, " | |
| f"seq_len={len(tokens)}, seq_len={len(tokens)})") | |
| # Verify causal mask (lower triangular) | |
| attn_layer0 = attentions[0][0, 0] # first head of first layer | |
| is_causal = True | |
| for i in range(len(tokens)): | |
| for j in range(i + 1, len(tokens)): | |
| if attn_layer0[i, j].item() > 1e-6: | |
| is_causal = False | |
| break | |
| print(f" Causal mask verified (upper triangle ~0): {is_causal}") | |
| # ββ (d) Subword merging ββββββββββββββββββββββββββββββββββββββββββ | |
| word_groups = group_tokens_into_words( | |
| [tokenizer.decode([tid]) for tid in input_ids] | |
| ) | |
| # Fix: re-tokenize using the raw token strings for GPT-2's Δ prefix | |
| raw_tokens = tokenizer.convert_ids_to_tokens(input_ids.tolist()) | |
| word_groups = group_tokens_into_words(raw_tokens) | |
| print(f"\n[d] Subword grouping ({len(word_groups)} words):") | |
| for word, indices in word_groups: | |
| subtokens = [raw_tokens[i] for i in indices] | |
| print(f" '{word}' <- tokens {indices}: {subtokens}") | |
| # Merge attention for layer 0, head 0 as a sanity check | |
| attn_l0_h0 = attentions[0][0, 0] # (seq_len, seq_len) | |
| merged = merge_attention_to_words(attn_l0_h0, word_groups) | |
| print(f"\n Merged attention (layer 0, head 0): {merged.shape}") | |
| # Check rows sum to ~1 | |
| row_sums = merged.sum(dim=-1) | |
| print(f" Row sums (should be ~1.0): " | |
| f"min={row_sums.min():.4f}, max={row_sums.max():.4f}, " | |
| f"mean={row_sums.mean():.4f}") | |
| # ββ (e) Unembedding at intermediate layers βββββββββββββββββββββββ | |
| print(f"\n[e] Intermediate-layer unembedding (top-{TOP_K} neighbors):") | |
| print("-" * 78) | |
| # Find index of "class" in word groups. | |
| # Also identify which token indices correspond to just "class" | |
| # (exclude trailing punctuation tokens from the hidden state average). | |
| class_word_idx = None | |
| class_token_indices = None | |
| for wi, (word, indices) in enumerate(word_groups): | |
| clean = word.rstrip(".,!?;:'\"") | |
| if clean.lower() == FOCUS_WORD: | |
| class_word_idx = wi | |
| # Use only the token(s) that correspond to the focus word, | |
| # excluding any trailing punctuation tokens in the group. | |
| # For GPT-2, punctuation is typically the last token(s). | |
| focus_indices = [] | |
| for idx in indices: | |
| tok_str = raw_tokens[idx] | |
| stripped = tok_str.lstrip("Δ ") | |
| if stripped.isalpha(): | |
| focus_indices.append(idx) | |
| class_token_indices = focus_indices if focus_indices else indices | |
| break | |
| for wi, (word, indices) in enumerate(word_groups): | |
| # Only show a subset of words unless it's the focus word | |
| # Show first word, last word, and the focus word | |
| if wi not in [0, len(word_groups) - 1] and wi != class_word_idx: | |
| continue | |
| # For the focus word, use only the core token indices | |
| # (excluding trailing punctuation that got grouped with it) | |
| use_indices = class_token_indices if wi == class_word_idx else indices | |
| label = FOCUS_WORD if wi == class_word_idx else word | |
| print(f"\n Word: '{label}' (tokens {use_indices})") | |
| layer_results = [] | |
| for layer_idx in PROBE_LAYERS: | |
| # hidden_states[0] is embedding output, [1] is after layer 0, etc. | |
| hs = hidden_states[layer_idx + 1][0] # (seq_len, hidden_dim) | |
| # Average hidden states for multi-token words | |
| word_hidden = hs[use_indices].mean(dim=0) # (hidden_dim,) | |
| neighbors = get_top_k_by_unembedding( | |
| word_hidden, unembed_weight, tokenizer | |
| ) | |
| layer_results.append((layer_idx, neighbors)) | |
| print(f" Layer {layer_idx:2d}: {format_neighbors(neighbors)}") | |
| # Store class results for cross-sentence comparison | |
| if wi == class_word_idx: | |
| class_results[sent_idx] = layer_results | |
| # ββ (f) Cosine similarity comparison (for focus word only) βββββββ | |
| if class_word_idx is not None: | |
| print(f"\n[f] Cosine similarity comparison for '{FOCUS_WORD}':") | |
| print("-" * 78) | |
| for layer_idx in PROBE_LAYERS: | |
| hs = hidden_states[layer_idx + 1][0] | |
| word_hidden = hs[class_token_indices].mean(dim=0) | |
| unembed_neighbors = get_top_k_by_unembedding( | |
| word_hidden, unembed_weight, tokenizer | |
| ) | |
| cosine_neighbors = get_top_k_by_cosine( | |
| word_hidden, embed_weight, tokenizer | |
| ) | |
| print(f" Layer {layer_idx:2d}:") | |
| print(f" Unembed: {format_neighbors(unembed_neighbors)}") | |
| print(f" Cosine: {format_neighbors(cosine_neighbors)}") | |
| # ββ (f2) Logit lens vs tuned lens (with/without ln_f) ββββββββββββ | |
| if class_word_idx is not None: | |
| print(f"\n[f2] Logit lens vs Tuned lens (with ln_f) for '{FOCUS_WORD}':") | |
| print("-" * 78) | |
| print(" 'Logit lens' = raw hidden state -> unembed (what word IS this?)") | |
| print(" 'Tuned lens' = hidden state -> ln_f -> unembed (what comes NEXT?)") | |
| for layer_idx in PROBE_LAYERS: | |
| hs = hidden_states[layer_idx + 1][0] | |
| word_hidden = hs[class_token_indices].mean(dim=0) | |
| raw_neighbors = get_top_k_by_unembedding( | |
| word_hidden, unembed_weight, tokenizer | |
| ) | |
| ln_neighbors = get_top_k_by_unembedding( | |
| word_hidden, unembed_weight, tokenizer, | |
| layer_norm=ln_f | |
| ) | |
| print(f" Layer {layer_idx:2d}:") | |
| print(f" Logit: {format_neighbors(raw_neighbors)}") | |
| print(f" Tuned: {format_neighbors(ln_neighbors)}") | |
| # ββ (g) Summary and cross-sentence comparison ββββββββββββββββββββββββ | |
| print("\n" + "=" * 80) | |
| print("SUMMARY: Cross-sentence comparison for 'class'") | |
| print("=" * 80) | |
| sentence_labels = [ | |
| "academic (Song of Solomon)", | |
| "socioeconomic (upper class)", | |
| "quality (class act)", | |
| ] | |
| for layer_idx in PROBE_LAYERS: | |
| print(f"\n Layer {layer_idx}:") | |
| for sent_idx in sorted(class_results.keys()): | |
| results_for_layer = [ | |
| (l, n) for l, n in class_results[sent_idx] if l == layer_idx | |
| ] | |
| if results_for_layer: | |
| _, neighbors = results_for_layer[0] | |
| top_tokens = [tok for tok, _ in neighbors] | |
| print(f" {sentence_labels[sent_idx]:35s} -> {top_tokens}") | |
| # ββ Final assessment βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "=" * 80) | |
| print("ASSESSMENT") | |
| print("=" * 80) | |
| # Check if deeper layers show divergence for "class" | |
| if len(class_results) >= 2: | |
| # Compare top-1 at layer 11 across sentences | |
| final_layer_tops = {} | |
| for sent_idx, layer_results in class_results.items(): | |
| for l, neighbors in layer_results: | |
| if l == 11: | |
| final_layer_tops[sent_idx] = neighbors[0][0] | |
| early_layer_tops = {} | |
| for sent_idx, layer_results in class_results.items(): | |
| for l, neighbors in layer_results: | |
| if l == 0: | |
| early_layer_tops[sent_idx] = neighbors[0][0] | |
| early_unique = len(set(early_layer_tops.values())) | |
| final_unique = len(set(final_layer_tops.values())) | |
| print(f"\n 1. Context-dependent neighbors for 'class':") | |
| print(f" Layer 0 top-1 predictions: {dict(zip([sentence_labels[k] for k in early_layer_tops], early_layer_tops.values()))}") | |
| print(f" Layer 11 top-1 predictions: {dict(zip([sentence_labels[k] for k in final_layer_tops], final_layer_tops.values()))}") | |
| print(f" Unique top-1 at layer 0: {early_unique}/3") | |
| print(f" Unique top-1 at layer 11: {final_unique}/3") | |
| if final_unique > early_unique: | |
| print(f" -> YES: Deeper layers show more context differentiation") | |
| elif final_unique == early_unique and final_unique > 1: | |
| print(f" -> PARTIAL: Both layers show differentiation") | |
| else: | |
| print(f" -> UNCLEAR: Need further analysis") | |
| print(f"\n 2. Unembedding vs Cosine Similarity:") | |
| if weights_tied: | |
| print(f" GPT-2 ties embed and unembed weights. In theory, unembedding") | |
| print(f" logits and cosine similarity would rank the same (same matrix).") | |
| print(f" In PRACTICE, they diverge at deeper layers because:") | |
| print(f" - Unembedding (dot product) is what the model actually does") | |
| print(f" to predict the next token β it accounts for magnitude.") | |
| print(f" - Cosine similarity normalizes away magnitude, and at deep") | |
| print(f" layers the hidden states have been transformed so far from") | |
| print(f" the input embedding space that cosine similarity becomes") | |
| print(f" meaningless (note negative scores and junk tokens at L11).") | |
| print(f" CONCLUSION: Use unembedding (logit lens), not cosine similarity.") | |
| else: | |
| print(f" Weights are NOT tied β methods should show different results.") | |
| print(f"\n 3. Memory footprint: {mem_bytes / 1024 / 1024:.1f} MB") | |
| print(f" Suitable for: {'browser (with ONNX/WASM)' if mem_bytes < 500*1024*1024 else 'server-side only'}") | |
| print(f"\n 4. Intermediate-layer interpretability:") | |
| print(f" Check the layer-by-layer output above. If neighbors shift from") | |
| print(f" generic/syntactic (early layers) to contextual/semantic (later") | |
| print(f" layers), Feature 2 (nearest neighbors per layer) is VIABLE.") | |
| print("\n" + "=" * 80) | |
| print("PROTOTYPE COMPLETE") | |
| print("=" * 80) | |
| if __name__ == "__main__": | |
| main() | |