Prisma / scripts /representation_analysis.py
y3i12's picture
Initial commit
56e82ec
#!/usr/bin/env python3
"""
Representation analysis: CKA and Logit Lens for Prisma / Circuit Transformer.
CKA (Centered Kernel Alignment):
Measures representational similarity between all layer pairs.
Produces a heatmap revealing mirror symmetry, phase transitions,
and cross-model alignment.
Logit Lens:
Projects intermediate representations to vocabulary space at every layer.
Reveals what the model "thinks" at each processing stage -- from raw
tokens through the semantic bottleneck back to specific predictions.
Also computes representation drift (cosine similarity between consecutive layers).
Usage:
# Full analysis (CKA + logit lens)
python -m circuits.scripts.representation_analysis \\
--checkpoint path/to/checkpoint.pt \\
--data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train
# Cross-model CKA
python -m circuits.scripts.representation_analysis \\
--checkpoint path/to/prisma.pt --hf-model gpt2-medium \\
--data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train
# CKA only (skip logit lens)
python -m circuits.scripts.representation_analysis \\
--checkpoint path/to/checkpoint.pt \\
--data hf:HuggingFaceFW/fineweb-edu:sample-10BT:train \\
--no-logit-lens
"""
import argparse
import json
import sys
import os
from pathlib import Path
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_prisma_model(checkpoint_path: str, device: str = "cpu"):
"""Load a Prisma/Circuit checkpoint, return (model, config_dict, model_type)."""
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from circuits.config import CircuitConfig
from circuits.model import CircuitTransformer
from circuits.mirrored import MirroredConfig, MirroredTransformer
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model_type = ckpt.get("model_type", "standard")
config_dict = ckpt.get("config", {})
if model_type == "mirrored":
if config_dict.get("dual_gate_middle"):
config_dict.pop("dual_gate_middle")
config = MirroredConfig.from_dict(config_dict)
model = MirroredTransformer(config)
else:
config = CircuitConfig.from_dict(config_dict)
model = CircuitTransformer(config)
state_dict = ckpt["model"]
if any(k.startswith("_orig_mod.") for k in state_dict):
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model.to(device).eval()
return model, config_dict, model_type
def load_hf_model(model_name: str, device: str = "cpu"):
"""Load a HuggingFace causal LM."""
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=True)
model.to(device).eval()
return model
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def load_data(data_source: str, tokenizer_name: str, num_samples: int = 32,
context_length: int = 512, device: str = "cpu"):
"""Load tokenized data. Returns (input_ids, tokenizer).
Supports:
- Memmap .bin files (from circuits training cache)
- hf:dataset:config:split (streaming from HuggingFace)
- Plain text files
"""
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from circuits.data import get_tokenizer
tokenizer = get_tokenizer(tokenizer_name)
# Memmap binary file (already tokenized)
if data_source.endswith(".bin"):
import struct
with open(data_source, 'rb') as f:
n_chunks, seq_len = struct.unpack('II', f.read(8))
data = np.memmap(data_source, dtype=np.int32, mode='r',
offset=8, shape=(n_chunks, seq_len))
n = min(num_samples, n_chunks)
# Slice to requested context length
cl = min(context_length, seq_len)
input_ids = torch.from_numpy(data[:n, :cl].copy()).long().to(device)
return input_ids, tokenizer
# HuggingFace dataset
if data_source.startswith("hf:"):
from datasets import load_dataset
parts = data_source[3:].split(":")
ds_name = parts[0]
ds_config = parts[1] if len(parts) > 1 else None
ds_split = parts[2] if len(parts) > 2 else "train"
dataset = load_dataset(ds_name, ds_config, split=ds_split, streaming=True)
all_ids = []
for item in dataset:
text = item.get("text", "")
if len(text) < 100:
continue
ids = tokenizer.encode(text)
if len(ids) >= context_length:
all_ids.append(ids[:context_length])
if len(all_ids) >= num_samples:
break
if not all_ids:
return None, tokenizer
return torch.tensor(all_ids, device=device), tokenizer
# Plain text file
with open(data_source) as f:
texts = [line.strip() for line in f if len(line.strip()) > 100]
all_ids = []
for text in texts:
ids = tokenizer.encode(text)
if len(ids) >= context_length:
all_ids.append(ids[:context_length])
if len(all_ids) >= num_samples:
break
if not all_ids:
return None, tokenizer
return torch.tensor(all_ids, device=device), tokenizer
def tokenize_for_hf(texts: list, model_name: str, context_length: int = 512,
device: str = "cpu"):
"""Tokenize texts for an HF model. Returns (input_ids, tokenizer)."""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name,
use_fast=False,
trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
all_ids = []
for text in texts:
ids = tokenizer.encode(text, max_length=context_length, truncation=True)
if len(ids) >= context_length:
all_ids.append(ids[:context_length])
elif len(ids) > 32:
all_ids.append(ids + [tokenizer.eos_token_id] * (context_length - len(ids)))
if not all_ids:
return None, tokenizer
return torch.tensor(all_ids, device=device), tokenizer
# ---------------------------------------------------------------------------
# Activation collection
# ---------------------------------------------------------------------------
def collect_mirrored_activations(model, input_ids, word_positions=None):
"""Collect activations from MirroredTransformer at every processing stage."""
activations = OrderedDict()
with torch.no_grad():
x = model.embed(input_ids)
if model.embed_proj is not None:
if model.embed_g3 is not None:
g4 = F.silu(model.embed_g4(x))
g3 = F.silu(model.embed_g3(x) * g4)
x = model.embed_proj(x) * g3
else:
x = F.silu(model.embed_proj(x))
x = x * model.embed_scale
activations["embedding"] = x.detach().cpu()
for i, block in enumerate(model.mirror_blocks):
x, _ = block(x, word_positions=word_positions)
activations[f"expand_{i}"] = x.detach().cpu()
for i, block in enumerate(model.middle_blocks):
x, _ = block(x, word_positions=word_positions)
activations[f"middle_{i}"] = x.detach().cpu()
for i in reversed(range(len(model.mirror_blocks))):
x, _ = model.mirror_blocks[i](x, word_positions=word_positions)
compress_idx = len(model.mirror_blocks) - 1 - i
activations[f"compress_{compress_idx}"] = x.detach().cpu()
x = model.norm(x)
activations["final_norm"] = x.detach().cpu()
return activations
def collect_standard_activations(model, input_ids, word_positions=None):
"""Collect activations from standard CircuitTransformer."""
activations = OrderedDict()
with torch.no_grad():
x = model.embed(input_ids)
if model.embed_proj is not None:
x = F.silu(model.embed_proj(x))
x = x * model.embed_scale
activations["embedding"] = x.detach().cpu()
for i, layer in enumerate(model.layers):
x, _ = layer(x, word_positions=word_positions)
activations[f"layer_{i}"] = x.detach().cpu()
x = model.norm(x)
activations["final_norm"] = x.detach().cpu()
return activations
def collect_hf_activations(model, input_ids):
"""Hook-based activation collection for HuggingFace models."""
activations = OrderedDict()
hooks = []
if hasattr(model, 'transformer'):
# GPT-2 style
blocks = model.transformer.h
embed = model.transformer.wte
final_norm = model.transformer.ln_f
elif hasattr(model, 'model'):
# Llama / Mistral style
blocks = model.model.layers
embed = model.model.embed_tokens
final_norm = model.model.norm
else:
raise ValueError(f"Unsupported HF model: {type(model)}")
def make_hook(name):
def hook_fn(module, input, output):
out = output[0] if isinstance(output, tuple) else output
activations[name] = out.detach().cpu()
return hook_fn
hooks.append(embed.register_forward_hook(make_hook("embedding")))
for i, block in enumerate(blocks):
hooks.append(block.register_forward_hook(make_hook(f"layer_{i}")))
hooks.append(final_norm.register_forward_hook(make_hook("final_norm")))
with torch.no_grad():
model(input_ids)
for h in hooks:
h.remove()
return activations
def collect_activations(model, model_type, config_dict, input_ids, device):
"""Dispatch to the right collector based on model type."""
word_positions = None
word_rope_dims = config_dict.get("word_rope_dims", 0) if config_dict else 0
if word_rope_dims > 0 and model_type in ("standard", "mirrored"):
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from circuits.data import get_tokenizer
from circuits.layers import build_word_start_table, compute_word_positions
tokenizer_name = config_dict.get("tokenizer_name", "gpt2")
# Try to get tokenizer from the model's config
tokenizer = get_tokenizer(tokenizer_name)
word_start_table = build_word_start_table(tokenizer, len(tokenizer)).to(device)
word_positions = compute_word_positions(input_ids, word_start_table)
if model_type == "mirrored":
return collect_mirrored_activations(model, input_ids, word_positions)
elif model_type == "standard":
return collect_standard_activations(model, input_ids, word_positions)
else:
return collect_hf_activations(model, input_ids)
# ---------------------------------------------------------------------------
# Linear CKA
# ---------------------------------------------------------------------------
def linear_cka(X: torch.Tensor, Y: torch.Tensor) -> float:
"""Compute linear CKA between two [N, D] representation matrices.
CKA(X, Y) = ||Yc^T Xc||_F^2 / (||Xc^T Xc||_F * ||Yc^T Yc||_F)
"""
X = X.float()
Y = Y.float()
# Center
X = X - X.mean(0, keepdim=True)
Y = Y - Y.mean(0, keepdim=True)
N = X.shape[0]
if N < min(X.shape[1], Y.shape[1]):
# Kernel formulation (N < D): K=XX^T, L=YY^T — [N,N] matrices
K = X @ X.T
L = Y @ Y.T
numerator = (K * L).sum()
denominator = torch.sqrt((K * K).sum() * (L * L).sum())
else:
# Feature formulation (D <= N)
XtY = X.T @ Y
XtX = X.T @ X
YtY = Y.T @ Y
numerator = (XtY * XtY).sum()
denominator = torch.sqrt((XtX * XtX).sum() * (YtY * YtY).sum())
if denominator < 1e-10:
return 0.0
return (numerator / denominator).item()
def compute_cka_matrix(activations: dict, subsample: int = 4) -> tuple:
"""Compute CKA between all layer pairs. Returns (cka_matrix, layer_names)."""
names = list(activations.keys())
n_layers = len(names)
# Flatten and subsample: [B, L, D] -> [N, D]
flat_acts = {}
for name, act in activations.items():
act_sub = act[:, ::subsample, :]
flat_acts[name] = act_sub.reshape(-1, act_sub.shape[-1])
cka_matrix = np.zeros((n_layers, n_layers))
for i in range(n_layers):
cka_matrix[i, i] = 1.0
for j in range(i + 1, n_layers):
cka_val = linear_cka(flat_acts[names[i]], flat_acts[names[j]])
cka_matrix[i, j] = cka_val
cka_matrix[j, i] = cka_val
if (i + 1) % 5 == 0 or i == n_layers - 1:
print(f" CKA: {i+1}/{n_layers} rows computed")
return cka_matrix, names
def compute_cross_model_cka(acts_a: dict, acts_b: dict) -> tuple:
"""Cross-model CKA using sample-level (avg-pooled) representations."""
names_a = list(acts_a.keys())
names_b = list(acts_b.keys())
def pool(activations):
return {name: act.mean(dim=1) for name, act in activations.items()}
pooled_a = pool(acts_a)
pooled_b = pool(acts_b)
# Ensure same number of samples
n_samples = min(
next(iter(pooled_a.values())).shape[0],
next(iter(pooled_b.values())).shape[0]
)
cka_matrix = np.zeros((len(names_a), len(names_b)))
for i, na in enumerate(names_a):
for j, nb in enumerate(names_b):
cka_matrix[i, j] = linear_cka(pooled_a[na][:n_samples], pooled_b[nb][:n_samples])
if (i + 1) % 5 == 0 or i == len(names_a) - 1:
print(f" Cross-CKA: {i+1}/{len(names_a)} rows computed")
return cka_matrix, names_a, names_b
# ---------------------------------------------------------------------------
# Logit Lens
# ---------------------------------------------------------------------------
def get_unembed_components(model, model_type):
"""Extract (norm_module, unembed_weight) for logit lens projection."""
if model_type in ("standard", "mirrored"):
return model.norm, model.embed.weight
elif hasattr(model, 'transformer'):
return model.transformer.ln_f, model.transformer.wte.weight
elif hasattr(model, 'model'):
return model.model.norm, model.model.embed_tokens.weight
else:
raise ValueError(f"Unsupported model: {type(model)}")
def compute_logit_lens(activations: dict, norm: nn.Module, unembed_weight: torch.Tensor,
labels: torch.Tensor, device: str = "cpu",
chunk_size: int = 2048) -> OrderedDict:
"""Compute logit lens statistics at every layer.
Projects intermediate hidden states through final norm + unembedding.
Computes entropy, top-1 probability, correct token rank, and
agreement with the final layer's predictions.
Args:
activations: OrderedDict[name] = [B, L, D]
norm: final layer norm module
unembed_weight: [V, D] unembedding matrix
labels: [B, L-1] next-token labels (input_ids[:, 1:])
device: computation device
chunk_size: number of positions per batch for projection
Returns:
OrderedDict[name] = {entropy, top1_prob, correct_rank, ...}
"""
names = list(activations.keys())
final_name = names[-1] # "final_norm"
results = OrderedDict()
unembed = unembed_weight.to(device)
norm_mod = norm.to(device)
labels_flat = labels.reshape(-1).to(device)
def process_layer(name, act, apply_norm=True):
"""Project one layer's activations and compute all metrics."""
B, L, D = act.shape
flat = act[:, :-1, :].reshape(-1, D) # [B*(L-1), D]
N = flat.shape[0]
all_entropy = []
all_top1_prob = []
all_correct_rank = []
all_top1_idx = []
for start in range(0, N, chunk_size):
end = min(start + chunk_size, N)
chunk = flat[start:end].to(device)
chunk_labels = labels_flat[start:end]
if apply_norm:
chunk = norm_mod(chunk)
logits = chunk @ unembed.T # [cs, V]
log_probs = F.log_softmax(logits, dim=-1)
probs = log_probs.exp()
# Entropy
entropy = -(probs * log_probs).sum(dim=-1)
all_entropy.append(entropy.cpu())
# Top-1 probability
top1_prob = probs.max(dim=-1).values
all_top1_prob.append(top1_prob.cpu())
# Correct token rank
correct_logits = logits.gather(1, chunk_labels.unsqueeze(1))
rank = (logits > correct_logits).sum(dim=-1) + 1
all_correct_rank.append(rank.cpu())
# Top-1 index
all_top1_idx.append(logits.argmax(dim=-1).cpu())
entropy_t = torch.cat(all_entropy)
top1_t = torch.cat(all_top1_prob)
rank_t = torch.cat(all_correct_rank).float()
top1_idx = torch.cat(all_top1_idx)
return {
"entropy": entropy_t.mean().item(),
"entropy_std": entropy_t.std().item(),
"top1_prob": top1_t.mean().item(),
"correct_rank_mean": rank_t.mean().item(),
"correct_rank_median": rank_t.median().item(),
"log_rank_mean": rank_t.log().mean().item(),
"_top1_idx": top1_idx,
}
# Process all layers
for name in names:
is_final = (name == final_name)
act = activations[name]
stats = process_layer(name, act, apply_norm=not is_final)
results[name] = stats
print(f" Logit lens: {name:20s} entropy={stats['entropy']:.2f} "
f"top1={stats['top1_prob']:.4f} rank={stats['correct_rank_median']:.0f}")
# Compute agreement with final layer
final_top1 = results[final_name]["_top1_idx"]
for name in names:
layer_top1 = results[name]["_top1_idx"]
agreement = (layer_top1 == final_top1).float().mean().item()
results[name]["agreement_with_final"] = agreement
# Clean up internal tensors
for name in names:
del results[name]["_top1_idx"]
return results
# ---------------------------------------------------------------------------
# Representation drift
# ---------------------------------------------------------------------------
def compute_drift(activations: dict) -> OrderedDict:
"""Cosine similarity between consecutive layers' representations."""
names = list(activations.keys())
drift = OrderedDict()
for i in range(1, len(names)):
prev = activations[names[i - 1]]
curr = activations[names[i]]
# Flatten to [N, D]
prev_flat = prev.reshape(-1, prev.shape[-1])
curr_flat = curr.reshape(-1, curr.shape[-1])
# Mean cosine similarity
cos = F.cosine_similarity(prev_flat, curr_flat, dim=-1)
drift[names[i]] = {
"cos_sim_mean": cos.mean().item(),
"cos_sim_std": cos.std().item(),
"l2_distance": (curr_flat - prev_flat).norm(dim=-1).mean().item(),
}
return drift
# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------
def _phase_color(name):
"""Return color based on layer phase."""
if "expand" in name:
return "steelblue"
elif "middle" in name:
return "goldenrod"
elif "compress" in name:
return "coral"
elif "embedding" in name:
return "gray"
elif "final" in name:
return "gray"
else:
return "mediumpurple"
def _layer_sort_key(name):
"""Sort key for processing order."""
order = {"embedding": -1, "final_norm": 9999}
if name in order:
return order[name]
parts = name.split("_")
phase = parts[0]
idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
phase_offset = {"expand": 0, "middle": 1000, "compress": 2000, "layer": 0}
return phase_offset.get(phase, 3000) + idx
def _short_name(name):
"""Shorten layer name for plot labels."""
if name == "embedding":
return "emb"
if name == "final_norm":
return "out"
parts = name.split("_")
if parts[0] == "expand":
return f"E{parts[1]}"
elif parts[0] == "middle":
return f"M{parts[1]}"
elif parts[0] == "compress":
return f"C{parts[1]}"
elif parts[0] == "layer":
return f"L{parts[1]}"
return name[:6]
def plot_cka_self(cka_matrix: np.ndarray, names: list, output_dir: Path,
model_label: str):
"""Plot self-CKA heatmap."""
n = len(names)
short = [_short_name(n) for n in names]
fig, ax = plt.subplots(figsize=(max(10, n * 0.35), max(8, n * 0.3)))
fig.suptitle(f"{model_label} -- CKA Self-Similarity", fontsize=14)
im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="equal")
# Phase separators
for i, name in enumerate(names):
if i > 0:
prev = names[i - 1].split("_")[0]
curr = name.split("_")[0]
if prev != curr:
ax.axhline(i - 0.5, color="white", linewidth=1.5, alpha=0.8)
ax.axvline(i - 0.5, color="white", linewidth=1.5, alpha=0.8)
ax.set_xticks(range(n))
ax.set_xticklabels(short, rotation=90, fontsize=7)
ax.set_yticks(range(n))
ax.set_yticklabels(short, fontsize=7)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA")
plt.tight_layout()
fig.savefig(output_dir / "cka_self.png", dpi=150)
plt.close(fig)
def plot_cka_cross(cka_matrix: np.ndarray, names_a: list, names_b: list,
output_dir: Path, label_a: str, label_b: str):
"""Plot cross-model CKA heatmap."""
short_a = [_short_name(n) for n in names_a]
short_b = [_short_name(n) for n in names_b]
na, nb = len(names_a), len(names_b)
fig, ax = plt.subplots(figsize=(max(10, nb * 0.35), max(8, na * 0.3)))
fig.suptitle(f"Cross-CKA: {label_a} vs {label_b}", fontsize=14)
im = ax.imshow(cka_matrix, cmap="inferno", vmin=0, vmax=1, aspect="auto")
ax.set_xticks(range(nb))
ax.set_xticklabels(short_b, rotation=90, fontsize=7)
ax.set_xlabel(label_b)
ax.set_yticks(range(na))
ax.set_yticklabels(short_a, fontsize=7)
ax.set_ylabel(label_a)
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="CKA")
plt.tight_layout()
fig.savefig(output_dir / "cka_cross.png", dpi=150)
plt.close(fig)
def plot_logit_lens(lens_results: OrderedDict, output_dir: Path,
model_label: str):
"""Plot logit lens summary: entropy, confidence, rank, agreement."""
names = list(lens_results.keys())
sorted_names = sorted(names, key=_layer_sort_key)
short = [_short_name(n) for n in sorted_names]
colors = [_phase_color(n) for n in sorted_names]
x = range(len(sorted_names))
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
fig.suptitle(f"{model_label} -- Logit Lens", fontsize=14)
# Entropy
vals = [lens_results[n]["entropy"] for n in sorted_names]
axes[0, 0].bar(x, vals, color=colors, alpha=0.85)
axes[0, 0].set_ylabel("Entropy (nats)")
axes[0, 0].set_title("Prediction entropy per layer")
axes[0, 0].set_xticks(x)
axes[0, 0].set_xticklabels(short, rotation=90, fontsize=7)
# Top-1 probability
vals = [lens_results[n]["top1_prob"] for n in sorted_names]
axes[0, 1].bar(x, vals, color=colors, alpha=0.85)
axes[0, 1].set_ylabel("Top-1 probability")
axes[0, 1].set_title("Prediction confidence per layer")
axes[0, 1].set_xticks(x)
axes[0, 1].set_xticklabels(short, rotation=90, fontsize=7)
# Correct rank (log scale)
vals = [lens_results[n]["correct_rank_median"] for n in sorted_names]
axes[1, 0].bar(x, vals, color=colors, alpha=0.85)
axes[1, 0].set_ylabel("Median rank of correct token")
axes[1, 0].set_yscale("log")
axes[1, 0].set_title("When does the model find the answer?")
axes[1, 0].set_xticks(x)
axes[1, 0].set_xticklabels(short, rotation=90, fontsize=7)
# Agreement with final layer
vals = [lens_results[n]["agreement_with_final"] for n in sorted_names]
axes[1, 1].bar(x, vals, color=colors, alpha=0.85)
axes[1, 1].set_ylabel("Agreement with final layer")
axes[1, 1].set_title("Convergence toward final prediction")
axes[1, 1].set_ylim(0, 1.05)
axes[1, 1].set_xticks(x)
axes[1, 1].set_xticklabels(short, rotation=90, fontsize=7)
plt.tight_layout()
fig.savefig(output_dir / "logit_lens_summary.png", dpi=150)
plt.close(fig)
def plot_logit_lens_trajectory(activations: dict, norm: nn.Module,
unembed_weight: torch.Tensor, input_ids: torch.Tensor,
tokenizer, output_dir: Path, model_label: str,
device: str = "cpu",
n_positions: int = 6, n_layers: int = 10):
"""Show top-5 predicted tokens at selected layers for a few positions.
Picks positions spread across the first sample and shows how the
model's prediction evolves through the network.
"""
names = sorted(activations.keys(), key=_layer_sort_key)
# Select layers evenly spread across the network
if len(names) > n_layers:
indices = np.linspace(0, len(names) - 1, n_layers, dtype=int)
selected_layers = [names[i] for i in indices]
else:
selected_layers = names
# Select positions from the first sample
seq_len = input_ids.shape[1]
pos_indices = np.linspace(10, seq_len - 2, n_positions, dtype=int)
unembed = unembed_weight.to(device)
norm_mod = norm.to(device)
final_name = names[-1]
fig, axes = plt.subplots(n_positions, 1, figsize=(14, 3 * n_positions))
if n_positions == 1:
axes = [axes]
fig.suptitle(f"{model_label} -- Token prediction trajectory", fontsize=14, y=1.02)
for pos_idx, pos in enumerate(pos_indices):
ax = axes[pos_idx]
actual_token = tokenizer.decode([input_ids[0, pos + 1].item()])
context = tokenizer.decode(input_ids[0, max(0, pos - 5):pos + 1].tolist())
layer_labels = []
top_tokens_per_layer = []
for name in selected_layers:
is_final = (name == final_name)
hidden = activations[name][0, pos:pos + 1, :].to(device) # [1, D]
if not is_final:
hidden = norm_mod(hidden)
logits = (hidden @ unembed.T).squeeze(0) # [V]
probs = F.softmax(logits, dim=-1)
top5_vals, top5_idx = probs.topk(5)
tokens_str = []
for val, idx in zip(top5_vals, top5_idx):
tok = tokenizer.decode([idx.item()]).replace("\n", "\\n")
tokens_str.append(f"{tok}({val:.2f})")
layer_labels.append(_short_name(name))
top_tokens_per_layer.append("\n".join(tokens_str))
# Create a text table
ax.set_xlim(-0.5, len(layer_labels) - 0.5)
ax.set_ylim(-0.5, 5.5)
ax.set_xticks(range(len(layer_labels)))
ax.set_xticklabels(layer_labels, fontsize=8)
ax.set_yticks([])
for li, tokens_str in enumerate(top_tokens_per_layer):
lines = tokens_str.split("\n")
for rank, line in enumerate(lines):
color = "darkgreen" if actual_token.strip() in line else "black"
fontweight = "bold" if actual_token.strip() in line else "normal"
ax.text(li, rank, line, ha="center", va="center", fontsize=7,
color=color, fontweight=fontweight)
ax.set_title(f'pos {pos}: "...{context}" -> [{actual_token.strip()}]',
fontsize=9, loc="left")
ax.invert_yaxis()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
plt.tight_layout()
fig.savefig(output_dir / "logit_lens_trajectory.png", dpi=150, bbox_inches="tight")
plt.close(fig)
def plot_drift(drift: OrderedDict, output_dir: Path, model_label: str):
"""Plot representation drift between consecutive layers."""
names = list(drift.keys())
sorted_names = sorted(names, key=_layer_sort_key)
short = [_short_name(n) for n in sorted_names]
colors = [_phase_color(n) for n in sorted_names]
x = range(len(sorted_names))
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle(f"{model_label} -- Representation drift", fontsize=14)
# Cosine similarity with previous layer
vals = [drift[n]["cos_sim_mean"] for n in sorted_names]
axes[0].bar(x, vals, color=colors, alpha=0.85)
axes[0].set_ylabel("Cosine similarity with previous layer")
axes[0].set_title("How much each layer preserves direction")
axes[0].set_xticks(x)
axes[0].set_xticklabels(short, rotation=90, fontsize=7)
# L2 distance
vals = [drift[n]["l2_distance"] for n in sorted_names]
axes[1].bar(x, vals, color=colors, alpha=0.85)
axes[1].set_ylabel("L2 distance from previous layer")
axes[1].set_title("How much each layer changes magnitude")
axes[1].set_xticks(x)
axes[1].set_xticklabels(short, rotation=90, fontsize=7)
plt.tight_layout()
fig.savefig(output_dir / "representation_drift.png", dpi=150)
plt.close(fig)
# ---------------------------------------------------------------------------
# Results saving
# ---------------------------------------------------------------------------
def save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir):
"""Save all numerical results to JSON."""
out = {}
if cka_matrix is not None:
out["cka_self"] = {
"names": cka_names,
"matrix": cka_matrix.tolist(),
}
if lens_results:
out["logit_lens"] = {name: data for name, data in lens_results.items()}
if drift:
out["drift"] = {name: data for name, data in drift.items()}
if cross_cka is not None:
matrix, names_a, names_b = cross_cka
out["cka_cross"] = {
"names_a": names_a,
"names_b": names_b,
"matrix": matrix.tolist(),
}
with open(output_dir / "results.json", "w") as f:
json.dump(out, f, indent=2, default=str)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="CKA and Logit Lens analysis for Prisma / Circuit Transformer")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to Prisma/Circuit checkpoint")
parser.add_argument("--checkpoint-b", type=str, default=None,
help="Second Prisma checkpoint for cross-model CKA")
parser.add_argument("--hf-model", type=str, default=None,
help="HuggingFace model for cross-model CKA (e.g. gpt2-medium)")
parser.add_argument("--data", type=str, required=True,
help="Data source (hf:dataset:config:split or file path)")
parser.add_argument("--num-samples", type=int, default=32,
help="Number of text samples (default: 32)")
parser.add_argument("--context-length", type=int, default=512,
help="Sequence length (default: 512)")
parser.add_argument("--cka-subsample", type=int, default=4,
help="Position subsampling for CKA (default: 4)")
parser.add_argument("--no-logit-lens", action="store_true",
help="Skip logit lens analysis")
parser.add_argument("--no-cka", action="store_true",
help="Skip CKA analysis")
parser.add_argument("--output-dir", type=str, default=None,
help="Output directory (default: auto)")
parser.add_argument("--gpu", type=int, default=0, help="GPU index")
args = parser.parse_args()
device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
# Output directory
if args.output_dir:
output_dir = Path(args.output_dir)
else:
ckpt_name = Path(args.checkpoint).parent.name
output_dir = Path("circuits/scripts/representation_output") / ckpt_name
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Output: {output_dir}")
# === Load model A ===
print(f"\nLoading: {args.checkpoint}")
model_a, config_a, model_type_a = load_prisma_model(args.checkpoint, device)
label_a = Path(args.checkpoint).parent.name
n_params = sum(p.numel() for p in model_a.parameters())
print(f" Type: {model_type_a}, params: {n_params:,}")
# === Load data ===
ckpt_data = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
tokenizer_name = ckpt_data.get("tokenizer_name", config_a.get("tokenizer_name", "gpt2"))
del ckpt_data
print(f"\nLoading data ({args.num_samples} samples, ctx={args.context_length})...")
result = load_data(
args.data, tokenizer_name, args.num_samples, args.context_length, device
)
if result[0] is None:
print("ERROR: No valid samples loaded.")
return
input_ids, tokenizer = result
print(f" Data shape: {input_ids.shape}")
# === Collect activations (model A) ===
print(f"\nCollecting activations ({model_type_a})...")
acts_a = collect_activations(model_a, model_type_a, config_a, input_ids, device)
print(f" Collected {len(acts_a)} layers")
# Free GPU memory
del model_a
if device.startswith("cuda"):
torch.cuda.empty_cache()
# === CKA (self) ===
cka_matrix = None
cka_names = None
if not args.no_cka:
print(f"\nComputing self-CKA (subsample={args.cka_subsample})...")
cka_matrix, cka_names = compute_cka_matrix(acts_a, subsample=args.cka_subsample)
plot_cka_self(cka_matrix, cka_names, output_dir, label_a)
print(f" Saved: cka_self.png")
# === Cross-model CKA ===
cross_cka = None
if not args.no_cka and (args.checkpoint_b or args.hf_model):
if args.checkpoint_b:
print(f"\nLoading comparison: {args.checkpoint_b}")
model_b, config_b, model_type_b = load_prisma_model(args.checkpoint_b, device)
label_b = Path(args.checkpoint_b).parent.name
acts_b = collect_activations(model_b, model_type_b, config_b, input_ids, device)
del model_b
else:
print(f"\nLoading HF model: {args.hf_model}")
model_b = load_hf_model(args.hf_model, device)
label_b = args.hf_model
# Decode texts from our tokens and re-tokenize for HF model
print(f" Re-tokenizing for {args.hf_model}...")
raw_texts = [tokenizer.decode(input_ids[i].tolist()) for i in range(input_ids.shape[0])]
input_ids_b, _ = tokenize_for_hf(
raw_texts, args.hf_model, args.context_length, device
)
if input_ids_b is not None:
print(f" HF data shape: {input_ids_b.shape}")
acts_b = collect_hf_activations(model_b, input_ids_b)
else:
acts_b = None
del model_b
if device.startswith("cuda"):
torch.cuda.empty_cache()
if acts_b:
print(f"\nComputing cross-model CKA...")
cross_matrix, cross_names_a, cross_names_b = compute_cross_model_cka(acts_a, acts_b)
cross_cka = (cross_matrix, cross_names_a, cross_names_b)
plot_cka_cross(cross_matrix, cross_names_a, cross_names_b,
output_dir, label_a, label_b)
print(f" Saved: cka_cross.png")
del acts_b
# === Logit lens ===
lens_results = None
if not args.no_logit_lens:
# Reload model for unembedding components (we deleted it for memory)
print(f"\nReloading model for logit lens...")
model_a, _, _ = load_prisma_model(args.checkpoint, device)
norm, unembed_weight = get_unembed_components(model_a, model_type_a)
labels = input_ids[:, 1:].cpu() # next-token labels
print(f"Computing logit lens...")
lens_results = compute_logit_lens(acts_a, norm, unembed_weight, labels, device)
plot_logit_lens(lens_results, output_dir, label_a)
print(f" Saved: logit_lens_summary.png")
# Token trajectory visualization
print(f" Generating token trajectories...")
plot_logit_lens_trajectory(
acts_a, norm, unembed_weight, input_ids.cpu(), tokenizer,
output_dir, label_a, device
)
print(f" Saved: logit_lens_trajectory.png")
del model_a
if device.startswith("cuda"):
torch.cuda.empty_cache()
# === Representation drift ===
print(f"\nComputing representation drift...")
drift = compute_drift(acts_a)
plot_drift(drift, output_dir, label_a)
print(f" Saved: representation_drift.png")
# === Save results ===
save_results(cka_matrix, cka_names, lens_results, drift, cross_cka, output_dir)
print(f"\nAll outputs saved to: {output_dir}")
n_plots = len(list(output_dir.glob("*.png")))
print(f" Plots: {n_plots} PNG files")
print(f" Data: results.json")
if __name__ == "__main__":
main()