| """ |
| Stage 3 part B: Compute 2 versions of direction(s) per dimension. |
| |
| v1_raw - mean(plan) - mean(exec), single direction (D,) |
| v_pca_subspace - top-k subspace from inter-class scatter PCA, basis (k, D) |
| |
| Removed: |
| v2_ortho_general — empirically had cosine > 0.9 to v1 (no signal) |
| v3_ortho_crossdim — same (>0.99 cosine to v2) |
| v4_pca (old) — was conceptually wrong (PCA over union, not contrast) |
| |
| The new v_pca_subspace is the principled subspace approach: extracts the top-k |
| directions of largest inter-class (plan-vs-exec) variation. |
| """ |
| import sys |
| import argparse |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
| import numpy as np |
|
|
| from configs.paths import ( |
| ensure_dirs, LOGS_DIR, |
| RESIDUALS_PATH, GENERAL_RESIDUALS_PATH, GENERAL_DIR_PATH, |
| PLAN_V1_RAW, MON_V1_RAW, |
| PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE, |
| CHECKPOINTS_DIR, |
| RESULTS_DIR, DIRECTION_COSINE_MATRIX, |
| ) |
| from configs.model import PCA_SUBSPACE_K |
| from src.utils import setup_logger, write_json |
| from src.directions import ( |
| compute_mean_diff, compute_pca_subspace, |
| normalize_directions, |
| compute_cosine_similarity_matrix, |
| save_directions, load_directions, |
| ) |
|
|
|
|
| def plot_cosine_matrix(cos_sim_dict, save_path): |
| """Plot per-layer cosine similarity (or principal-angle cosine) between versions.""" |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
|
|
| pairs = [k for k in cos_sim_dict.keys() if "__VS__" in k] |
| if not pairs: |
| return |
|
|
| all_layers = set() |
| for p in pairs: |
| all_layers.update(cos_sim_dict[p].keys()) |
| all_layers = sorted(all_layers) |
|
|
| mat = np.zeros((len(all_layers), len(pairs))) |
| for j, p in enumerate(pairs): |
| for i, li in enumerate(all_layers): |
| mat[i, j] = cos_sim_dict[p].get(li, 0.0) |
|
|
| fig, ax = plt.subplots(figsize=(14, max(6, len(all_layers) * 0.25))) |
| sns.heatmap(mat, cmap="coolwarm", center=0, ax=ax, |
| xticklabels=[p.replace("__VS__", "\nVS\n") for p in pairs], |
| yticklabels=[f"L{li}" for li in all_layers], |
| annot=True, fmt=".2f", cbar=True) |
| ax.set_title("Cosine / principal-angle similarity (per layer)") |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=120) |
| plt.close() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--pca_k", type=int, default=PCA_SUBSPACE_K) |
| parser.add_argument("--resume", action="store_true") |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("08_directions", LOGS_DIR / "08_directions.log") |
|
|
| |
| out_files = [PLAN_V1_RAW, MON_V1_RAW, PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE] |
| if args.resume and all(p.exists() for p in out_files): |
| log.info("All directions already saved. Skipping (resume).") |
| return |
|
|
| log.info(f"Loading {RESIDUALS_PATH}") |
| residuals = torch.load(RESIDUALS_PATH, map_location="cpu") |
| log.info(f"PCA subspace k = {args.pca_k}") |
|
|
| layer_ids = sorted(int(k) for k in residuals.keys()) |
| log.info(f"Target layers ({len(layer_ids)}): {layer_ids}") |
|
|
| plan_acts = {li: residuals[str(li)]["plan"] for li in layer_ids} |
| mon_acts = {li: residuals[str(li)]["mon"] for li in layer_ids} |
| exec_acts = {li: residuals[str(li)]["exec"] for li in layer_ids} |
|
|
| |
| |
| |
| log.info("=" * 60) |
| log.info("v1_raw: mean-diff direction") |
| w_plan_raw = compute_mean_diff(plan_acts, exec_acts) |
| w_mon_raw = compute_mean_diff(mon_acts, exec_acts) |
| for li in layer_ids: |
| norm_p = w_plan_raw[li].norm() |
| norm_m = w_mon_raw[li].norm() |
| log.info(f" L{li:2d}: ||w_plan||={norm_p:.2f}, ||w_mon||={norm_m:.2f}") |
|
|
| |
| |
| |
| log.info("=" * 60) |
| log.info(f"v_pca_subspace: top-{args.pca_k} inter-class scatter PCA") |
| Q_plan = compute_pca_subspace(plan_acts, exec_acts, k=args.pca_k) |
| Q_mon = compute_pca_subspace(mon_acts, exec_acts, k=args.pca_k) |
| for li in layer_ids: |
| log.info(f" L{li:2d}: planning basis shape {tuple(Q_plan[li].shape)}, " |
| f"monitoring basis shape {tuple(Q_mon[li].shape)}") |
|
|
| |
| |
| |
| log.info("=" * 60) |
| log.info("Normalizing and saving") |
|
|
| versions_plan = { |
| "v1_raw": normalize_directions(w_plan_raw), |
| "v_pca_subspace": normalize_directions(Q_plan), |
| } |
| versions_mon = { |
| "v1_raw": normalize_directions(w_mon_raw), |
| "v_pca_subspace": normalize_directions(Q_mon), |
| } |
|
|
| save_directions(versions_plan["v1_raw"], PLAN_V1_RAW) |
| save_directions(versions_plan["v_pca_subspace"], PLAN_V_PCA_SUBSPACE) |
| save_directions(versions_mon["v1_raw"], MON_V1_RAW) |
| save_directions(versions_mon["v_pca_subspace"], MON_V_PCA_SUBSPACE) |
| log.info("All directions saved.") |
|
|
| |
| |
| |
| log.info("Computing cosine / principal-angle similarities...") |
| cos_plan = compute_cosine_similarity_matrix(versions_plan) |
| cos_mon = compute_cosine_similarity_matrix(versions_mon) |
|
|
| cross_dim_cos = {} |
| for v in versions_plan: |
| per_layer = {} |
| for li in versions_plan[v]: |
| from src.directions import _subspace_cosine |
| a = versions_plan[v][li] |
| b = versions_mon[v][li] |
| per_layer[li] = _subspace_cosine(a, b) |
| cross_dim_cos[f"plan_{v}__VS__mon_{v}"] = per_layer |
|
|
| summary = { |
| "within_planning": {k: {str(li): float(v) for li, v in d.items()} |
| for k, d in cos_plan.items()}, |
| "within_monitoring": {k: {str(li): float(v) for li, v in d.items()} |
| for k, d in cos_mon.items()}, |
| "cross_dim_per_version": {k: {str(li): float(v) for li, v in d.items()} |
| for k, d in cross_dim_cos.items()}, |
| } |
| write_json(summary, RESULTS_DIR / "direction_cosines.json") |
| log.info("Saved direction_cosines.json") |
|
|
| merged = {**cos_plan, **cross_dim_cos} |
| plot_cosine_matrix(merged, DIRECTION_COSINE_MATRIX) |
| log.info(f"Saved {DIRECTION_COSINE_MATRIX}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|