File size: 6,699 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """
Stage 3 part B: Compute 2 versions of direction(s) per dimension.
v1_raw - mean(plan) - mean(exec), single direction (D,)
v_pca_subspace - top-k subspace from inter-class scatter PCA, basis (k, D)
Removed:
v2_ortho_general — empirically had cosine > 0.9 to v1 (no signal)
v3_ortho_crossdim — same (>0.99 cosine to v2)
v4_pca (old) — was conceptually wrong (PCA over union, not contrast)
The new v_pca_subspace is the principled subspace approach: extracts the top-k
directions of largest inter-class (plan-vs-exec) variation.
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
import numpy as np
from configs.paths import (
ensure_dirs, LOGS_DIR,
RESIDUALS_PATH, GENERAL_RESIDUALS_PATH, GENERAL_DIR_PATH,
PLAN_V1_RAW, MON_V1_RAW,
PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE,
CHECKPOINTS_DIR,
RESULTS_DIR, DIRECTION_COSINE_MATRIX,
)
from configs.model import PCA_SUBSPACE_K
from src.utils import setup_logger, write_json
from src.directions import (
compute_mean_diff, compute_pca_subspace,
normalize_directions,
compute_cosine_similarity_matrix,
save_directions, load_directions,
)
def plot_cosine_matrix(cos_sim_dict, save_path):
"""Plot per-layer cosine similarity (or principal-angle cosine) between versions."""
import matplotlib.pyplot as plt
import seaborn as sns
pairs = [k for k in cos_sim_dict.keys() if "__VS__" in k]
if not pairs:
return
all_layers = set()
for p in pairs:
all_layers.update(cos_sim_dict[p].keys())
all_layers = sorted(all_layers)
mat = np.zeros((len(all_layers), len(pairs)))
for j, p in enumerate(pairs):
for i, li in enumerate(all_layers):
mat[i, j] = cos_sim_dict[p].get(li, 0.0)
fig, ax = plt.subplots(figsize=(14, max(6, len(all_layers) * 0.25)))
sns.heatmap(mat, cmap="coolwarm", center=0, ax=ax,
xticklabels=[p.replace("__VS__", "\nVS\n") for p in pairs],
yticklabels=[f"L{li}" for li in all_layers],
annot=True, fmt=".2f", cbar=True)
ax.set_title("Cosine / principal-angle similarity (per layer)")
plt.tight_layout()
plt.savefig(save_path, dpi=120)
plt.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--pca_k", type=int, default=PCA_SUBSPACE_K)
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("08_directions", LOGS_DIR / "08_directions.log")
# Resume — check if all output files exist
out_files = [PLAN_V1_RAW, MON_V1_RAW, PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE]
if args.resume and all(p.exists() for p in out_files):
log.info("All directions already saved. Skipping (resume).")
return
log.info(f"Loading {RESIDUALS_PATH}")
residuals = torch.load(RESIDUALS_PATH, map_location="cpu")
log.info(f"PCA subspace k = {args.pca_k}")
layer_ids = sorted(int(k) for k in residuals.keys())
log.info(f"Target layers ({len(layer_ids)}): {layer_ids}")
plan_acts = {li: residuals[str(li)]["plan"] for li in layer_ids}
mon_acts = {li: residuals[str(li)]["mon"] for li in layer_ids}
exec_acts = {li: residuals[str(li)]["exec"] for li in layer_ids}
# ============================================================
# v1_raw
# ============================================================
log.info("=" * 60)
log.info("v1_raw: mean-diff direction")
w_plan_raw = compute_mean_diff(plan_acts, exec_acts)
w_mon_raw = compute_mean_diff(mon_acts, exec_acts)
for li in layer_ids:
norm_p = w_plan_raw[li].norm()
norm_m = w_mon_raw[li].norm()
log.info(f" L{li:2d}: ||w_plan||={norm_p:.2f}, ||w_mon||={norm_m:.2f}")
# ============================================================
# v_pca_subspace
# ============================================================
log.info("=" * 60)
log.info(f"v_pca_subspace: top-{args.pca_k} inter-class scatter PCA")
Q_plan = compute_pca_subspace(plan_acts, exec_acts, k=args.pca_k)
Q_mon = compute_pca_subspace(mon_acts, exec_acts, k=args.pca_k)
for li in layer_ids:
log.info(f" L{li:2d}: planning basis shape {tuple(Q_plan[li].shape)}, "
f"monitoring basis shape {tuple(Q_mon[li].shape)}")
# ============================================================
# Normalize and save
# ============================================================
log.info("=" * 60)
log.info("Normalizing and saving")
versions_plan = {
"v1_raw": normalize_directions(w_plan_raw),
"v_pca_subspace": normalize_directions(Q_plan),
}
versions_mon = {
"v1_raw": normalize_directions(w_mon_raw),
"v_pca_subspace": normalize_directions(Q_mon),
}
save_directions(versions_plan["v1_raw"], PLAN_V1_RAW)
save_directions(versions_plan["v_pca_subspace"], PLAN_V_PCA_SUBSPACE)
save_directions(versions_mon["v1_raw"], MON_V1_RAW)
save_directions(versions_mon["v_pca_subspace"], MON_V_PCA_SUBSPACE)
log.info("All directions saved.")
# ============================================================
# Cosine analysis
# ============================================================
log.info("Computing cosine / principal-angle similarities...")
cos_plan = compute_cosine_similarity_matrix(versions_plan)
cos_mon = compute_cosine_similarity_matrix(versions_mon)
cross_dim_cos = {}
for v in versions_plan:
per_layer = {}
for li in versions_plan[v]:
from src.directions import _subspace_cosine
a = versions_plan[v][li]
b = versions_mon[v][li]
per_layer[li] = _subspace_cosine(a, b)
cross_dim_cos[f"plan_{v}__VS__mon_{v}"] = per_layer
summary = {
"within_planning": {k: {str(li): float(v) for li, v in d.items()}
for k, d in cos_plan.items()},
"within_monitoring": {k: {str(li): float(v) for li, v in d.items()}
for k, d in cos_mon.items()},
"cross_dim_per_version": {k: {str(li): float(v) for li, v in d.items()}
for k, d in cross_dim_cos.items()},
}
write_json(summary, RESULTS_DIR / "direction_cosines.json")
log.info("Saved direction_cosines.json")
merged = {**cos_plan, **cross_dim_cos}
plot_cosine_matrix(merged, DIRECTION_COSINE_MATRIX)
log.info(f"Saved {DIRECTION_COSINE_MATRIX}")
if __name__ == "__main__":
main()
|