""" Direction extraction (rewritten Apr 2026). Two versions kept: v1_raw — single mean-diff direction (D,) v_pca_subspace — top-k subspace from inter-class scatter PCA (k, D) Earlier v2_ortho_general / v3_ortho_crossdim / v4_pca were removed because: - v2/v3 had cosine > 0.95 to v1 in v1 results (no signal added) - v4 was conceptually wrong (PCA over all decision points, not over the plan-vs-exec contrast) The new v_pca_subspace performs PCA on the **inter-class scatter**: S_b = sum_c (mu_c - mu) (mu_c - mu)^T where c ∈ {plan, exec}. Top-k eigenvectors form a k-D subspace capturing the directions of largest plan-vs-exec variation. Steering with this subspace: h_new = h - (1 - alpha) · Q^T Q · h where Q ∈ R^(k × D) is row-orthonormal. """ import torch import numpy as np from typing import Dict, List, Optional def _safe_normalize(v: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: n = v.norm(dim=-1, keepdim=True).clamp(min=eps) return v / n # ============================================================ # v1_raw: single direction mean-diff # ============================================================ def compute_mean_diff( plan_acts_per_layer: Dict[int, torch.Tensor], exec_acts_per_layer: Dict[int, torch.Tensor], ) -> Dict[int, torch.Tensor]: """ v1: raw mean-diff per layer. Returns {layer_id: (D,) float32 direction (NOT normalized)}. """ directions = {} for li in plan_acts_per_layer: h_plan = plan_acts_per_layer[li].to(torch.float32) h_exec = exec_acts_per_layer[li].to(torch.float32) if h_plan.shape[0] == 0 or h_exec.shape[0] == 0: directions[li] = torch.zeros(h_plan.shape[1] if h_plan.shape[0] else (h_exec.shape[1] if h_exec.shape[0] else 0)) continue mu_plan = h_plan.mean(dim=0) mu_exec = h_exec.mean(dim=0) directions[li] = mu_plan - mu_exec return directions # ============================================================ # v_pca_subspace: top-k PCA on plan-vs-exec inter-class structure # ============================================================ def compute_pca_subspace( plan_acts_per_layer: Dict[int, torch.Tensor], exec_acts_per_layer: Dict[int, torch.Tensor], k: int = 3, ) -> Dict[int, torch.Tensor]: """ For each layer, compute a top-k subspace basis Q ∈ R^(k × D) capturing the directions of largest variation between plan and exec activations. Approach: build a balanced "contrast set" — for each plan token, sample the same number of exec tokens. Center each class to its own mean, then take PCA on the centered union (variance within both classes around their respective means; the leading eigenvectors capture the axes along which the two classes most differ). More principled than fitting plain PCA on union — this version emphasises the inter-class scatter direction. Returns: {layer_id: (k, D) row-orthonormal basis} """ bases = {} for li in plan_acts_per_layer: H_plan = plan_acts_per_layer[li].to(torch.float32) H_exec = exec_acts_per_layer[li].to(torch.float32) if H_plan.shape[0] == 0 or H_exec.shape[0] == 0: D = H_plan.shape[1] if H_plan.shape[0] else (H_exec.shape[1] if H_exec.shape[0] else 0) bases[li] = torch.zeros(0, D) continue D = H_plan.shape[1] # Balance classes — sample same number from larger class n_plan, n_exec = H_plan.shape[0], H_exec.shape[0] n_use = min(n_plan, n_exec) if n_plan > n_use: idx = torch.randperm(n_plan)[:n_use] H_plan = H_plan[idx] if n_exec > n_use: idx = torch.randperm(n_exec)[:n_use] H_exec = H_exec[idx] # Class means mu_plan = H_plan.mean(dim=0) # (D,) mu_exec = H_exec.mean(dim=0) # Inter-class scatter contribution: signed displacement of each sample # from the OPPOSITE class mean. This captures inter-class variance. # Fisher-style: project each sample onto axis spanning the two means, # but extract a k-D subspace via eigendecomposition. # # Build M = [H_plan - mu_exec ; H_exec - mu_plan] (2N, D) # then Cov(M) has top eigenvectors aligned with directions of # plan-vs-exec separation (not with within-class noise). M_plan = H_plan - mu_exec.unsqueeze(0) # (n_use, D) M_exec = H_exec - mu_plan.unsqueeze(0) # (n_use, D) M = torch.cat([M_plan, M_exec], dim=0) # (2N, D) # Center M overall M = M - M.mean(dim=0, keepdim=True) # SVD: M = U S V^T, top rows of V^T are the desired basis n_comp = min(k, M.shape[0] - 1, D) if n_comp <= 0: bases[li] = torch.zeros(0, D) continue try: U, S, Vt = torch.linalg.svd(M, full_matrices=False) Q = Vt[:n_comp] # (n_comp, D) except Exception: cov = (M.T @ M) / max(M.shape[0] - 1, 1) eigvals, eigvecs = torch.linalg.eigh(cov) idx = torch.argsort(eigvals, descending=True) Q = eigvecs[:, idx[:n_comp]].T # Row-orthonormalize defensively (already orthonormal from SVD, but...) Q = _row_orthonormalize(Q) bases[li] = Q return bases def _row_orthonormalize(Q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: """Gram-Schmidt row-orthonormalization.""" if Q.shape[0] == 0: return Q out = [] for i in range(Q.shape[0]): v = Q[i].clone() for u in out: v = v - (v @ u) * u n = v.norm() if n < eps: continue out.append(v / n) if not out: return torch.zeros(0, Q.shape[1]) return torch.stack(out, dim=0) # ============================================================ # Normalization # ============================================================ def normalize_directions( directions: Dict[int, torch.Tensor], ) -> Dict[int, torch.Tensor]: """Return unit vectors. Works for (D,) tensors (single dir). For (k, D) bases, returns row-orthonormal (already from SVD).""" out = {} for li, w in directions.items(): w32 = w.to(torch.float32) if w32.dim() == 1: if w32.norm() < 1e-8: out[li] = w32 else: out[li] = _safe_normalize(w32) elif w32.dim() == 2: out[li] = _row_orthonormalize(w32) else: out[li] = w32 return out # ============================================================ # Save / load # ============================================================ def save_directions(directions: Dict[int, torch.Tensor], path): torch.save({str(li): w for li, w in directions.items()}, path) def load_directions(path) -> Dict[int, torch.Tensor]: raw = torch.load(path, map_location="cpu") return {int(k): v for k, v in raw.items()} # ============================================================ # Cosine analysis # ============================================================ def compute_cosine_similarity_matrix( dirs_dict: Dict[str, Dict[int, torch.Tensor]] ) -> Dict[str, Dict[int, float]]: """ For (D,) directions: cosine between unit vectors. For (k, D) bases: principal angle (smallest angle between subspaces). Returns {(v1, v2): {layer: cos}} per layer. """ versions = list(dirs_dict.keys()) out = {} for i, v1 in enumerate(versions): for v2 in versions[i:]: key = f"{v1}__VS__{v2}" per_layer = {} for li in dirs_dict[v1]: if li not in dirs_dict[v2]: continue a = dirs_dict[v1][li].to(torch.float32) b = dirs_dict[v2][li].to(torch.float32) per_layer[li] = _subspace_cosine(a, b) out[key] = per_layer return out def _subspace_cosine(a: torch.Tensor, b: torch.Tensor) -> float: """ Cosine for (D,) directions or principal-angle cosine for (k,D) bases. """ if a.numel() == 0 or b.numel() == 0: return 0.0 if a.dim() == 1 and b.dim() == 1: if a.norm() < 1e-8 or b.norm() < 1e-8: return 0.0 return float((a @ b) / (a.norm() * b.norm())) # Subspace case: largest singular value of A^T B (where A, B are row-orthonormal) A = a if a.dim() == 2 else a.unsqueeze(0) B = b if b.dim() == 2 else b.unsqueeze(0) if A.shape[0] == 0 or B.shape[0] == 0: return 0.0 # row-normalize A and B in case A = _row_orthonormalize(A) B = _row_orthonormalize(B) if A.shape[0] == 0 or B.shape[0] == 0: return 0.0 M = A @ B.T # (k_a, k_b) s = torch.linalg.svdvals(M) return float(s[0]) if s.numel() > 0 else 0.0