v2 / src /directions.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Direction extraction (rewritten Apr 2026).
Two versions kept:
v1_raw — single mean-diff direction (D,)
v_pca_subspace — top-k subspace from inter-class scatter PCA (k, D)
Earlier v2_ortho_general / v3_ortho_crossdim / v4_pca were removed because:
- v2/v3 had cosine > 0.95 to v1 in v1 results (no signal added)
- v4 was conceptually wrong (PCA over all decision points, not over the
plan-vs-exec contrast)
The new v_pca_subspace performs PCA on the **inter-class scatter**:
S_b = sum_c (mu_c - mu) (mu_c - mu)^T
where c ∈ {plan, exec}. Top-k eigenvectors form a k-D subspace capturing
the directions of largest plan-vs-exec variation.
Steering with this subspace:
h_new = h - (1 - alpha) · Q^T Q · h
where Q ∈ R^(k × D) is row-orthonormal.
"""
import torch
import numpy as np
from typing import Dict, List, Optional
def _safe_normalize(v: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
n = v.norm(dim=-1, keepdim=True).clamp(min=eps)
return v / n
# ============================================================
# v1_raw: single direction mean-diff
# ============================================================
def compute_mean_diff(
plan_acts_per_layer: Dict[int, torch.Tensor],
exec_acts_per_layer: Dict[int, torch.Tensor],
) -> Dict[int, torch.Tensor]:
"""
v1: raw mean-diff per layer.
Returns {layer_id: (D,) float32 direction (NOT normalized)}.
"""
directions = {}
for li in plan_acts_per_layer:
h_plan = plan_acts_per_layer[li].to(torch.float32)
h_exec = exec_acts_per_layer[li].to(torch.float32)
if h_plan.shape[0] == 0 or h_exec.shape[0] == 0:
directions[li] = torch.zeros(h_plan.shape[1] if h_plan.shape[0] else
(h_exec.shape[1] if h_exec.shape[0] else 0))
continue
mu_plan = h_plan.mean(dim=0)
mu_exec = h_exec.mean(dim=0)
directions[li] = mu_plan - mu_exec
return directions
# ============================================================
# v_pca_subspace: top-k PCA on plan-vs-exec inter-class structure
# ============================================================
def compute_pca_subspace(
plan_acts_per_layer: Dict[int, torch.Tensor],
exec_acts_per_layer: Dict[int, torch.Tensor],
k: int = 3,
) -> Dict[int, torch.Tensor]:
"""
For each layer, compute a top-k subspace basis Q ∈ R^(k × D) capturing the
directions of largest variation between plan and exec activations.
Approach: build a balanced "contrast set" — for each plan token, sample
the same number of exec tokens. Center each class to its own mean, then
take PCA on the centered union (variance within both classes around
their respective means; the leading eigenvectors capture the axes along
which the two classes most differ).
More principled than fitting plain PCA on union — this version emphasises
the inter-class scatter direction.
Returns:
{layer_id: (k, D) row-orthonormal basis}
"""
bases = {}
for li in plan_acts_per_layer:
H_plan = plan_acts_per_layer[li].to(torch.float32)
H_exec = exec_acts_per_layer[li].to(torch.float32)
if H_plan.shape[0] == 0 or H_exec.shape[0] == 0:
D = H_plan.shape[1] if H_plan.shape[0] else (H_exec.shape[1] if H_exec.shape[0] else 0)
bases[li] = torch.zeros(0, D)
continue
D = H_plan.shape[1]
# Balance classes — sample same number from larger class
n_plan, n_exec = H_plan.shape[0], H_exec.shape[0]
n_use = min(n_plan, n_exec)
if n_plan > n_use:
idx = torch.randperm(n_plan)[:n_use]
H_plan = H_plan[idx]
if n_exec > n_use:
idx = torch.randperm(n_exec)[:n_use]
H_exec = H_exec[idx]
# Class means
mu_plan = H_plan.mean(dim=0) # (D,)
mu_exec = H_exec.mean(dim=0)
# Inter-class scatter contribution: signed displacement of each sample
# from the OPPOSITE class mean. This captures inter-class variance.
# Fisher-style: project each sample onto axis spanning the two means,
# but extract a k-D subspace via eigendecomposition.
#
# Build M = [H_plan - mu_exec ; H_exec - mu_plan] (2N, D)
# then Cov(M) has top eigenvectors aligned with directions of
# plan-vs-exec separation (not with within-class noise).
M_plan = H_plan - mu_exec.unsqueeze(0) # (n_use, D)
M_exec = H_exec - mu_plan.unsqueeze(0) # (n_use, D)
M = torch.cat([M_plan, M_exec], dim=0) # (2N, D)
# Center M overall
M = M - M.mean(dim=0, keepdim=True)
# SVD: M = U S V^T, top rows of V^T are the desired basis
n_comp = min(k, M.shape[0] - 1, D)
if n_comp <= 0:
bases[li] = torch.zeros(0, D)
continue
try:
U, S, Vt = torch.linalg.svd(M, full_matrices=False)
Q = Vt[:n_comp] # (n_comp, D)
except Exception:
cov = (M.T @ M) / max(M.shape[0] - 1, 1)
eigvals, eigvecs = torch.linalg.eigh(cov)
idx = torch.argsort(eigvals, descending=True)
Q = eigvecs[:, idx[:n_comp]].T
# Row-orthonormalize defensively (already orthonormal from SVD, but...)
Q = _row_orthonormalize(Q)
bases[li] = Q
return bases
def _row_orthonormalize(Q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""Gram-Schmidt row-orthonormalization."""
if Q.shape[0] == 0:
return Q
out = []
for i in range(Q.shape[0]):
v = Q[i].clone()
for u in out:
v = v - (v @ u) * u
n = v.norm()
if n < eps:
continue
out.append(v / n)
if not out:
return torch.zeros(0, Q.shape[1])
return torch.stack(out, dim=0)
# ============================================================
# Normalization
# ============================================================
def normalize_directions(
directions: Dict[int, torch.Tensor],
) -> Dict[int, torch.Tensor]:
"""Return unit vectors. Works for (D,) tensors (single dir).
For (k, D) bases, returns row-orthonormal (already from SVD)."""
out = {}
for li, w in directions.items():
w32 = w.to(torch.float32)
if w32.dim() == 1:
if w32.norm() < 1e-8:
out[li] = w32
else:
out[li] = _safe_normalize(w32)
elif w32.dim() == 2:
out[li] = _row_orthonormalize(w32)
else:
out[li] = w32
return out
# ============================================================
# Save / load
# ============================================================
def save_directions(directions: Dict[int, torch.Tensor], path):
torch.save({str(li): w for li, w in directions.items()}, path)
def load_directions(path) -> Dict[int, torch.Tensor]:
raw = torch.load(path, map_location="cpu")
return {int(k): v for k, v in raw.items()}
# ============================================================
# Cosine analysis
# ============================================================
def compute_cosine_similarity_matrix(
dirs_dict: Dict[str, Dict[int, torch.Tensor]]
) -> Dict[str, Dict[int, float]]:
"""
For (D,) directions: cosine between unit vectors.
For (k, D) bases: principal angle (smallest angle between subspaces).
Returns {(v1, v2): {layer: cos}} per layer.
"""
versions = list(dirs_dict.keys())
out = {}
for i, v1 in enumerate(versions):
for v2 in versions[i:]:
key = f"{v1}__VS__{v2}"
per_layer = {}
for li in dirs_dict[v1]:
if li not in dirs_dict[v2]:
continue
a = dirs_dict[v1][li].to(torch.float32)
b = dirs_dict[v2][li].to(torch.float32)
per_layer[li] = _subspace_cosine(a, b)
out[key] = per_layer
return out
def _subspace_cosine(a: torch.Tensor, b: torch.Tensor) -> float:
"""
Cosine for (D,) directions or principal-angle cosine for (k,D) bases.
"""
if a.numel() == 0 or b.numel() == 0:
return 0.0
if a.dim() == 1 and b.dim() == 1:
if a.norm() < 1e-8 or b.norm() < 1e-8:
return 0.0
return float((a @ b) / (a.norm() * b.norm()))
# Subspace case: largest singular value of A^T B (where A, B are row-orthonormal)
A = a if a.dim() == 2 else a.unsqueeze(0)
B = b if b.dim() == 2 else b.unsqueeze(0)
if A.shape[0] == 0 or B.shape[0] == 0:
return 0.0
# row-normalize A and B in case
A = _row_orthonormalize(A)
B = _row_orthonormalize(B)
if A.shape[0] == 0 or B.shape[0] == 0:
return 0.0
M = A @ B.T # (k_a, k_b)
s = torch.linalg.svdvals(M)
return float(s[0]) if s.numel() > 0 else 0.0