File size: 8,927 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | """
Direction extraction (rewritten Apr 2026).
Two versions kept:
v1_raw — single mean-diff direction (D,)
v_pca_subspace — top-k subspace from inter-class scatter PCA (k, D)
Earlier v2_ortho_general / v3_ortho_crossdim / v4_pca were removed because:
- v2/v3 had cosine > 0.95 to v1 in v1 results (no signal added)
- v4 was conceptually wrong (PCA over all decision points, not over the
plan-vs-exec contrast)
The new v_pca_subspace performs PCA on the **inter-class scatter**:
S_b = sum_c (mu_c - mu) (mu_c - mu)^T
where c ∈ {plan, exec}. Top-k eigenvectors form a k-D subspace capturing
the directions of largest plan-vs-exec variation.
Steering with this subspace:
h_new = h - (1 - alpha) · Q^T Q · h
where Q ∈ R^(k × D) is row-orthonormal.
"""
import torch
import numpy as np
from typing import Dict, List, Optional
def _safe_normalize(v: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
n = v.norm(dim=-1, keepdim=True).clamp(min=eps)
return v / n
# ============================================================
# v1_raw: single direction mean-diff
# ============================================================
def compute_mean_diff(
plan_acts_per_layer: Dict[int, torch.Tensor],
exec_acts_per_layer: Dict[int, torch.Tensor],
) -> Dict[int, torch.Tensor]:
"""
v1: raw mean-diff per layer.
Returns {layer_id: (D,) float32 direction (NOT normalized)}.
"""
directions = {}
for li in plan_acts_per_layer:
h_plan = plan_acts_per_layer[li].to(torch.float32)
h_exec = exec_acts_per_layer[li].to(torch.float32)
if h_plan.shape[0] == 0 or h_exec.shape[0] == 0:
directions[li] = torch.zeros(h_plan.shape[1] if h_plan.shape[0] else
(h_exec.shape[1] if h_exec.shape[0] else 0))
continue
mu_plan = h_plan.mean(dim=0)
mu_exec = h_exec.mean(dim=0)
directions[li] = mu_plan - mu_exec
return directions
# ============================================================
# v_pca_subspace: top-k PCA on plan-vs-exec inter-class structure
# ============================================================
def compute_pca_subspace(
plan_acts_per_layer: Dict[int, torch.Tensor],
exec_acts_per_layer: Dict[int, torch.Tensor],
k: int = 3,
) -> Dict[int, torch.Tensor]:
"""
For each layer, compute a top-k subspace basis Q ∈ R^(k × D) capturing the
directions of largest variation between plan and exec activations.
Approach: build a balanced "contrast set" — for each plan token, sample
the same number of exec tokens. Center each class to its own mean, then
take PCA on the centered union (variance within both classes around
their respective means; the leading eigenvectors capture the axes along
which the two classes most differ).
More principled than fitting plain PCA on union — this version emphasises
the inter-class scatter direction.
Returns:
{layer_id: (k, D) row-orthonormal basis}
"""
bases = {}
for li in plan_acts_per_layer:
H_plan = plan_acts_per_layer[li].to(torch.float32)
H_exec = exec_acts_per_layer[li].to(torch.float32)
if H_plan.shape[0] == 0 or H_exec.shape[0] == 0:
D = H_plan.shape[1] if H_plan.shape[0] else (H_exec.shape[1] if H_exec.shape[0] else 0)
bases[li] = torch.zeros(0, D)
continue
D = H_plan.shape[1]
# Balance classes — sample same number from larger class
n_plan, n_exec = H_plan.shape[0], H_exec.shape[0]
n_use = min(n_plan, n_exec)
if n_plan > n_use:
idx = torch.randperm(n_plan)[:n_use]
H_plan = H_plan[idx]
if n_exec > n_use:
idx = torch.randperm(n_exec)[:n_use]
H_exec = H_exec[idx]
# Class means
mu_plan = H_plan.mean(dim=0) # (D,)
mu_exec = H_exec.mean(dim=0)
# Inter-class scatter contribution: signed displacement of each sample
# from the OPPOSITE class mean. This captures inter-class variance.
# Fisher-style: project each sample onto axis spanning the two means,
# but extract a k-D subspace via eigendecomposition.
#
# Build M = [H_plan - mu_exec ; H_exec - mu_plan] (2N, D)
# then Cov(M) has top eigenvectors aligned with directions of
# plan-vs-exec separation (not with within-class noise).
M_plan = H_plan - mu_exec.unsqueeze(0) # (n_use, D)
M_exec = H_exec - mu_plan.unsqueeze(0) # (n_use, D)
M = torch.cat([M_plan, M_exec], dim=0) # (2N, D)
# Center M overall
M = M - M.mean(dim=0, keepdim=True)
# SVD: M = U S V^T, top rows of V^T are the desired basis
n_comp = min(k, M.shape[0] - 1, D)
if n_comp <= 0:
bases[li] = torch.zeros(0, D)
continue
try:
U, S, Vt = torch.linalg.svd(M, full_matrices=False)
Q = Vt[:n_comp] # (n_comp, D)
except Exception:
cov = (M.T @ M) / max(M.shape[0] - 1, 1)
eigvals, eigvecs = torch.linalg.eigh(cov)
idx = torch.argsort(eigvals, descending=True)
Q = eigvecs[:, idx[:n_comp]].T
# Row-orthonormalize defensively (already orthonormal from SVD, but...)
Q = _row_orthonormalize(Q)
bases[li] = Q
return bases
def _row_orthonormalize(Q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""Gram-Schmidt row-orthonormalization."""
if Q.shape[0] == 0:
return Q
out = []
for i in range(Q.shape[0]):
v = Q[i].clone()
for u in out:
v = v - (v @ u) * u
n = v.norm()
if n < eps:
continue
out.append(v / n)
if not out:
return torch.zeros(0, Q.shape[1])
return torch.stack(out, dim=0)
# ============================================================
# Normalization
# ============================================================
def normalize_directions(
directions: Dict[int, torch.Tensor],
) -> Dict[int, torch.Tensor]:
"""Return unit vectors. Works for (D,) tensors (single dir).
For (k, D) bases, returns row-orthonormal (already from SVD)."""
out = {}
for li, w in directions.items():
w32 = w.to(torch.float32)
if w32.dim() == 1:
if w32.norm() < 1e-8:
out[li] = w32
else:
out[li] = _safe_normalize(w32)
elif w32.dim() == 2:
out[li] = _row_orthonormalize(w32)
else:
out[li] = w32
return out
# ============================================================
# Save / load
# ============================================================
def save_directions(directions: Dict[int, torch.Tensor], path):
torch.save({str(li): w for li, w in directions.items()}, path)
def load_directions(path) -> Dict[int, torch.Tensor]:
raw = torch.load(path, map_location="cpu")
return {int(k): v for k, v in raw.items()}
# ============================================================
# Cosine analysis
# ============================================================
def compute_cosine_similarity_matrix(
dirs_dict: Dict[str, Dict[int, torch.Tensor]]
) -> Dict[str, Dict[int, float]]:
"""
For (D,) directions: cosine between unit vectors.
For (k, D) bases: principal angle (smallest angle between subspaces).
Returns {(v1, v2): {layer: cos}} per layer.
"""
versions = list(dirs_dict.keys())
out = {}
for i, v1 in enumerate(versions):
for v2 in versions[i:]:
key = f"{v1}__VS__{v2}"
per_layer = {}
for li in dirs_dict[v1]:
if li not in dirs_dict[v2]:
continue
a = dirs_dict[v1][li].to(torch.float32)
b = dirs_dict[v2][li].to(torch.float32)
per_layer[li] = _subspace_cosine(a, b)
out[key] = per_layer
return out
def _subspace_cosine(a: torch.Tensor, b: torch.Tensor) -> float:
"""
Cosine for (D,) directions or principal-angle cosine for (k,D) bases.
"""
if a.numel() == 0 or b.numel() == 0:
return 0.0
if a.dim() == 1 and b.dim() == 1:
if a.norm() < 1e-8 or b.norm() < 1e-8:
return 0.0
return float((a @ b) / (a.norm() * b.norm()))
# Subspace case: largest singular value of A^T B (where A, B are row-orthonormal)
A = a if a.dim() == 2 else a.unsqueeze(0)
B = b if b.dim() == 2 else b.unsqueeze(0)
if A.shape[0] == 0 or B.shape[0] == 0:
return 0.0
# row-normalize A and B in case
A = _row_orthonormalize(A)
B = _row_orthonormalize(B)
if A.shape[0] == 0 or B.shape[0] == 0:
return 0.0
M = A @ B.T # (k_a, k_b)
s = torch.linalg.svdvals(M)
return float(s[0]) if s.numel() > 0 else 0.0
|