| |
|
| | """
|
| | Representation analysis: CKA and Logit Lens for Prisma / Circuit Transformer.
|
| |
|
| | CKA (Centered Kernel Alignment):
|
| | Measures representational similarity between all layer pairs.
|
| | Produces a heatmap revealing mirror symmetry, phase transitions,
|
| | and cross-model alignment.
|
| |
|
| | Logit Lens:
|
| | Projects intermediate representations to vocabulary space at every layer.
|
| | Reveals what the model "thinks" at each processing stage -- from raw
|
| | tokens through the semantic bottleneck back to specific predictions.
|
| |
|
| | Also computes representation drift (cosine similarity between consecutive layers).
|
| |
|
| | Usage:
|
| | # Full analysis (CKA + logit lens)
|
| | python -m circuits.scripts.representation_analysis \\
|
| | --checkpoint path/to/checkpoint.pt \\
|
| | --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train
|
| |
|
| | # Cross-model CKA
|
| | python -m circuits.scripts.representation_analysis \\
|
| | --checkpoint path/to/prisma.pt --hf-model gpt2-medium \\
|
| | --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train
|
| |
|
| | # CKA only (skip logit lens)
|
| | python -m circuits.scripts.representation_analysis \\
|
| | --checkpoint path/to/checkpoint.pt \\
|
| | --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train \\
|
| | --no-logit-lens
|
| | """
|
| |
|
| | import argparse
|
| | import json
|
| | import sys
|
| | import os
|
| | from pathlib import Path
|
| | from collections import OrderedDict
|
| |
|
| | import numpy as np
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| |
|
| | import matplotlib
|
| | matplotlib.use("Agg")
|
| | import matplotlib.pyplot as plt
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def load_prisma_model(checkpoint_path: str, device: str = "cpu"):
|
| | """Load a Prisma/Circuit checkpoint, return (model, config_dict, model_type)."""
|
| | sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| | from circuits.config import CircuitConfig
|
| | from circuits.model import CircuitTransformer
|
| | from circuits.mirrored import MirroredConfig, MirroredTransformer
|
| |
|
| | ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| | model_type = ckpt.get("model_type", "standard")
|
| | config_dict = ckpt.get("config", {})
|
| |
|
| | if model_type == "mirrored":
|
| | if config_dict.get("dual_gate_middle"):
|
| | config_dict.pop("dual_gate_middle")
|
| | config = MirroredConfig.from_dict(config_dict)
|
| | model = MirroredTransformer(config)
|
| | else:
|
| | config = CircuitConfig.from_dict(config_dict)
|
| | model = CircuitTransformer(config)
|
| |
|
| | state_dict = ckpt["model"]
|
| | if any(k.startswith("_orig_mod.") for k in state_dict):
|
| | state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| | model.load_state_dict(state_dict, strict=False)
|
| | model.to(device).eval()
|
| |
|
| | return model, config_dict, model_type
|
| |
|
| |
|
| | def load_hf_model(model_name: str, device: str = "cpu"):
|
| | """Load a HuggingFace causal LM."""
|
| | from transformers import AutoModelForCausalLM
|
| | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True)
|
| | model.to(device).eval()
|
| | return model
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def load_data(data_source: str, tokenizer_name: str, num_samples: int = 32,
|
| | context_length: int = 512, device: str = "cpu"):
|
| | """Load tokenized data. Returns (input_ids, tokenizer).
|
| |
|
| | Supports:
|
| | - Memmap .bin files (from circuits training cache)
|
| | - hf:dataset:config:split (streaming from HuggingFace)
|
| | - Plain text files
|
| | """
|
| | sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| | from circuits.data import get_tokenizer
|
| |
|
| | tokenizer = get_tokenizer(tokenizer_name)
|
| |
|
| |
|
| | if data_source.endswith(".bin"):
|
| | import struct
|
| | with open(data_source, 'rb') as f:
|
| | n_chunks, seq_len = struct.unpack('II', f.read(8))
|
| | data = np.memmap(data_source, dtype=np.int32, mode='r',
|
| | offset=8, shape=(n_chunks, seq_len))
|
| | n = min(num_samples, n_chunks)
|
| |
|
| | cl = min(context_length, seq_len)
|
| | input_ids = torch.from_numpy(data[:n, :cl].copy()).long().to(device)
|
| | return input_ids, tokenizer
|
| |
|
| |
|
| | if data_source.startswith("hf:"):
|
| | from datasets import load_dataset
|
| | parts = data_source[3:].split(":")
|
| | ds_name = parts[0]
|
| | ds_config = parts[1] if len(parts) > 1 else None
|
| | ds_split = parts[2] if len(parts) > 2 else "train"
|
| | dataset = load_dataset(ds_name, ds_config, split=ds_split, streaming=True)
|
| | all_ids = []
|
| | for item in dataset:
|
| | text = item.get("text", "")
|
| | if len(text) < 100:
|
| | continue
|
| | ids = tokenizer.encode(text)
|
| | if len(ids) >= context_length:
|
| | all_ids.append(ids[:context_length])
|
| | if len(all_ids) >= num_samples:
|
| | break
|
| | if not all_ids:
|
| | return None, tokenizer
|
| | return torch.tensor(all_ids, device=device), tokenizer
|
| |
|
| |
|
| | with open(data_source) as f:
|
| | texts = [line.strip() for line in f if len(line.strip()) > 100]
|
| | all_ids = []
|
| | for text in texts:
|
| | ids = tokenizer.encode(text)
|
| | if len(ids) >= context_length:
|
| | all_ids.append(ids[:context_length])
|
| | if len(all_ids) >= num_samples:
|
| | break
|
| | if not all_ids:
|
| | return None, tokenizer
|
| | return torch.tensor(all_ids, device=device), tokenizer
|
| |
|
| |
|
| | def tokenize_for_hf(texts: list, model_name: str, context_length: int = 512,
|
| | device: str = "cpu"):
|
| | """Tokenize texts for an HF model. Returns (input_ids, tokenizer)."""
|
| | from transformers import AutoTokenizer
|
| | tokenizer = AutoTokenizer.from_pretrained(model_name,
|
| | use_fast=False,
|
| | trust_remote_code=True)
|
| | if tokenizer.pad_token is None:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| | all_ids = []
|
| | for text in texts:
|
| | ids = tokenizer.encode(text, max_length=context_length, truncation=True)
|
| | if len(ids) >= context_length:
|
| | all_ids.append(ids[:context_length])
|
| | elif len(ids) > 32:
|
| | all_ids.append(ids + [tokenizer.eos_token_id] * (context_length - len(ids)))
|
| |
|
| | if not all_ids:
|
| | return None, tokenizer
|
| |
|
| | return torch.tensor(all_ids, device=device), tokenizer
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def collect_mirrored_activations(model, input_ids, word_positions=None):
|
| | """Collect activations from MirroredTransformer at every processing stage."""
|
| | activations = OrderedDict()
|
| |
|
| | with torch.no_grad():
|
| | x = model.embed(input_ids)
|
| | if model.embed_proj is not None:
|
| | if model.embed_g3 is not None:
|
| | g4 = F.silu(model.embed_g4(x))
|
| | g3 = F.silu(model.embed_g3(x) * g4)
|
| | x = model.embed_proj(x) * g3
|
| | else:
|
| | x = F.silu(model.embed_proj(x))
|
| | x = x * model.embed_scale
|
| | activations["embedding"] = x.detach().cpu()
|
| |
|
| | for i, block in enumerate(model.mirror_blocks):
|
| | x, _ = block(x, word_positions=word_positions)
|
| | activations[f"expand_{i}"] = x.detach().cpu()
|
| |
|
| | for i, block in enumerate(model.middle_blocks):
|
| | x, _ = block(x, word_positions=word_positions)
|
| | activations[f"middle_{i}"] = x.detach().cpu()
|
| |
|
| | for i in reversed(range(len(model.mirror_blocks))):
|
| | x, _ = model.mirror_blocks[i](x, word_positions=word_positions)
|
| | compress_idx = len(model.mirror_blocks) - 1 - i
|
| | activations[f"compress_{compress_idx}"] = x.detach().cpu()
|
| |
|
| | x = model.norm(x)
|
| | activations["final_norm"] = x.detach().cpu()
|
| |
|
| | return activations
|
| |
|
| |
|
| | def collect_standard_activations(model, input_ids, word_positions=None):
|
| | """Collect activations from standard CircuitTransformer."""
|
| | activations = OrderedDict()
|
| |
|
| | with torch.no_grad():
|
| | x = model.embed(input_ids)
|
| | if model.embed_proj is not None:
|
| | x = F.silu(model.embed_proj(x))
|
| | x = x * model.embed_scale
|
| | activations["embedding"] = x.detach().cpu()
|
| |
|
| | for i, layer in enumerate(model.layers):
|
| | x, _ = layer(x, word_positions=word_positions)
|
| | activations[f"layer_{i}"] = x.detach().cpu()
|
| |
|
| | x = model.norm(x)
|
| | activations["final_norm"] = x.detach().cpu()
|
| |
|
| | return activations
|
| |
|
| |
|
| | def collect_hf_activations(model, input_ids):
|
| | """Hook-based activation collection for HuggingFace models."""
|
| | activations = OrderedDict()
|
| | hooks = []
|
| |
|
| | if hasattr(model, 'transformer'):
|
| |
|
| | blocks = model.transformer.h
|
| | embed = model.transformer.wte
|
| | final_norm = model.transformer.ln_f
|
| | elif hasattr(model, 'model'):
|
| |
|
| | blocks = model.model.layers
|
| | embed = model.model.embed_tokens
|
| | final_norm = model.model.norm
|
| | else:
|
| | raise ValueError(f"Unsupported HF model: {type(model)}")
|
| |
|
| | def make_hook(name):
|
| | def hook_fn(module, input, output):
|
| | out = output[0] if isinstance(output, tuple) else output
|
| | activations[name] = out.detach().cpu()
|
| | return hook_fn
|
| |
|
| | hooks.append(embed.register_forward_hook(make_hook("embedding")))
|
| | for i, block in enumerate(blocks):
|
| | hooks.append(block.register_forward_hook(make_hook(f"layer_{i}")))
|
| | hooks.append(final_norm.register_forward_hook(make_hook("final_norm")))
|
| |
|
| | with torch.no_grad():
|
| | model(input_ids)
|
| |
|
| | for h in hooks:
|
| | h.remove()
|
| |
|
| | return activations
|
| |
|
| |
|
| | def collect_activations(model, model_type, config_dict, input_ids, device):
|
| | """Dispatch to the right collector based on model type."""
|
| | word_positions = None
|
| | word_rope_dims = config_dict.get("word_rope_dims", 0) if config_dict else 0
|
| |
|
| | if word_rope_dims > 0 and model_type in ("standard", "mirrored"):
|
| | sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| | from circuits.data import get_tokenizer
|
| | from circuits.layers import build_word_start_table, compute_word_positions
|
| | tokenizer_name = config_dict.get("tokenizer_name", "gpt2")
|
| |
|
| | tokenizer = get_tokenizer(tokenizer_name)
|
| | word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
|
| | word_positions = compute_word_positions(input_ids, word_start_table)
|
| |
|
| | if model_type == "mirrored":
|
| | return collect_mirrored_activations(model, input_ids, word_positions)
|
| | elif model_type == "standard":
|
| | return collect_standard_activations(model, input_ids, word_positions)
|
| | else:
|
| | return collect_hf_activations(model, input_ids)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def linear_cka(X: torch.Tensor, Y: torch.Tensor) -> float:
|
| | """Compute linear CKA between two [N, D] representation matrices.
|
| |
|
| | CKA(X, Y) = ||Yc^T Xc||_F^2 / (||Xc^T Xc||_F * ||Yc^T Yc||_F)
|
| | """
|
| | X = X.float()
|
| | Y = Y.float()
|
| |
|
| |
|
| | X = X - X.mean(0, keepdim=True)
|
| | Y = Y - Y.mean(0, keepdim=True)
|
| |
|
| | N = X.shape[0]
|
| |
|
| | if N < min(X.shape[1], Y.shape[1]):
|
| |
|
| | K = X @ X.T
|
| | L = Y @ Y.T
|
| | numerator = (K * L).sum()
|
| | denominator = torch.sqrt((K * K).sum() * (L * L).sum())
|
| | else:
|
| |
|
| | XtY = X.T @ Y
|
| | XtX = X.T @ X
|
| | YtY = Y.T @ Y
|
| | numerator = (XtY * XtY).sum()
|
| | denominator = torch.sqrt((XtX * XtX).sum() * (YtY * YtY).sum())
|
| |
|
| | if denominator < 1e-10:
|
| | return 0.0
|
| |
|
| | return (numerator / denominator).item()
|
| |
|
| |
|
| | def compute_cka_matrix(activations: dict, subsample: int = 4) -> tuple:
|
| | """Compute CKA between all layer pairs. Returns (cka_matrix, layer_names)."""
|
| | names = list(activations.keys())
|
| | n_layers = len(names)
|
| |
|
| |
|
| | flat_acts = {}
|
| | for name, act in activations.items():
|
| | act_sub = act[:, ::subsample, :]
|
| | flat_acts[name] = act_sub.reshape(-1, act_sub.shape[-1])
|
| |
|
| | cka_matrix = np.zeros((n_layers, n_layers))
|
| |
|
| | for i in range(n_layers):
|
| | cka_matrix[i, i] = 1.0
|
| | for j in range(i + 1, n_layers):
|
| | cka_val = linear_cka(flat_acts[names[i]], flat_acts[names[j]])
|
| | cka_matrix[i, j] = cka_val
|
| | cka_matrix[j, i] = cka_val
|
| | if (i + 1) % 5 == 0 or i == n_layers - 1:
|
| | print(f" CKA: {i+1}/{n_layers} rows computed")
|
| |
|
| | return cka_matrix, names
|
| |
|
| |
|
| | def compute_cross_model_cka(acts_a: dict, acts_b: dict) -> tuple:
|
| | """Cross-model CKA using sample-level (avg-pooled) representations."""
|
| | names_a = list(acts_a.keys())
|
| | names_b = list(acts_b.keys())
|
| |
|
| | def pool(activations):
|
| | return {name: act.mean(dim=1) for name, act in activations.items()}
|
| |
|
| | pooled_a = pool(acts_a)
|
| | pooled_b = pool(acts_b)
|
| |
|
| |
|
| | n_samples = min(
|
| | next(iter(pooled_a.values())).shape[0],
|
| | next(iter(pooled_b.values())).shape[0]
|
| | )
|
| |
|
| | cka_matrix = np.zeros((len(names_a), len(names_b)))
|
| |
|
| | for i, na in enumerate(names_a):
|
| | for j, nb in enumerate(names_b):
|
| | cka_matrix[i, j] = linear_cka(pooled_a[na][:n_samples], pooled_b[nb][:n_samples])
|
| | if (i + 1) % 5 == 0 or i == len(names_a) - 1:
|
| | print(f" Cross-CKA: {i+1}/{len(names_a)} rows computed")
|
| |
|
| | return cka_matrix, names_a, names_b
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def get_unembed_components(model, model_type):
|
| | """Extract (norm_module, unembed_weight) for logit lens projection."""
|
| | if model_type in ("standard", "mirrored"):
|
| | return model.norm, model.embed.weight
|
| | elif hasattr(model, 'transformer'):
|
| | return model.transformer.ln_f, model.transformer.wte.weight
|
| | elif hasattr(model, 'model'):
|
| | return model.model.norm, model.model.embed_tokens.weight
|
| | else:
|
| | raise ValueError(f"Unsupported model: {type(model)}")
|
| |
|
| |
|
| | def compute_logit_lens(activations: dict, norm: nn.Module, unembed_weight: torch.Tensor,
|
| | labels: torch.Tensor, device: str = "cpu",
|
| | chunk_size: int = 2048) -> OrderedDict:
|
| | """Compute logit lens statistics at every layer.
|
| |
|
| | Projects intermediate hidden states through final norm + unembedding.
|
| | Computes entropy, top-1 probability, correct token rank, and
|
| | agreement with the final layer's predictions.
|
| |
|
| | Args:
|
| | activations: OrderedDict[name] = [B, L, D]
|
| | norm: final layer norm module
|
| | unembed_weight: [V, D] unembedding matrix
|
| | labels: [B, L-1] next-token labels (input_ids[:, 1:])
|
| | device: computation device
|
| | chunk_size: number of positions per batch for projection
|
| |
|
| | Returns:
|
| | OrderedDict[name] = {entropy, top1_prob, correct_rank, ...}
|
| | """
|
| | names = list(activations.keys())
|
| | final_name = names[-1]
|
| | results = OrderedDict()
|
| |
|
| | unembed = unembed_weight.to(device)
|
| | norm_mod = norm.to(device)
|
| | labels_flat = labels.reshape(-1).to(device)
|
| |
|
| | def process_layer(name, act, apply_norm=True):
|
| | """Project one layer's activations and compute all metrics."""
|
| | B, L, D = act.shape
|
| | flat = act[:, :-1, :].reshape(-1, D)
|
| | N = flat.shape[0]
|
| |
|
| | all_entropy = []
|
| | all_top1_prob = []
|
| | all_correct_rank = []
|
| | all_top1_idx = []
|
| |
|
| | for start in range(0, N, chunk_size):
|
| | end = min(start + chunk_size, N)
|
| | chunk = flat[start:end].to(device)
|
| | chunk_labels = labels_flat[start:end]
|
| |
|
| | if apply_norm:
|
| | chunk = norm_mod(chunk)
|
| |
|
| | logits = chunk @ unembed.T
|
| | log_probs = F.log_softmax(logits, dim=-1)
|
| | probs = log_probs.exp()
|
| |
|
| |
|
| | entropy = -(probs * log_probs).sum(dim=-1)
|
| | all_entropy.append(entropy.cpu())
|
| |
|
| |
|
| | top1_prob = probs.max(dim=-1).values
|
| | all_top1_prob.append(top1_prob.cpu())
|
| |
|
| |
|
| | correct_logits = logits.gather(1, chunk_labels.unsqueeze(1))
|
| | rank = (logits > correct_logits).sum(dim=-1) + 1
|
| | all_correct_rank.append(rank.cpu())
|
| |
|
| |
|
| | all_top1_idx.append(logits.argmax(dim=-1).cpu())
|
| |
|
| | entropy_t = torch.cat(all_entropy)
|
| | top1_t = torch.cat(all_top1_prob)
|
| | rank_t = torch.cat(all_correct_rank).float()
|
| | top1_idx = torch.cat(all_top1_idx)
|
| |
|
| | return {
|
| | "entropy": entropy_t.mean().item(),
|
| | "entropy_std": entropy_t.std().item(),
|
| | "top1_prob": top1_t.mean().item(),
|
| | "correct_rank_mean": rank_t.mean().item(),
|
| | "correct_rank_median": rank_t.median().item(),
|
| | "log_rank_mean": rank_t.log().mean().item(),
|
| | "_top1_idx": top1_idx,
|
| | }
|
| |
|
| |
|
| | for name in names:
|
| | is_final = (name == final_name)
|
| | act = activations[name]
|
| | stats = process_layer(name, act, apply_norm=not is_final)
|
| | results[name] = stats
|
| | print(f" Logit lens: {name:20s} entropy={stats['entropy']:.2f} "
|
| | f"top1={stats['top1_prob']:.4f} rank={stats['correct_rank_median']:.0f}")
|
| |
|
| |
|
| | final_top1 = results[final_name]["_top1_idx"]
|
| | for name in names:
|
| | layer_top1 = results[name]["_top1_idx"]
|
| | agreement = (layer_top1 == final_top1).float().mean().item()
|
| | results[name]["agreement_with_final"] = agreement
|
| |
|
| |
|
| | for name in names:
|
| | del results[name]["_top1_idx"]
|
| |
|
| | return results
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def compute_drift(activations: dict) -> OrderedDict:
|
| | """Cosine similarity between consecutive layers' representations."""
|
| | names = list(activations.keys())
|
| | drift = OrderedDict()
|
| |
|
| | for i in range(1, len(names)):
|
| | prev = activations[names[i - 1]]
|
| | curr = activations[names[i]]
|
| |
|
| |
|
| | prev_flat = prev.reshape(-1, prev.shape[-1])
|
| | curr_flat = curr.reshape(-1, curr.shape[-1])
|
| |
|
| |
|
| | cos = F.cosine_similarity(prev_flat, curr_flat, dim=-1)
|
| | drift[names[i]] = {
|
| | "cos_sim_mean": cos.mean().item(),
|
| | "cos_sim_std": cos.std().item(),
|
| | "l2_distance": (curr_flat - prev_flat).norm(dim=-1).mean().item(),
|
| | }
|
| |
|
| | return drift
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def _phase_color(name):
|
| | """Return color based on layer phase."""
|
| | if "expand" in name:
|
| | return "steelblue"
|
| | elif "middle" in name:
|
| | return "goldenrod"
|
| | elif "compress" in name:
|
| | return "coral"
|
| | elif "embedding" in name:
|
| | return "gray"
|
| | elif "final" in name:
|
| | return "gray"
|
| | else:
|
| | return "mediumpurple"
|
| |
|
| |
|
| | def _layer_sort_key(name):
|
| | """Sort key for processing order."""
|
| | order = {"embedding": -1, "final_norm": 9999}
|
| | if name in order:
|
| | return order[name]
|
| | parts = name.split("_")
|
| | phase = parts[0]
|
| | idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
|
| | phase_offset = {"expand": 0, "middle": 1000, "compress": 2000, "layer": 0}
|
| | return phase_offset.get(phase, 3000) + idx
|
| |
|
| |
|
| | def _short_name(name):
|
| | """Shorten layer name for plot labels."""
|
| | if name == "embedding":
|
| | return "emb"
|
| | if name == "final_norm":
|
| | return "out"
|
| | parts = name.split("_")
|
| | if parts[0] == "expand":
|
| | return f"E{parts[1]}"
|
| | elif parts[0] == "middle":
|
| | return f"M{parts[1]}"
|
| | elif parts[0] == "compress":
|
| | return f"C{parts[1]}"
|
| | elif parts[0] == "layer":
|
| | return f"L{parts[1]}"
|
| | return name[:6]
|
| |
|
| |
|
| | def plot_cka_self(cka_matrix: np.ndarray, names: list, output_dir: Path,
|
| | model_label: str):
|
| | """Plot self-CKA heatmap."""
|
| | n = len(names)
|
| | short = [_short_name(n) for n in names]
|
| |
|
| | fig, ax = plt.subplots(figsize=(max(10, n * 0.35), max(8, n * 0.3)))
|
| | fig.suptitle(f"{model_label} -- CKA Self-Similarity", fontsize=14)
|
| |
|
| | im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="equal")
|
| |
|
| |
|
| | for i, name in enumerate(names):
|
| | if i > 0:
|
| | prev = names[i - 1].split("_")[0]
|
| | curr = name.split("_")[0]
|
| | if prev != curr:
|
| | ax.axhline(i - 0.5, color="white", linewidth=1.5, alpha=0.8)
|
| | ax.axvline(i - 0.5, color="white", linewidth=1.5, alpha=0.8)
|
| |
|
| | ax.set_xticks(range(n))
|
| | ax.set_xticklabels(short, rotation=90, fontsize=7)
|
| | ax.set_yticks(range(n))
|
| | ax.set_yticklabels(short, fontsize=7)
|
| |
|
| | plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA")
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "cka_self.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_cka_cross(cka_matrix: np.ndarray, names_a: list, names_b: list,
|
| | output_dir: Path, label_a: str, label_b: str):
|
| | """Plot cross-model CKA heatmap."""
|
| | short_a = [_short_name(n) for n in names_a]
|
| | short_b = [_short_name(n) for n in names_b]
|
| |
|
| | na, nb = len(names_a), len(names_b)
|
| | fig, ax = plt.subplots(figsize=(max(10, nb * 0.35), max(8, na * 0.3)))
|
| | fig.suptitle(f"Cross-CKA: {label_a} vs {label_b}", fontsize=14)
|
| |
|
| | im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="auto")
|
| |
|
| | ax.set_xticks(range(nb))
|
| | ax.set_xticklabels(short_b, rotation=90, fontsize=7)
|
| | ax.set_xlabel(label_b)
|
| | ax.set_yticks(range(na))
|
| | ax.set_yticklabels(short_a, fontsize=7)
|
| | ax.set_ylabel(label_a)
|
| |
|
| | plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA")
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "cka_cross.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_logit_lens(lens_results: OrderedDict, output_dir: Path,
|
| | model_label: str):
|
| | """Plot logit lens summary: entropy, confidence, rank, agreement."""
|
| | names = list(lens_results.keys())
|
| | sorted_names = sorted(names, key=_layer_sort_key)
|
| | short = [_short_name(n) for n in sorted_names]
|
| | colors = [_phase_color(n) for n in sorted_names]
|
| | x = range(len(sorted_names))
|
| |
|
| | fig, axes = plt.subplots(2, 2, figsize=(16, 10))
|
| | fig.suptitle(f"{model_label} -- Logit Lens", fontsize=14)
|
| |
|
| |
|
| | vals = [lens_results[n]["entropy"] for n in sorted_names]
|
| | axes[0, 0].bar(x, vals, color=colors, alpha=0.85)
|
| | axes[0, 0].set_ylabel("Entropy (nats)")
|
| | axes[0, 0].set_title("Prediction entropy per layer")
|
| | axes[0, 0].set_xticks(x)
|
| | axes[0, 0].set_xticklabels(short, rotation=90, fontsize=7)
|
| |
|
| |
|
| | vals = [lens_results[n]["top1_prob"] for n in sorted_names]
|
| | axes[0, 1].bar(x, vals, color=colors, alpha=0.85)
|
| | axes[0, 1].set_ylabel("Top-1 probability")
|
| | axes[0, 1].set_title("Prediction confidence per layer")
|
| | axes[0, 1].set_xticks(x)
|
| | axes[0, 1].set_xticklabels(short, rotation=90, fontsize=7)
|
| |
|
| |
|
| | vals = [lens_results[n]["correct_rank_median"] for n in sorted_names]
|
| | axes[1, 0].bar(x, vals, color=colors, alpha=0.85)
|
| | axes[1, 0].set_ylabel("Median rank of correct token")
|
| | axes[1, 0].set_yscale("log")
|
| | axes[1, 0].set_title("When does the model find the answer?")
|
| | axes[1, 0].set_xticks(x)
|
| | axes[1, 0].set_xticklabels(short, rotation=90, fontsize=7)
|
| |
|
| |
|
| | vals = [lens_results[n]["agreement_with_final"] for n in sorted_names]
|
| | axes[1, 1].bar(x, vals, color=colors, alpha=0.85)
|
| | axes[1, 1].set_ylabel("Agreement with final layer")
|
| | axes[1, 1].set_title("Convergence toward final prediction")
|
| | axes[1, 1].set_ylim(0, 1.05)
|
| | axes[1, 1].set_xticks(x)
|
| | axes[1, 1].set_xticklabels(short, rotation=90, fontsize=7)
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "logit_lens_summary.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_logit_lens_trajectory(activations: dict, norm: nn.Module,
|
| | unembed_weight: torch.Tensor, input_ids: torch.Tensor,
|
| | tokenizer, output_dir: Path, model_label: str,
|
| | device: str = "cpu",
|
| | n_positions: int = 6, n_layers: int = 10):
|
| | """Show top-5 predicted tokens at selected layers for a few positions.
|
| |
|
| | Picks positions spread across the first sample and shows how the
|
| | model's prediction evolves through the network.
|
| | """
|
| | names = sorted(activations.keys(), key=_layer_sort_key)
|
| |
|
| |
|
| | if len(names) > n_layers:
|
| | indices = np.linspace(0, len(names) - 1, n_layers, dtype=int)
|
| | selected_layers = [names[i] for i in indices]
|
| | else:
|
| | selected_layers = names
|
| |
|
| |
|
| | seq_len = input_ids.shape[1]
|
| | pos_indices = np.linspace(10, seq_len - 2, n_positions, dtype=int)
|
| |
|
| | unembed = unembed_weight.to(device)
|
| | norm_mod = norm.to(device)
|
| | final_name = names[-1]
|
| |
|
| | fig, axes = plt.subplots(n_positions, 1, figsize=(14, 3 * n_positions))
|
| | if n_positions == 1:
|
| | axes = [axes]
|
| | fig.suptitle(f"{model_label} -- Token prediction trajectory", fontsize=14, y=1.02)
|
| |
|
| | for pos_idx, pos in enumerate(pos_indices):
|
| | ax = axes[pos_idx]
|
| | actual_token = tokenizer.decode([input_ids[0, pos + 1].item()])
|
| | context = tokenizer.decode(input_ids[0, max(0, pos - 5):pos + 1].tolist())
|
| |
|
| | layer_labels = []
|
| | top_tokens_per_layer = []
|
| |
|
| | for name in selected_layers:
|
| | is_final = (name == final_name)
|
| | hidden = activations[name][0, pos:pos + 1, :].to(device)
|
| | if not is_final:
|
| | hidden = norm_mod(hidden)
|
| | logits = (hidden @ unembed.T).squeeze(0)
|
| | probs = F.softmax(logits, dim=-1)
|
| | top5_vals, top5_idx = probs.topk(5)
|
| |
|
| | tokens_str = []
|
| | for val, idx in zip(top5_vals, top5_idx):
|
| | tok = tokenizer.decode([idx.item()]).replace("\n", "\\n")
|
| | tokens_str.append(f"{tok}({val:.2f})")
|
| |
|
| | layer_labels.append(_short_name(name))
|
| | top_tokens_per_layer.append("\n".join(tokens_str))
|
| |
|
| |
|
| | ax.set_xlim(-0.5, len(layer_labels) - 0.5)
|
| | ax.set_ylim(-0.5, 5.5)
|
| | ax.set_xticks(range(len(layer_labels)))
|
| | ax.set_xticklabels(layer_labels, fontsize=8)
|
| | ax.set_yticks([])
|
| |
|
| | for li, tokens_str in enumerate(top_tokens_per_layer):
|
| | lines = tokens_str.split("\n")
|
| | for rank, line in enumerate(lines):
|
| | color = "darkgreen" if actual_token.strip() in line else "black"
|
| | fontweight = "bold" if actual_token.strip() in line else "normal"
|
| | ax.text(li, rank, line, ha="center", va="center", fontsize=7,
|
| | color=color, fontweight=fontweight)
|
| |
|
| | ax.set_title(f'pos {pos}: "...{context}" -> [{actual_token.strip()}]',
|
| | fontsize=9, loc="left")
|
| | ax.invert_yaxis()
|
| | ax.spines["top"].set_visible(False)
|
| | ax.spines["right"].set_visible(False)
|
| | ax.spines["left"].set_visible(False)
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "logit_lens_trajectory.png", dpi=150, bbox_inches="tight")
|
| | plt.close(fig)
|
| |
|
| |
|
| | def plot_drift(drift: OrderedDict, output_dir: Path, model_label: str):
|
| | """Plot representation drift between consecutive layers."""
|
| | names = list(drift.keys())
|
| | sorted_names = sorted(names, key=_layer_sort_key)
|
| | short = [_short_name(n) for n in sorted_names]
|
| | colors = [_phase_color(n) for n in sorted_names]
|
| | x = range(len(sorted_names))
|
| |
|
| | fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| | fig.suptitle(f"{model_label} -- Representation drift", fontsize=14)
|
| |
|
| |
|
| | vals = [drift[n]["cos_sim_mean"] for n in sorted_names]
|
| | axes[0].bar(x, vals, color=colors, alpha=0.85)
|
| | axes[0].set_ylabel("Cosine similarity with previous layer")
|
| | axes[0].set_title("How much each layer preserves direction")
|
| | axes[0].set_xticks(x)
|
| | axes[0].set_xticklabels(short, rotation=90, fontsize=7)
|
| |
|
| |
|
| | vals = [drift[n]["l2_distance"] for n in sorted_names]
|
| | axes[1].bar(x, vals, color=colors, alpha=0.85)
|
| | axes[1].set_ylabel("L2 distance from previous layer")
|
| | axes[1].set_title("How much each layer changes magnitude")
|
| | axes[1].set_xticks(x)
|
| | axes[1].set_xticklabels(short, rotation=90, fontsize=7)
|
| |
|
| | plt.tight_layout()
|
| | fig.savefig(output_dir / "representation_drift.png", dpi=150)
|
| | plt.close(fig)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir):
|
| | """Save all numerical results to JSON."""
|
| | out = {}
|
| |
|
| | if cka_matrix is not None:
|
| | out["cka_self"] = {
|
| | "names": cka_names,
|
| | "matrix": cka_matrix.tolist(),
|
| | }
|
| |
|
| | if lens_results:
|
| | out["logit_lens"] = {name: data for name, data in lens_results.items()}
|
| |
|
| | if drift:
|
| | out["drift"] = {name: data for name, data in drift.items()}
|
| |
|
| | if cross_cka is not None:
|
| | matrix, names_a, names_b = cross_cka
|
| | out["cka_cross"] = {
|
| | "names_a": names_a,
|
| | "names_b": names_b,
|
| | "matrix": matrix.tolist(),
|
| | }
|
| |
|
| | with open(output_dir / "results.json", "w") as f:
|
| | json.dump(out, f, indent=2, default=str)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser(
|
| | description="CKA and Logit Lens analysis for Prisma / Circuit Transformer")
|
| | parser.add_argument("--checkpoint", type=str, required=True,
|
| | help="Path to Prisma/Circuit checkpoint")
|
| | parser.add_argument("--checkpoint-b", type=str, default=None,
|
| | help="Second Prisma checkpoint for cross-model CKA")
|
| | parser.add_argument("--hf-model", type=str, default=None,
|
| | help="HuggingFace model for cross-model CKA (e.g. gpt2-medium)")
|
| | parser.add_argument("--data", type=str, required=True,
|
| | help="Data source (hf:dataset:config:split or file path)")
|
| | parser.add_argument("--num-samples", type=int, default=32,
|
| | help="Number of text samples (default: 32)")
|
| | parser.add_argument("--context-length", type=int, default=512,
|
| | help="Sequence length (default: 512)")
|
| | parser.add_argument("--cka-subsample", type=int, default=4,
|
| | help="Position subsampling for CKA (default: 4)")
|
| | parser.add_argument("--no-logit-lens", action="store_true",
|
| | help="Skip logit lens analysis")
|
| | parser.add_argument("--no-cka", action="store_true",
|
| | help="Skip CKA analysis")
|
| | parser.add_argument("--output-dir", type=str, default=None,
|
| | help="Output directory (default: auto)")
|
| | parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| | args = parser.parse_args()
|
| |
|
| | device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
|
| | print(f"Device: {device}")
|
| |
|
| |
|
| | if args.output_dir:
|
| | output_dir = Path(args.output_dir)
|
| | else:
|
| | ckpt_name = Path(args.checkpoint).parent.name
|
| | output_dir = Path("circuits/scripts/representation_output") / ckpt_name
|
| | output_dir.mkdir(parents=True, exist_ok=True)
|
| | print(f"Output: {output_dir}")
|
| |
|
| |
|
| | print(f"\nLoading: {args.checkpoint}")
|
| | model_a, config_a, model_type_a = load_prisma_model(args.checkpoint, device)
|
| | label_a = Path(args.checkpoint).parent.name
|
| | n_params = sum(p.numel() for p in model_a.parameters())
|
| | print(f" Type: {model_type_a}, params: {n_params:,}")
|
| |
|
| |
|
| | ckpt_data = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| | tokenizer_name = ckpt_data.get("tokenizer_name", config_a.get("tokenizer_name", "gpt2"))
|
| | del ckpt_data
|
| |
|
| | print(f"\nLoading data ({args.num_samples} samples, ctx={args.context_length})...")
|
| | result = load_data(
|
| | args.data, tokenizer_name, args.num_samples, args.context_length, device
|
| | )
|
| | if result[0] is None:
|
| | print("ERROR: No valid samples loaded.")
|
| | return
|
| | input_ids, tokenizer = result
|
| | print(f" Data shape: {input_ids.shape}")
|
| |
|
| |
|
| | print(f"\nCollecting activations ({model_type_a})...")
|
| | acts_a = collect_activations(model_a, model_type_a, config_a, input_ids, device)
|
| | print(f" Collected {len(acts_a)} layers")
|
| |
|
| |
|
| | del model_a
|
| | if device.startswith("cuda"):
|
| | torch.cuda.empty_cache()
|
| |
|
| |
|
| | cka_matrix = None
|
| | cka_names = None
|
| | if not args.no_cka:
|
| | print(f"\nComputing self-CKA (subsample={args.cka_subsample})...")
|
| | cka_matrix, cka_names = compute_cka_matrix(acts_a, subsample=args.cka_subsample)
|
| | plot_cka_self(cka_matrix, cka_names, output_dir, label_a)
|
| | print(f" Saved: cka_self.png")
|
| |
|
| |
|
| | cross_cka = None
|
| | if not args.no_cka and (args.checkpoint_b or args.hf_model):
|
| | if args.checkpoint_b:
|
| | print(f"\nLoading comparison: {args.checkpoint_b}")
|
| | model_b, config_b, model_type_b = load_prisma_model(args.checkpoint_b, device)
|
| | label_b = Path(args.checkpoint_b).parent.name
|
| | acts_b = collect_activations(model_b, model_type_b, config_b, input_ids, device)
|
| | del model_b
|
| | else:
|
| | print(f"\nLoading HF model: {args.hf_model}")
|
| | model_b = load_hf_model(args.hf_model, device)
|
| | label_b = args.hf_model
|
| |
|
| | print(f" Re-tokenizing for {args.hf_model}...")
|
| | raw_texts = [tokenizer.decode(input_ids[i].tolist()) for i in range(input_ids.shape[0])]
|
| | input_ids_b, _ = tokenize_for_hf(
|
| | raw_texts, args.hf_model, args.context_length, device
|
| | )
|
| | if input_ids_b is not None:
|
| | print(f" HF data shape: {input_ids_b.shape}")
|
| | acts_b = collect_hf_activations(model_b, input_ids_b)
|
| | else:
|
| | acts_b = None
|
| | del model_b
|
| |
|
| | if device.startswith("cuda"):
|
| | torch.cuda.empty_cache()
|
| |
|
| | if acts_b:
|
| | print(f"\nComputing cross-model CKA...")
|
| | cross_matrix, cross_names_a, cross_names_b = compute_cross_model_cka(acts_a, acts_b)
|
| | cross_cka = (cross_matrix, cross_names_a, cross_names_b)
|
| | plot_cka_cross(cross_matrix, cross_names_a, cross_names_b,
|
| | output_dir, label_a, label_b)
|
| | print(f" Saved: cka_cross.png")
|
| | del acts_b
|
| |
|
| |
|
| | lens_results = None
|
| | if not args.no_logit_lens:
|
| |
|
| | print(f"\nReloading model for logit lens...")
|
| | model_a, _, _ = load_prisma_model(args.checkpoint, device)
|
| | norm, unembed_weight = get_unembed_components(model_a, model_type_a)
|
| |
|
| | labels = input_ids[:, 1:].cpu()
|
| |
|
| | print(f"Computing logit lens...")
|
| | lens_results = compute_logit_lens(acts_a, norm, unembed_weight, labels, device)
|
| | plot_logit_lens(lens_results, output_dir, label_a)
|
| | print(f" Saved: logit_lens_summary.png")
|
| |
|
| |
|
| | print(f" Generating token trajectories...")
|
| | plot_logit_lens_trajectory(
|
| | acts_a, norm, unembed_weight, input_ids.cpu(), tokenizer,
|
| | output_dir, label_a, device
|
| | )
|
| | print(f" Saved: logit_lens_trajectory.png")
|
| |
|
| | del model_a
|
| | if device.startswith("cuda"):
|
| | torch.cuda.empty_cache()
|
| |
|
| |
|
| | print(f"\nComputing representation drift...")
|
| | drift = compute_drift(acts_a)
|
| | plot_drift(drift, output_dir, label_a)
|
| | print(f" Saved: representation_drift.png")
|
| |
|
| |
|
| | save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir)
|
| | print(f"\nAll outputs saved to: {output_dir}")
|
| | n_plots = len(list(output_dir.glob("*.png")))
|
| | print(f" Plots: {n_plots} PNG files")
|
| | print(f" Data: results.json")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|