""" Stage 2: Dimension interaction analysis. Produces: - Jaccard overlap of top-K experts - Co-activation PMI between (plan_expert, mon_expert) pairs - Cross-dim contrast visualization - (Direction cosine matrix is produced later, after script 08) """ import sys import argparse import json from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import numpy as np import torch from configs.paths import ( ensure_dirs, LOGS_DIR, ROUTING_DIR, LABELED_COTS_PATH, TOP_EXPERTS_PLAN_PATH, TOP_EXPERTS_MON_PATH, RESULTS_DIR, INTERACTION_HEATMAP, ) from configs.model import MODEL_CONFIG from src.utils import setup_logger, read_jsonl, read_json, write_json def compute_jaccard(set_a, set_b): if not set_a and not set_b: return 0.0 return len(set_a & set_b) / len(set_a | set_b) def compute_pmi_matrix(topk_ids_by_layer, token_indices, n_layers, n_experts, eps=1e-6): """ For given tokens, compute co-activation PMI between all (expert_i, expert_j) in each layer. Returns list of (L, E, E) matrices — too large for E=128 × 48 layers. Instead, compute PMI ONLY between the top planning experts and top monitoring experts. """ raise NotImplementedError("Use compute_pmi_pairwise instead.") def compute_pmi_pairwise(topk_ids_by_layer, token_indices, plan_experts, mon_experts, eps=1e-6): """ Compute co-activation PMI between pairs of (plan_expert, mon_expert). For each token t in token_indices, check: - is plan_expert e_p active at its layer l_p? - is mon_expert e_m active at its layer l_m? - both active? Pairs with SAME LAYER yield strongest co-activation signals (since topK can include both simultaneously). Returns a dict: {(l_p, e_p, l_m, e_m): pmi} To avoid combinatorial explosion, we only compute pairs where l_p == l_m (same-layer co-activation). """ n = len(token_indices) if n == 0: return {} idx_tensor = torch.tensor(token_indices, dtype=torch.long) # For each (layer, expert), build activation mask def expert_active(layer, expert): topk = topk_ids_by_layer[layer][idx_tensor].numpy() # (n, top_k) return (topk == expert).any(axis=1) # (n,) bool results = {} for (lp, ep) in plan_experts: for (lm, em) in mon_experts: if lp != lm: continue act_p = expert_active(lp, ep) act_m = expert_active(lm, em) p_p = act_p.mean() + eps p_m = act_m.mean() + eps p_pm = (act_p & act_m).mean() + eps pmi = float(np.log(p_pm / (p_p * p_m))) results[(lp, ep, lm, em)] = { "pmi": pmi, "P_plan": float(p_p), "P_mon": float(p_m), "P_joint": float(p_pm), } return results def load_all_shards(shards_dir, num_layers): """Reuse simplified loader. Only need topk_ids here.""" shard_files = sorted(shards_dir.glob("shard_*.pt")) per_layer_ids = {li: [] for li in range(num_layers)} sample_id_to_range = {} cursor = 0 for sf in shard_files: shard = torch.load(sf, map_location="cpu") for sid, slen in zip(shard["sample_ids"], shard["sample_lengths"]): sample_id_to_range[sid] = (cursor, cursor + slen) cursor += slen for li in range(num_layers): if li in shard["topk_ids"]: per_layer_ids[li].append(shard["topk_ids"][li]) topk_ids = {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v} return topk_ids, sample_id_to_range def collect_global_token_indices(labeled, sample_id_to_range, field): out = [] for r in labeled: sid = r["idx"] if sid not in sample_id_to_range: continue start, end = sample_id_to_range[sid] for ti in r[field]: gi = start + ti if gi < end: out.append(gi) return out def plot_interaction_heatmap( jaccard_value, delta_crossdim, pmi_pairs, save_path, plan_experts, mon_experts, ): import matplotlib.pyplot as plt import seaborn as sns fig, axes = plt.subplots(1, 3, figsize=(24, 7)) # (1) Jaccard as a text box axes[0].axis("off") axes[0].text(0.5, 0.5, f"Jaccard overlap of top-K experts\n\n" f"J = |E_plan ∩ E_mon| / |E_plan ∪ E_mon|\n\n" f"J = {jaccard_value:.3f}\n\n" f"|E_plan| = {len(plan_experts)}\n" f"|E_mon| = {len(mon_experts)}\n" f"|intersection| = " f"{len(set(map(tuple, plan_experts)) & set(map(tuple, mon_experts)))}", ha="center", va="center", fontsize=14, bbox=dict(boxstyle="round,pad=0.8", facecolor="lightblue")) axes[0].set_title("Top-K Expert Overlap", fontsize=14) # (2) Cross-dim contrast: Δfreq(plan) - Δfreq(mon) sns.heatmap(delta_crossdim, cmap="coolwarm", center=0, ax=axes[1], xticklabels=False, yticklabels=False) axes[1].set_xlabel("Expert ID") axes[1].set_ylabel("Layer ID") axes[1].set_title("Δfreq(plan) − Δfreq(mon)\n(experts that distinguish plan from mon)", fontsize=14) # (3) PMI pair distribution if pmi_pairs: pmi_vals = [v["pmi"] for v in pmi_pairs.values()] axes[2].hist(pmi_vals, bins=30, color="steelblue", edgecolor="black") axes[2].axvline(0, color="red", linestyle="--", label="independence (PMI=0)") axes[2].set_xlabel("Co-activation PMI") axes[2].set_ylabel("# pairs") axes[2].set_title( f"Co-activation PMI between\nplan and mon experts (same layer)\n" f"Mean PMI = {np.mean(pmi_vals):+.3f}", fontsize=12, ) axes[2].legend() else: axes[2].text(0.5, 0.5, "No same-layer plan-mon pairs found", ha="center", va="center") axes[2].axis("off") plt.tight_layout() plt.savefig(save_path, dpi=120) plt.close() def main(): parser = argparse.ArgumentParser() parser.add_argument("--resume", action="store_true") args = parser.parse_args() ensure_dirs() log = setup_logger("06_interaction", LOGS_DIR / "06_interaction.log") # Load top experts top_plan = read_json(TOP_EXPERTS_PLAN_PATH) top_mon = read_json(TOP_EXPERTS_MON_PATH) plan_pairs = [(d["layer"], d["expert"]) for d in top_plan["top_experts"]] mon_pairs = [(d["layer"], d["expert"]) for d in top_mon["top_experts"]] # 1) Jaccard overlap jac = compute_jaccard(set(plan_pairs), set(mon_pairs)) log.info(f"Jaccard overlap (top-K experts): {jac:.3f}") # 2) Cross-dim contrast stats = np.load(RESULTS_DIR / "routing_stats.npz") delta_plan = stats["delta_plan_vs_exec"] delta_mon = stats["delta_mon_vs_exec"] delta_crossdim = delta_plan - delta_mon # positive => plan-selective, negative => mon-selective # 3) Same-layer PMI of plan-mon expert pairs log.info("Loading routing shards for PMI...") num_layers = MODEL_CONFIG["num_layers"] topk_ids, sample_id_to_range = load_all_shards(ROUTING_DIR, num_layers) labeled = read_jsonl(LABELED_COTS_PATH) plan_tis = collect_global_token_indices(labeled, sample_id_to_range, "plan_decision_tis") log.info(f"Computing PMI over {len(plan_tis)} planning decision points " f"for same-layer (plan_expert, mon_expert) pairs...") pmi_pairs = compute_pmi_pairwise( topk_ids, plan_tis, plan_pairs, mon_pairs, ) log.info(f"Computed PMI for {len(pmi_pairs)} same-layer pairs") # 4) Summary & save summary = { "jaccard_overlap": float(jac), "n_plan_experts": len(plan_pairs), "n_mon_experts": len(mon_pairs), "intersection": [list(p) for p in (set(plan_pairs) & set(mon_pairs))], "n_pmi_pairs": len(pmi_pairs), "pmi_pairs": [ {"plan_layer": k[0], "plan_expert": k[1], "mon_layer": k[2], "mon_expert": k[3], **v} for k, v in pmi_pairs.items() ], } if pmi_pairs: pmi_vals = [v["pmi"] for v in pmi_pairs.values()] summary["pmi_stats"] = { "mean": float(np.mean(pmi_vals)), "std": float(np.std(pmi_vals)), "max": float(np.max(pmi_vals)), "min": float(np.min(pmi_vals)), } write_json(summary, RESULTS_DIR / "interaction_summary.json") # Plot plot_interaction_heatmap( jac, delta_crossdim, pmi_pairs, INTERACTION_HEATMAP, plan_pairs, mon_pairs, ) log.info(f"Saved interaction heatmap: {INTERACTION_HEATMAP}") log.info(f"Saved interaction summary: {RESULTS_DIR / 'interaction_summary.json'}") if __name__ == "__main__": main()