#!/usr/bin/env python3 """ Representation analysis: CKA and Logit Lens for Prisma / Circuit Transformer. CKA (Centered Kernel Alignment): Measures representational similarity between all layer pairs. Produces a heatmap revealing mirror symmetry, phase transitions, and cross-model alignment. Logit Lens: Projects intermediate representations to vocabulary space at every layer. Reveals what the model "thinks" at each processing stage -- from raw tokens through the semantic bottleneck back to specific predictions. Also computes representation drift (cosine similarity between consecutive layers). Usage: # Full analysis (CKA + logit lens) python -m circuits.scripts.representation_analysis \\ --checkpoint path/to/checkpoint.pt \\ --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train # Cross-model CKA python -m circuits.scripts.representation_analysis \\ --checkpoint path/to/prisma.pt --hf-model gpt2-medium \\ --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train # CKA only (skip logit lens) python -m circuits.scripts.representation_analysis \\ --checkpoint path/to/checkpoint.pt \\ --data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train \\ --no-logit-lens """ import argparse import json import sys import os from pathlib import Path from collections import OrderedDict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def load_prisma_model(checkpoint_path: str, device: str = "cpu"): """Load a Prisma/Circuit checkpoint, return (model, config_dict, model_type).""" sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from circuits.config import CircuitConfig from circuits.model import CircuitTransformer from circuits.mirrored import MirroredConfig, MirroredTransformer ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) model_type = ckpt.get("model_type", "standard") config_dict = ckpt.get("config", {}) if model_type == "mirrored": if config_dict.get("dual_gate_middle"): config_dict.pop("dual_gate_middle") config = MirroredConfig.from_dict(config_dict) model = MirroredTransformer(config) else: config = CircuitConfig.from_dict(config_dict) model = CircuitTransformer(config) state_dict = ckpt["model"] if any(k.startswith("_orig_mod.") for k in state_dict): state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()} model.load_state_dict(state_dict, strict=False) model.to(device).eval() return model, config_dict, model_type def load_hf_model(model_name: str, device: str = "cpu"): """Load a HuggingFace causal LM.""" from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True) model.to(device).eval() return model # --------------------------------------------------------------------------- # Data loading # --------------------------------------------------------------------------- def load_data(data_source: str, tokenizer_name: str, num_samples: int = 32, context_length: int = 512, device: str = "cpu"): """Load tokenized data. Returns (input_ids, tokenizer). Supports: - Memmap .bin files (from circuits training cache) - hf:dataset:config:split (streaming from HuggingFace) - Plain text files """ sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from circuits.data import get_tokenizer tokenizer = get_tokenizer(tokenizer_name) # Memmap binary file (already tokenized) if data_source.endswith(".bin"): import struct with open(data_source, 'rb') as f: n_chunks, seq_len = struct.unpack('II', f.read(8)) data = np.memmap(data_source, dtype=np.int32, mode='r', offset=8, shape=(n_chunks, seq_len)) n = min(num_samples, n_chunks) # Slice to requested context length cl = min(context_length, seq_len) input_ids = torch.from_numpy(data[:n, :cl].copy()).long().to(device) return input_ids, tokenizer # HuggingFace dataset if data_source.startswith("hf:"): from datasets import load_dataset parts = data_source[3:].split(":") ds_name = parts[0] ds_config = parts[1] if len(parts) > 1 else None ds_split = parts[2] if len(parts) > 2 else "train" dataset = load_dataset(ds_name, ds_config, split=ds_split, streaming=True) all_ids = [] for item in dataset: text = item.get("text", "") if len(text) < 100: continue ids = tokenizer.encode(text) if len(ids) >= context_length: all_ids.append(ids[:context_length]) if len(all_ids) >= num_samples: break if not all_ids: return None, tokenizer return torch.tensor(all_ids, device=device), tokenizer # Plain text file with open(data_source) as f: texts = [line.strip() for line in f if len(line.strip()) > 100] all_ids = [] for text in texts: ids = tokenizer.encode(text) if len(ids) >= context_length: all_ids.append(ids[:context_length]) if len(all_ids) >= num_samples: break if not all_ids: return None, tokenizer return torch.tensor(all_ids, device=device), tokenizer def tokenize_for_hf(texts: list, model_name: str, context_length: int = 512, device: str = "cpu"): """Tokenize texts for an HF model. Returns (input_ids, tokenizer).""" from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token all_ids = [] for text in texts: ids = tokenizer.encode(text, max_length=context_length, truncation=True) if len(ids) >= context_length: all_ids.append(ids[:context_length]) elif len(ids) > 32: all_ids.append(ids + [tokenizer.eos_token_id] * (context_length - len(ids))) if not all_ids: return None, tokenizer return torch.tensor(all_ids, device=device), tokenizer # --------------------------------------------------------------------------- # Activation collection # --------------------------------------------------------------------------- def collect_mirrored_activations(model, input_ids, word_positions=None): """Collect activations from MirroredTransformer at every processing stage.""" activations = OrderedDict() with torch.no_grad(): x = model.embed(input_ids) if model.embed_proj is not None: if model.embed_g3 is not None: g4 = F.silu(model.embed_g4(x)) g3 = F.silu(model.embed_g3(x) * g4) x = model.embed_proj(x) * g3 else: x = F.silu(model.embed_proj(x)) x = x * model.embed_scale activations["embedding"] = x.detach().cpu() for i, block in enumerate(model.mirror_blocks): x, _ = block(x, word_positions=word_positions) activations[f"expand_{i}"] = x.detach().cpu() for i, block in enumerate(model.middle_blocks): x, _ = block(x, word_positions=word_positions) activations[f"middle_{i}"] = x.detach().cpu() for i in reversed(range(len(model.mirror_blocks))): x, _ = model.mirror_blocks[i](x, word_positions=word_positions) compress_idx = len(model.mirror_blocks) - 1 - i activations[f"compress_{compress_idx}"] = x.detach().cpu() x = model.norm(x) activations["final_norm"] = x.detach().cpu() return activations def collect_standard_activations(model, input_ids, word_positions=None): """Collect activations from standard CircuitTransformer.""" activations = OrderedDict() with torch.no_grad(): x = model.embed(input_ids) if model.embed_proj is not None: x = F.silu(model.embed_proj(x)) x = x * model.embed_scale activations["embedding"] = x.detach().cpu() for i, layer in enumerate(model.layers): x, _ = layer(x, word_positions=word_positions) activations[f"layer_{i}"] = x.detach().cpu() x = model.norm(x) activations["final_norm"] = x.detach().cpu() return activations def collect_hf_activations(model, input_ids): """Hook-based activation collection for HuggingFace models.""" activations = OrderedDict() hooks = [] if hasattr(model, 'transformer'): # GPT-2 style blocks = model.transformer.h embed = model.transformer.wte final_norm = model.transformer.ln_f elif hasattr(model, 'model'): # Llama / Mistral style blocks = model.model.layers embed = model.model.embed_tokens final_norm = model.model.norm else: raise ValueError(f"Unsupported HF model: {type(model)}") def make_hook(name): def hook_fn(module, input, output): out = output[0] if isinstance(output, tuple) else output activations[name] = out.detach().cpu() return hook_fn hooks.append(embed.register_forward_hook(make_hook("embedding"))) for i, block in enumerate(blocks): hooks.append(block.register_forward_hook(make_hook(f"layer_{i}"))) hooks.append(final_norm.register_forward_hook(make_hook("final_norm"))) with torch.no_grad(): model(input_ids) for h in hooks: h.remove() return activations def collect_activations(model, model_type, config_dict, input_ids, device): """Dispatch to the right collector based on model type.""" word_positions = None word_rope_dims = config_dict.get("word_rope_dims", 0) if config_dict else 0 if word_rope_dims > 0 and model_type in ("standard", "mirrored"): sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from circuits.data import get_tokenizer from circuits.layers import build_word_start_table, compute_word_positions tokenizer_name = config_dict.get("tokenizer_name", "gpt2") # Try to get tokenizer from the model's config tokenizer = get_tokenizer(tokenizer_name) word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device) word_positions = compute_word_positions(input_ids, word_start_table) if model_type == "mirrored": return collect_mirrored_activations(model, input_ids, word_positions) elif model_type == "standard": return collect_standard_activations(model, input_ids, word_positions) else: return collect_hf_activations(model, input_ids) # --------------------------------------------------------------------------- # Linear CKA # --------------------------------------------------------------------------- def linear_cka(X: torch.Tensor, Y: torch.Tensor) -> float: """Compute linear CKA between two [N, D] representation matrices. CKA(X, Y) = ||Yc^T Xc||_F^2 / (||Xc^T Xc||_F * ||Yc^T Yc||_F) """ X = X.float() Y = Y.float() # Center X = X - X.mean(0, keepdim=True) Y = Y - Y.mean(0, keepdim=True) N = X.shape[0] if N < min(X.shape[1], Y.shape[1]): # Kernel formulation (N < D): K=XX^T, L=YY^T — [N,N] matrices K = X @ X.T L = Y @ Y.T numerator = (K * L).sum() denominator = torch.sqrt((K * K).sum() * (L * L).sum()) else: # Feature formulation (D <= N) XtY = X.T @ Y XtX = X.T @ X YtY = Y.T @ Y numerator = (XtY * XtY).sum() denominator = torch.sqrt((XtX * XtX).sum() * (YtY * YtY).sum()) if denominator < 1e-10: return 0.0 return (numerator / denominator).item() def compute_cka_matrix(activations: dict, subsample: int = 4) -> tuple: """Compute CKA between all layer pairs. Returns (cka_matrix, layer_names).""" names = list(activations.keys()) n_layers = len(names) # Flatten and subsample: [B, L, D] -> [N, D] flat_acts = {} for name, act in activations.items(): act_sub = act[:, ::subsample, :] flat_acts[name] = act_sub.reshape(-1, act_sub.shape[-1]) cka_matrix = np.zeros((n_layers, n_layers)) for i in range(n_layers): cka_matrix[i, i] = 1.0 for j in range(i + 1, n_layers): cka_val = linear_cka(flat_acts[names[i]], flat_acts[names[j]]) cka_matrix[i, j] = cka_val cka_matrix[j, i] = cka_val if (i + 1) % 5 == 0 or i == n_layers - 1: print(f" CKA: {i+1}/{n_layers} rows computed") return cka_matrix, names def compute_cross_model_cka(acts_a: dict, acts_b: dict) -> tuple: """Cross-model CKA using sample-level (avg-pooled) representations.""" names_a = list(acts_a.keys()) names_b = list(acts_b.keys()) def pool(activations): return {name: act.mean(dim=1) for name, act in activations.items()} pooled_a = pool(acts_a) pooled_b = pool(acts_b) # Ensure same number of samples n_samples = min( next(iter(pooled_a.values())).shape[0], next(iter(pooled_b.values())).shape[0] ) cka_matrix = np.zeros((len(names_a), len(names_b))) for i, na in enumerate(names_a): for j, nb in enumerate(names_b): cka_matrix[i, j] = linear_cka(pooled_a[na][:n_samples], pooled_b[nb][:n_samples]) if (i + 1) % 5 == 0 or i == len(names_a) - 1: print(f" Cross-CKA: {i+1}/{len(names_a)} rows computed") return cka_matrix, names_a, names_b # --------------------------------------------------------------------------- # Logit Lens # --------------------------------------------------------------------------- def get_unembed_components(model, model_type): """Extract (norm_module, unembed_weight) for logit lens projection.""" if model_type in ("standard", "mirrored"): return model.norm, model.embed.weight elif hasattr(model, 'transformer'): return model.transformer.ln_f, model.transformer.wte.weight elif hasattr(model, 'model'): return model.model.norm, model.model.embed_tokens.weight else: raise ValueError(f"Unsupported model: {type(model)}") def compute_logit_lens(activations: dict, norm: nn.Module, unembed_weight: torch.Tensor, labels: torch.Tensor, device: str = "cpu", chunk_size: int = 2048) -> OrderedDict: """Compute logit lens statistics at every layer. Projects intermediate hidden states through final norm + unembedding. Computes entropy, top-1 probability, correct token rank, and agreement with the final layer's predictions. Args: activations: OrderedDict[name] = [B, L, D] norm: final layer norm module unembed_weight: [V, D] unembedding matrix labels: [B, L-1] next-token labels (input_ids[:, 1:]) device: computation device chunk_size: number of positions per batch for projection Returns: OrderedDict[name] = {entropy, top1_prob, correct_rank, ...} """ names = list(activations.keys()) final_name = names[-1] # "final_norm" results = OrderedDict() unembed = unembed_weight.to(device) norm_mod = norm.to(device) labels_flat = labels.reshape(-1).to(device) def process_layer(name, act, apply_norm=True): """Project one layer's activations and compute all metrics.""" B, L, D = act.shape flat = act[:, :-1, :].reshape(-1, D) # [B*(L-1), D] N = flat.shape[0] all_entropy = [] all_top1_prob = [] all_correct_rank = [] all_top1_idx = [] for start in range(0, N, chunk_size): end = min(start + chunk_size, N) chunk = flat[start:end].to(device) chunk_labels = labels_flat[start:end] if apply_norm: chunk = norm_mod(chunk) logits = chunk @ unembed.T # [cs, V] log_probs = F.log_softmax(logits, dim=-1) probs = log_probs.exp() # Entropy entropy = -(probs * log_probs).sum(dim=-1) all_entropy.append(entropy.cpu()) # Top-1 probability top1_prob = probs.max(dim=-1).values all_top1_prob.append(top1_prob.cpu()) # Correct token rank correct_logits = logits.gather(1, chunk_labels.unsqueeze(1)) rank = (logits > correct_logits).sum(dim=-1) + 1 all_correct_rank.append(rank.cpu()) # Top-1 index all_top1_idx.append(logits.argmax(dim=-1).cpu()) entropy_t = torch.cat(all_entropy) top1_t = torch.cat(all_top1_prob) rank_t = torch.cat(all_correct_rank).float() top1_idx = torch.cat(all_top1_idx) return { "entropy": entropy_t.mean().item(), "entropy_std": entropy_t.std().item(), "top1_prob": top1_t.mean().item(), "correct_rank_mean": rank_t.mean().item(), "correct_rank_median": rank_t.median().item(), "log_rank_mean": rank_t.log().mean().item(), "_top1_idx": top1_idx, } # Process all layers for name in names: is_final = (name == final_name) act = activations[name] stats = process_layer(name, act, apply_norm=not is_final) results[name] = stats print(f" Logit lens: {name:20s} entropy={stats['entropy']:.2f} " f"top1={stats['top1_prob']:.4f} rank={stats['correct_rank_median']:.0f}") # Compute agreement with final layer final_top1 = results[final_name]["_top1_idx"] for name in names: layer_top1 = results[name]["_top1_idx"] agreement = (layer_top1 == final_top1).float().mean().item() results[name]["agreement_with_final"] = agreement # Clean up internal tensors for name in names: del results[name]["_top1_idx"] return results # --------------------------------------------------------------------------- # Representation drift # --------------------------------------------------------------------------- def compute_drift(activations: dict) -> OrderedDict: """Cosine similarity between consecutive layers' representations.""" names = list(activations.keys()) drift = OrderedDict() for i in range(1, len(names)): prev = activations[names[i - 1]] curr = activations[names[i]] # Flatten to [N, D] prev_flat = prev.reshape(-1, prev.shape[-1]) curr_flat = curr.reshape(-1, curr.shape[-1]) # Mean cosine similarity cos = F.cosine_similarity(prev_flat, curr_flat, dim=-1) drift[names[i]] = { "cos_sim_mean": cos.mean().item(), "cos_sim_std": cos.std().item(), "l2_distance": (curr_flat - prev_flat).norm(dim=-1).mean().item(), } return drift # --------------------------------------------------------------------------- # Plotting # --------------------------------------------------------------------------- def _phase_color(name): """Return color based on layer phase.""" if "expand" in name: return "steelblue" elif "middle" in name: return "goldenrod" elif "compress" in name: return "coral" elif "embedding" in name: return "gray" elif "final" in name: return "gray" else: return "mediumpurple" def _layer_sort_key(name): """Sort key for processing order.""" order = {"embedding": -1, "final_norm": 9999} if name in order: return order[name] parts = name.split("_") phase = parts[0] idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0 phase_offset = {"expand": 0, "middle": 1000, "compress": 2000, "layer": 0} return phase_offset.get(phase, 3000) + idx def _short_name(name): """Shorten layer name for plot labels.""" if name == "embedding": return "emb" if name == "final_norm": return "out" parts = name.split("_") if parts[0] == "expand": return f"E{parts[1]}" elif parts[0] == "middle": return f"M{parts[1]}" elif parts[0] == "compress": return f"C{parts[1]}" elif parts[0] == "layer": return f"L{parts[1]}" return name[:6] def plot_cka_self(cka_matrix: np.ndarray, names: list, output_dir: Path, model_label: str): """Plot self-CKA heatmap.""" n = len(names) short = [_short_name(n) for n in names] fig, ax = plt.subplots(figsize=(max(10, n * 0.35), max(8, n * 0.3))) fig.suptitle(f"{model_label} -- CKA Self-Similarity", fontsize=14) im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="equal") # Phase separators for i, name in enumerate(names): if i > 0: prev = names[i - 1].split("_")[0] curr = name.split("_")[0] if prev != curr: ax.axhline(i - 0.5, color="white", linewidth=1.5, alpha=0.8) ax.axvline(i - 0.5, color="white", linewidth=1.5, alpha=0.8) ax.set_xticks(range(n)) ax.set_xticklabels(short, rotation=90, fontsize=7) ax.set_yticks(range(n)) ax.set_yticklabels(short, fontsize=7) plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA") plt.tight_layout() fig.savefig(output_dir / "cka_self.png", dpi=150) plt.close(fig) def plot_cka_cross(cka_matrix: np.ndarray, names_a: list, names_b: list, output_dir: Path, label_a: str, label_b: str): """Plot cross-model CKA heatmap.""" short_a = [_short_name(n) for n in names_a] short_b = [_short_name(n) for n in names_b] na, nb = len(names_a), len(names_b) fig, ax = plt.subplots(figsize=(max(10, nb * 0.35), max(8, na * 0.3))) fig.suptitle(f"Cross-CKA: {label_a} vs {label_b}", fontsize=14) im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="auto") ax.set_xticks(range(nb)) ax.set_xticklabels(short_b, rotation=90, fontsize=7) ax.set_xlabel(label_b) ax.set_yticks(range(na)) ax.set_yticklabels(short_a, fontsize=7) ax.set_ylabel(label_a) plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA") plt.tight_layout() fig.savefig(output_dir / "cka_cross.png", dpi=150) plt.close(fig) def plot_logit_lens(lens_results: OrderedDict, output_dir: Path, model_label: str): """Plot logit lens summary: entropy, confidence, rank, agreement.""" names = list(lens_results.keys()) sorted_names = sorted(names, key=_layer_sort_key) short = [_short_name(n) for n in sorted_names] colors = [_phase_color(n) for n in sorted_names] x = range(len(sorted_names)) fig, axes = plt.subplots(2, 2, figsize=(16, 10)) fig.suptitle(f"{model_label} -- Logit Lens", fontsize=14) # Entropy vals = [lens_results[n]["entropy"] for n in sorted_names] axes[0, 0].bar(x, vals, color=colors, alpha=0.85) axes[0, 0].set_ylabel("Entropy (nats)") axes[0, 0].set_title("Prediction entropy per layer") axes[0, 0].set_xticks(x) axes[0, 0].set_xticklabels(short, rotation=90, fontsize=7) # Top-1 probability vals = [lens_results[n]["top1_prob"] for n in sorted_names] axes[0, 1].bar(x, vals, color=colors, alpha=0.85) axes[0, 1].set_ylabel("Top-1 probability") axes[0, 1].set_title("Prediction confidence per layer") axes[0, 1].set_xticks(x) axes[0, 1].set_xticklabels(short, rotation=90, fontsize=7) # Correct rank (log scale) vals = [lens_results[n]["correct_rank_median"] for n in sorted_names] axes[1, 0].bar(x, vals, color=colors, alpha=0.85) axes[1, 0].set_ylabel("Median rank of correct token") axes[1, 0].set_yscale("log") axes[1, 0].set_title("When does the model find the answer?") axes[1, 0].set_xticks(x) axes[1, 0].set_xticklabels(short, rotation=90, fontsize=7) # Agreement with final layer vals = [lens_results[n]["agreement_with_final"] for n in sorted_names] axes[1, 1].bar(x, vals, color=colors, alpha=0.85) axes[1, 1].set_ylabel("Agreement with final layer") axes[1, 1].set_title("Convergence toward final prediction") axes[1, 1].set_ylim(0, 1.05) axes[1, 1].set_xticks(x) axes[1, 1].set_xticklabels(short, rotation=90, fontsize=7) plt.tight_layout() fig.savefig(output_dir / "logit_lens_summary.png", dpi=150) plt.close(fig) def plot_logit_lens_trajectory(activations: dict, norm: nn.Module, unembed_weight: torch.Tensor, input_ids: torch.Tensor, tokenizer, output_dir: Path, model_label: str, device: str = "cpu", n_positions: int = 6, n_layers: int = 10): """Show top-5 predicted tokens at selected layers for a few positions. Picks positions spread across the first sample and shows how the model's prediction evolves through the network. """ names = sorted(activations.keys(), key=_layer_sort_key) # Select layers evenly spread across the network if len(names) > n_layers: indices = np.linspace(0, len(names) - 1, n_layers, dtype=int) selected_layers = [names[i] for i in indices] else: selected_layers = names # Select positions from the first sample seq_len = input_ids.shape[1] pos_indices = np.linspace(10, seq_len - 2, n_positions, dtype=int) unembed = unembed_weight.to(device) norm_mod = norm.to(device) final_name = names[-1] fig, axes = plt.subplots(n_positions, 1, figsize=(14, 3 * n_positions)) if n_positions == 1: axes = [axes] fig.suptitle(f"{model_label} -- Token prediction trajectory", fontsize=14, y=1.02) for pos_idx, pos in enumerate(pos_indices): ax = axes[pos_idx] actual_token = tokenizer.decode([input_ids[0, pos + 1].item()]) context = tokenizer.decode(input_ids[0, max(0, pos - 5):pos + 1].tolist()) layer_labels = [] top_tokens_per_layer = [] for name in selected_layers: is_final = (name == final_name) hidden = activations[name][0, pos:pos + 1, :].to(device) # [1, D] if not is_final: hidden = norm_mod(hidden) logits = (hidden @ unembed.T).squeeze(0) # [V] probs = F.softmax(logits, dim=-1) top5_vals, top5_idx = probs.topk(5) tokens_str = [] for val, idx in zip(top5_vals, top5_idx): tok = tokenizer.decode([idx.item()]).replace("\n", "\\n") tokens_str.append(f"{tok}({val:.2f})") layer_labels.append(_short_name(name)) top_tokens_per_layer.append("\n".join(tokens_str)) # Create a text table ax.set_xlim(-0.5, len(layer_labels) - 0.5) ax.set_ylim(-0.5, 5.5) ax.set_xticks(range(len(layer_labels))) ax.set_xticklabels(layer_labels, fontsize=8) ax.set_yticks([]) for li, tokens_str in enumerate(top_tokens_per_layer): lines = tokens_str.split("\n") for rank, line in enumerate(lines): color = "darkgreen" if actual_token.strip() in line else "black" fontweight = "bold" if actual_token.strip() in line else "normal" ax.text(li, rank, line, ha="center", va="center", fontsize=7, color=color, fontweight=fontweight) ax.set_title(f'pos {pos}: "...{context}" -> [{actual_token.strip()}]', fontsize=9, loc="left") ax.invert_yaxis() ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["left"].set_visible(False) plt.tight_layout() fig.savefig(output_dir / "logit_lens_trajectory.png", dpi=150, bbox_inches="tight") plt.close(fig) def plot_drift(drift: OrderedDict, output_dir: Path, model_label: str): """Plot representation drift between consecutive layers.""" names = list(drift.keys()) sorted_names = sorted(names, key=_layer_sort_key) short = [_short_name(n) for n in sorted_names] colors = [_phase_color(n) for n in sorted_names] x = range(len(sorted_names)) fig, axes = plt.subplots(1, 2, figsize=(14, 5)) fig.suptitle(f"{model_label} -- Representation drift", fontsize=14) # Cosine similarity with previous layer vals = [drift[n]["cos_sim_mean"] for n in sorted_names] axes[0].bar(x, vals, color=colors, alpha=0.85) axes[0].set_ylabel("Cosine similarity with previous layer") axes[0].set_title("How much each layer preserves direction") axes[0].set_xticks(x) axes[0].set_xticklabels(short, rotation=90, fontsize=7) # L2 distance vals = [drift[n]["l2_distance"] for n in sorted_names] axes[1].bar(x, vals, color=colors, alpha=0.85) axes[1].set_ylabel("L2 distance from previous layer") axes[1].set_title("How much each layer changes magnitude") axes[1].set_xticks(x) axes[1].set_xticklabels(short, rotation=90, fontsize=7) plt.tight_layout() fig.savefig(output_dir / "representation_drift.png", dpi=150) plt.close(fig) # --------------------------------------------------------------------------- # Results saving # --------------------------------------------------------------------------- def save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir): """Save all numerical results to JSON.""" out = {} if cka_matrix is not None: out["cka_self"] = { "names": cka_names, "matrix": cka_matrix.tolist(), } if lens_results: out["logit_lens"] = {name: data for name, data in lens_results.items()} if drift: out["drift"] = {name: data for name, data in drift.items()} if cross_cka is not None: matrix, names_a, names_b = cross_cka out["cka_cross"] = { "names_a": names_a, "names_b": names_b, "matrix": matrix.tolist(), } with open(output_dir / "results.json", "w") as f: json.dump(out, f, indent=2, default=str) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="CKA and Logit Lens analysis for Prisma / Circuit Transformer") parser.add_argument("--checkpoint", type=str, required=True, help="Path to Prisma/Circuit checkpoint") parser.add_argument("--checkpoint-b", type=str, default=None, help="Second Prisma checkpoint for cross-model CKA") parser.add_argument("--hf-model", type=str, default=None, help="HuggingFace model for cross-model CKA (e.g. gpt2-medium)") parser.add_argument("--data", type=str, required=True, help="Data source (hf:dataset:config:split or file path)") parser.add_argument("--num-samples", type=int, default=32, help="Number of text samples (default: 32)") parser.add_argument("--context-length", type=int, default=512, help="Sequence length (default: 512)") parser.add_argument("--cka-subsample", type=int, default=4, help="Position subsampling for CKA (default: 4)") parser.add_argument("--no-logit-lens", action="store_true", help="Skip logit lens analysis") parser.add_argument("--no-cka", action="store_true", help="Skip CKA analysis") parser.add_argument("--output-dir", type=str, default=None, help="Output directory (default: auto)") parser.add_argument("--gpu", type=int, default=0, help="GPU index") args = parser.parse_args() device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") # Output directory if args.output_dir: output_dir = Path(args.output_dir) else: ckpt_name = Path(args.checkpoint).parent.name output_dir = Path("circuits/scripts/representation_output") / ckpt_name output_dir.mkdir(parents=True, exist_ok=True) print(f"Output: {output_dir}") # === Load model A === print(f"\nLoading: {args.checkpoint}") model_a, config_a, model_type_a = load_prisma_model(args.checkpoint, device) label_a = Path(args.checkpoint).parent.name n_params = sum(p.numel() for p in model_a.parameters()) print(f" Type: {model_type_a}, params: {n_params:,}") # === Load data === ckpt_data = torch.load(args.checkpoint, map_location="cpu", weights_only=False) tokenizer_name = ckpt_data.get("tokenizer_name", config_a.get("tokenizer_name", "gpt2")) del ckpt_data print(f"\nLoading data ({args.num_samples} samples, ctx={args.context_length})...") result = load_data( args.data, tokenizer_name, args.num_samples, args.context_length, device ) if result[0] is None: print("ERROR: No valid samples loaded.") return input_ids, tokenizer = result print(f" Data shape: {input_ids.shape}") # === Collect activations (model A) === print(f"\nCollecting activations ({model_type_a})...") acts_a = collect_activations(model_a, model_type_a, config_a, input_ids, device) print(f" Collected {len(acts_a)} layers") # Free GPU memory del model_a if device.startswith("cuda"): torch.cuda.empty_cache() # === CKA (self) === cka_matrix = None cka_names = None if not args.no_cka: print(f"\nComputing self-CKA (subsample={args.cka_subsample})...") cka_matrix, cka_names = compute_cka_matrix(acts_a, subsample=args.cka_subsample) plot_cka_self(cka_matrix, cka_names, output_dir, label_a) print(f" Saved: cka_self.png") # === Cross-model CKA === cross_cka = None if not args.no_cka and (args.checkpoint_b or args.hf_model): if args.checkpoint_b: print(f"\nLoading comparison: {args.checkpoint_b}") model_b, config_b, model_type_b = load_prisma_model(args.checkpoint_b, device) label_b = Path(args.checkpoint_b).parent.name acts_b = collect_activations(model_b, model_type_b, config_b, input_ids, device) del model_b else: print(f"\nLoading HF model: {args.hf_model}") model_b = load_hf_model(args.hf_model, device) label_b = args.hf_model # Decode texts from our tokens and re-tokenize for HF model print(f" Re-tokenizing for {args.hf_model}...") raw_texts = [tokenizer.decode(input_ids[i].tolist()) for i in range(input_ids.shape[0])] input_ids_b, _ = tokenize_for_hf( raw_texts, args.hf_model, args.context_length, device ) if input_ids_b is not None: print(f" HF data shape: {input_ids_b.shape}") acts_b = collect_hf_activations(model_b, input_ids_b) else: acts_b = None del model_b if device.startswith("cuda"): torch.cuda.empty_cache() if acts_b: print(f"\nComputing cross-model CKA...") cross_matrix, cross_names_a, cross_names_b = compute_cross_model_cka(acts_a, acts_b) cross_cka = (cross_matrix, cross_names_a, cross_names_b) plot_cka_cross(cross_matrix, cross_names_a, cross_names_b, output_dir, label_a, label_b) print(f" Saved: cka_cross.png") del acts_b # === Logit lens === lens_results = None if not args.no_logit_lens: # Reload model for unembedding components (we deleted it for memory) print(f"\nReloading model for logit lens...") model_a, _, _ = load_prisma_model(args.checkpoint, device) norm, unembed_weight = get_unembed_components(model_a, model_type_a) labels = input_ids[:, 1:].cpu() # next-token labels print(f"Computing logit lens...") lens_results = compute_logit_lens(acts_a, norm, unembed_weight, labels, device) plot_logit_lens(lens_results, output_dir, label_a) print(f" Saved: logit_lens_summary.png") # Token trajectory visualization print(f" Generating token trajectories...") plot_logit_lens_trajectory( acts_a, norm, unembed_weight, input_ids.cpu(), tokenizer, output_dir, label_a, device ) print(f" Saved: logit_lens_trajectory.png") del model_a if device.startswith("cuda"): torch.cuda.empty_cache() # === Representation drift === print(f"\nComputing representation drift...") drift = compute_drift(acts_a) plot_drift(drift, output_dir, label_a) print(f" Saved: representation_drift.png") # === Save results === save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir) print(f"\nAll outputs saved to: {output_dir}") n_plots = len(list(output_dir.glob("*.png"))) print(f" Plots: {n_plots} PNG files") print(f" Data: results.json") if __name__ == "__main__": main()