""" Stage 3 part B: Compute 2 versions of direction(s) per dimension. v1_raw - mean(plan) - mean(exec), single direction (D,) v_pca_subspace - top-k subspace from inter-class scatter PCA, basis (k, D) Removed: v2_ortho_general — empirically had cosine > 0.9 to v1 (no signal) v3_ortho_crossdim — same (>0.99 cosine to v2) v4_pca (old) — was conceptually wrong (PCA over union, not contrast) The new v_pca_subspace is the principled subspace approach: extracts the top-k directions of largest inter-class (plan-vs-exec) variation. """ import sys import argparse from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch import numpy as np from configs.paths import ( ensure_dirs, LOGS_DIR, RESIDUALS_PATH, GENERAL_RESIDUALS_PATH, GENERAL_DIR_PATH, PLAN_V1_RAW, MON_V1_RAW, PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE, CHECKPOINTS_DIR, RESULTS_DIR, DIRECTION_COSINE_MATRIX, ) from configs.model import PCA_SUBSPACE_K from src.utils import setup_logger, write_json from src.directions import ( compute_mean_diff, compute_pca_subspace, normalize_directions, compute_cosine_similarity_matrix, save_directions, load_directions, ) def plot_cosine_matrix(cos_sim_dict, save_path): """Plot per-layer cosine similarity (or principal-angle cosine) between versions.""" import matplotlib.pyplot as plt import seaborn as sns pairs = [k for k in cos_sim_dict.keys() if "__VS__" in k] if not pairs: return all_layers = set() for p in pairs: all_layers.update(cos_sim_dict[p].keys()) all_layers = sorted(all_layers) mat = np.zeros((len(all_layers), len(pairs))) for j, p in enumerate(pairs): for i, li in enumerate(all_layers): mat[i, j] = cos_sim_dict[p].get(li, 0.0) fig, ax = plt.subplots(figsize=(14, max(6, len(all_layers) * 0.25))) sns.heatmap(mat, cmap="coolwarm", center=0, ax=ax, xticklabels=[p.replace("__VS__", "\nVS\n") for p in pairs], yticklabels=[f"L{li}" for li in all_layers], annot=True, fmt=".2f", cbar=True) ax.set_title("Cosine / principal-angle similarity (per layer)") plt.tight_layout() plt.savefig(save_path, dpi=120) plt.close() def main(): parser = argparse.ArgumentParser() parser.add_argument("--pca_k", type=int, default=PCA_SUBSPACE_K) parser.add_argument("--resume", action="store_true") args = parser.parse_args() ensure_dirs() log = setup_logger("08_directions", LOGS_DIR / "08_directions.log") # Resume — check if all output files exist out_files = [PLAN_V1_RAW, MON_V1_RAW, PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE] if args.resume and all(p.exists() for p in out_files): log.info("All directions already saved. Skipping (resume).") return log.info(f"Loading {RESIDUALS_PATH}") residuals = torch.load(RESIDUALS_PATH, map_location="cpu") log.info(f"PCA subspace k = {args.pca_k}") layer_ids = sorted(int(k) for k in residuals.keys()) log.info(f"Target layers ({len(layer_ids)}): {layer_ids}") plan_acts = {li: residuals[str(li)]["plan"] for li in layer_ids} mon_acts = {li: residuals[str(li)]["mon"] for li in layer_ids} exec_acts = {li: residuals[str(li)]["exec"] for li in layer_ids} # ============================================================ # v1_raw # ============================================================ log.info("=" * 60) log.info("v1_raw: mean-diff direction") w_plan_raw = compute_mean_diff(plan_acts, exec_acts) w_mon_raw = compute_mean_diff(mon_acts, exec_acts) for li in layer_ids: norm_p = w_plan_raw[li].norm() norm_m = w_mon_raw[li].norm() log.info(f" L{li:2d}: ||w_plan||={norm_p:.2f}, ||w_mon||={norm_m:.2f}") # ============================================================ # v_pca_subspace # ============================================================ log.info("=" * 60) log.info(f"v_pca_subspace: top-{args.pca_k} inter-class scatter PCA") Q_plan = compute_pca_subspace(plan_acts, exec_acts, k=args.pca_k) Q_mon = compute_pca_subspace(mon_acts, exec_acts, k=args.pca_k) for li in layer_ids: log.info(f" L{li:2d}: planning basis shape {tuple(Q_plan[li].shape)}, " f"monitoring basis shape {tuple(Q_mon[li].shape)}") # ============================================================ # Normalize and save # ============================================================ log.info("=" * 60) log.info("Normalizing and saving") versions_plan = { "v1_raw": normalize_directions(w_plan_raw), "v_pca_subspace": normalize_directions(Q_plan), } versions_mon = { "v1_raw": normalize_directions(w_mon_raw), "v_pca_subspace": normalize_directions(Q_mon), } save_directions(versions_plan["v1_raw"], PLAN_V1_RAW) save_directions(versions_plan["v_pca_subspace"], PLAN_V_PCA_SUBSPACE) save_directions(versions_mon["v1_raw"], MON_V1_RAW) save_directions(versions_mon["v_pca_subspace"], MON_V_PCA_SUBSPACE) log.info("All directions saved.") # ============================================================ # Cosine analysis # ============================================================ log.info("Computing cosine / principal-angle similarities...") cos_plan = compute_cosine_similarity_matrix(versions_plan) cos_mon = compute_cosine_similarity_matrix(versions_mon) cross_dim_cos = {} for v in versions_plan: per_layer = {} for li in versions_plan[v]: from src.directions import _subspace_cosine a = versions_plan[v][li] b = versions_mon[v][li] per_layer[li] = _subspace_cosine(a, b) cross_dim_cos[f"plan_{v}__VS__mon_{v}"] = per_layer summary = { "within_planning": {k: {str(li): float(v) for li, v in d.items()} for k, d in cos_plan.items()}, "within_monitoring": {k: {str(li): float(v) for li, v in d.items()} for k, d in cos_mon.items()}, "cross_dim_per_version": {k: {str(li): float(v) for li, v in d.items()} for k, d in cross_dim_cos.items()}, } write_json(summary, RESULTS_DIR / "direction_cosines.json") log.info("Saved direction_cosines.json") merged = {**cos_plan, **cross_dim_cos} plot_cosine_matrix(merged, DIRECTION_COSINE_MATRIX) log.info(f"Saved {DIRECTION_COSINE_MATRIX}") if __name__ == "__main__": main()