v2 / scripts /08_compute_directions.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Stage 3 part B: Compute 2 versions of direction(s) per dimension.
v1_raw - mean(plan) - mean(exec), single direction (D,)
v_pca_subspace - top-k subspace from inter-class scatter PCA, basis (k, D)
Removed:
v2_ortho_general — empirically had cosine > 0.9 to v1 (no signal)
v3_ortho_crossdim — same (>0.99 cosine to v2)
v4_pca (old) — was conceptually wrong (PCA over union, not contrast)
The new v_pca_subspace is the principled subspace approach: extracts the top-k
directions of largest inter-class (plan-vs-exec) variation.
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
import numpy as np
from configs.paths import (
ensure_dirs, LOGS_DIR,
RESIDUALS_PATH, GENERAL_RESIDUALS_PATH, GENERAL_DIR_PATH,
PLAN_V1_RAW, MON_V1_RAW,
PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE,
CHECKPOINTS_DIR,
RESULTS_DIR, DIRECTION_COSINE_MATRIX,
)
from configs.model import PCA_SUBSPACE_K
from src.utils import setup_logger, write_json
from src.directions import (
compute_mean_diff, compute_pca_subspace,
normalize_directions,
compute_cosine_similarity_matrix,
save_directions, load_directions,
)
def plot_cosine_matrix(cos_sim_dict, save_path):
"""Plot per-layer cosine similarity (or principal-angle cosine) between versions."""
import matplotlib.pyplot as plt
import seaborn as sns
pairs = [k for k in cos_sim_dict.keys() if "__VS__" in k]
if not pairs:
return
all_layers = set()
for p in pairs:
all_layers.update(cos_sim_dict[p].keys())
all_layers = sorted(all_layers)
mat = np.zeros((len(all_layers), len(pairs)))
for j, p in enumerate(pairs):
for i, li in enumerate(all_layers):
mat[i, j] = cos_sim_dict[p].get(li, 0.0)
fig, ax = plt.subplots(figsize=(14, max(6, len(all_layers) * 0.25)))
sns.heatmap(mat, cmap="coolwarm", center=0, ax=ax,
xticklabels=[p.replace("__VS__", "\nVS\n") for p in pairs],
yticklabels=[f"L{li}" for li in all_layers],
annot=True, fmt=".2f", cbar=True)
ax.set_title("Cosine / principal-angle similarity (per layer)")
plt.tight_layout()
plt.savefig(save_path, dpi=120)
plt.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--pca_k", type=int, default=PCA_SUBSPACE_K)
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("08_directions", LOGS_DIR / "08_directions.log")
# Resume — check if all output files exist
out_files = [PLAN_V1_RAW, MON_V1_RAW, PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE]
if args.resume and all(p.exists() for p in out_files):
log.info("All directions already saved. Skipping (resume).")
return
log.info(f"Loading {RESIDUALS_PATH}")
residuals = torch.load(RESIDUALS_PATH, map_location="cpu")
log.info(f"PCA subspace k = {args.pca_k}")
layer_ids = sorted(int(k) for k in residuals.keys())
log.info(f"Target layers ({len(layer_ids)}): {layer_ids}")
plan_acts = {li: residuals[str(li)]["plan"] for li in layer_ids}
mon_acts = {li: residuals[str(li)]["mon"] for li in layer_ids}
exec_acts = {li: residuals[str(li)]["exec"] for li in layer_ids}
# ============================================================
# v1_raw
# ============================================================
log.info("=" * 60)
log.info("v1_raw: mean-diff direction")
w_plan_raw = compute_mean_diff(plan_acts, exec_acts)
w_mon_raw = compute_mean_diff(mon_acts, exec_acts)
for li in layer_ids:
norm_p = w_plan_raw[li].norm()
norm_m = w_mon_raw[li].norm()
log.info(f" L{li:2d}: ||w_plan||={norm_p:.2f}, ||w_mon||={norm_m:.2f}")
# ============================================================
# v_pca_subspace
# ============================================================
log.info("=" * 60)
log.info(f"v_pca_subspace: top-{args.pca_k} inter-class scatter PCA")
Q_plan = compute_pca_subspace(plan_acts, exec_acts, k=args.pca_k)
Q_mon = compute_pca_subspace(mon_acts, exec_acts, k=args.pca_k)
for li in layer_ids:
log.info(f" L{li:2d}: planning basis shape {tuple(Q_plan[li].shape)}, "
f"monitoring basis shape {tuple(Q_mon[li].shape)}")
# ============================================================
# Normalize and save
# ============================================================
log.info("=" * 60)
log.info("Normalizing and saving")
versions_plan = {
"v1_raw": normalize_directions(w_plan_raw),
"v_pca_subspace": normalize_directions(Q_plan),
}
versions_mon = {
"v1_raw": normalize_directions(w_mon_raw),
"v_pca_subspace": normalize_directions(Q_mon),
}
save_directions(versions_plan["v1_raw"], PLAN_V1_RAW)
save_directions(versions_plan["v_pca_subspace"], PLAN_V_PCA_SUBSPACE)
save_directions(versions_mon["v1_raw"], MON_V1_RAW)
save_directions(versions_mon["v_pca_subspace"], MON_V_PCA_SUBSPACE)
log.info("All directions saved.")
# ============================================================
# Cosine analysis
# ============================================================
log.info("Computing cosine / principal-angle similarities...")
cos_plan = compute_cosine_similarity_matrix(versions_plan)
cos_mon = compute_cosine_similarity_matrix(versions_mon)
cross_dim_cos = {}
for v in versions_plan:
per_layer = {}
for li in versions_plan[v]:
from src.directions import _subspace_cosine
a = versions_plan[v][li]
b = versions_mon[v][li]
per_layer[li] = _subspace_cosine(a, b)
cross_dim_cos[f"plan_{v}__VS__mon_{v}"] = per_layer
summary = {
"within_planning": {k: {str(li): float(v) for li, v in d.items()}
for k, d in cos_plan.items()},
"within_monitoring": {k: {str(li): float(v) for li, v in d.items()}
for k, d in cos_mon.items()},
"cross_dim_per_version": {k: {str(li): float(v) for li, v in d.items()}
for k, d in cross_dim_cos.items()},
}
write_json(summary, RESULTS_DIR / "direction_cosines.json")
log.info("Saved direction_cosines.json")
merged = {**cos_plan, **cross_dim_cos}
plot_cosine_matrix(merged, DIRECTION_COSINE_MATRIX)
log.info(f"Saved {DIRECTION_COSINE_MATRIX}")
if __name__ == "__main__":
main()