| """ |
| Stage 2: Dimension interaction analysis. |
| |
| Produces: |
| - Jaccard overlap of top-K experts |
| - Co-activation PMI between (plan_expert, mon_expert) pairs |
| - Cross-dim contrast visualization |
| - (Direction cosine matrix is produced later, after script 08) |
| """ |
| import sys |
| import argparse |
| import json |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import numpy as np |
| import torch |
|
|
| from configs.paths import ( |
| ensure_dirs, LOGS_DIR, ROUTING_DIR, LABELED_COTS_PATH, |
| TOP_EXPERTS_PLAN_PATH, TOP_EXPERTS_MON_PATH, |
| RESULTS_DIR, INTERACTION_HEATMAP, |
| ) |
| from configs.model import MODEL_CONFIG |
| from src.utils import setup_logger, read_jsonl, read_json, write_json |
|
|
|
|
| def compute_jaccard(set_a, set_b): |
| if not set_a and not set_b: |
| return 0.0 |
| return len(set_a & set_b) / len(set_a | set_b) |
|
|
|
|
| def compute_pmi_matrix(topk_ids_by_layer, token_indices, n_layers, n_experts, eps=1e-6): |
| """ |
| For given tokens, compute co-activation PMI between all (expert_i, expert_j) in each layer. |
| |
| Returns list of (L, E, E) matrices — too large for E=128 × 48 layers. |
| Instead, compute PMI ONLY between the top planning experts and top monitoring experts. |
| """ |
| raise NotImplementedError("Use compute_pmi_pairwise instead.") |
|
|
|
|
| def compute_pmi_pairwise(topk_ids_by_layer, token_indices, plan_experts, mon_experts, eps=1e-6): |
| """ |
| Compute co-activation PMI between pairs of (plan_expert, mon_expert). |
| |
| For each token t in token_indices, check: |
| - is plan_expert e_p active at its layer l_p? |
| - is mon_expert e_m active at its layer l_m? |
| - both active? |
| |
| Pairs with SAME LAYER yield strongest co-activation signals |
| (since topK can include both simultaneously). |
| |
| Returns a dict: {(l_p, e_p, l_m, e_m): pmi} |
| |
| To avoid combinatorial explosion, we only compute pairs where l_p == l_m |
| (same-layer co-activation). |
| """ |
| n = len(token_indices) |
| if n == 0: |
| return {} |
| idx_tensor = torch.tensor(token_indices, dtype=torch.long) |
|
|
| |
| def expert_active(layer, expert): |
| topk = topk_ids_by_layer[layer][idx_tensor].numpy() |
| return (topk == expert).any(axis=1) |
|
|
| results = {} |
| for (lp, ep) in plan_experts: |
| for (lm, em) in mon_experts: |
| if lp != lm: |
| continue |
| act_p = expert_active(lp, ep) |
| act_m = expert_active(lm, em) |
| p_p = act_p.mean() + eps |
| p_m = act_m.mean() + eps |
| p_pm = (act_p & act_m).mean() + eps |
| pmi = float(np.log(p_pm / (p_p * p_m))) |
| results[(lp, ep, lm, em)] = { |
| "pmi": pmi, |
| "P_plan": float(p_p), |
| "P_mon": float(p_m), |
| "P_joint": float(p_pm), |
| } |
| return results |
|
|
|
|
| def load_all_shards(shards_dir, num_layers): |
| """Reuse simplified loader. Only need topk_ids here.""" |
| shard_files = sorted(shards_dir.glob("shard_*.pt")) |
| per_layer_ids = {li: [] for li in range(num_layers)} |
| sample_id_to_range = {} |
| cursor = 0 |
| for sf in shard_files: |
| shard = torch.load(sf, map_location="cpu") |
| for sid, slen in zip(shard["sample_ids"], shard["sample_lengths"]): |
| sample_id_to_range[sid] = (cursor, cursor + slen) |
| cursor += slen |
| for li in range(num_layers): |
| if li in shard["topk_ids"]: |
| per_layer_ids[li].append(shard["topk_ids"][li]) |
| topk_ids = {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v} |
| return topk_ids, sample_id_to_range |
|
|
|
|
| def collect_global_token_indices(labeled, sample_id_to_range, field): |
| out = [] |
| for r in labeled: |
| sid = r["idx"] |
| if sid not in sample_id_to_range: |
| continue |
| start, end = sample_id_to_range[sid] |
| for ti in r[field]: |
| gi = start + ti |
| if gi < end: |
| out.append(gi) |
| return out |
|
|
|
|
| def plot_interaction_heatmap( |
| jaccard_value, delta_crossdim, pmi_pairs, save_path, |
| plan_experts, mon_experts, |
| ): |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
|
|
| fig, axes = plt.subplots(1, 3, figsize=(24, 7)) |
|
|
| |
| axes[0].axis("off") |
| axes[0].text(0.5, 0.5, |
| f"Jaccard overlap of top-K experts\n\n" |
| f"J = |E_plan ∩ E_mon| / |E_plan ∪ E_mon|\n\n" |
| f"J = {jaccard_value:.3f}\n\n" |
| f"|E_plan| = {len(plan_experts)}\n" |
| f"|E_mon| = {len(mon_experts)}\n" |
| f"|intersection| = " |
| f"{len(set(map(tuple, plan_experts)) & set(map(tuple, mon_experts)))}", |
| ha="center", va="center", fontsize=14, |
| bbox=dict(boxstyle="round,pad=0.8", facecolor="lightblue")) |
| axes[0].set_title("Top-K Expert Overlap", fontsize=14) |
|
|
| |
| sns.heatmap(delta_crossdim, cmap="coolwarm", center=0, ax=axes[1], |
| xticklabels=False, yticklabels=False) |
| axes[1].set_xlabel("Expert ID") |
| axes[1].set_ylabel("Layer ID") |
| axes[1].set_title("Δfreq(plan) − Δfreq(mon)\n(experts that distinguish plan from mon)", |
| fontsize=14) |
|
|
| |
| if pmi_pairs: |
| pmi_vals = [v["pmi"] for v in pmi_pairs.values()] |
| axes[2].hist(pmi_vals, bins=30, color="steelblue", edgecolor="black") |
| axes[2].axvline(0, color="red", linestyle="--", label="independence (PMI=0)") |
| axes[2].set_xlabel("Co-activation PMI") |
| axes[2].set_ylabel("# pairs") |
| axes[2].set_title( |
| f"Co-activation PMI between\nplan and mon experts (same layer)\n" |
| f"Mean PMI = {np.mean(pmi_vals):+.3f}", fontsize=12, |
| ) |
| axes[2].legend() |
| else: |
| axes[2].text(0.5, 0.5, "No same-layer plan-mon pairs found", ha="center", va="center") |
| axes[2].axis("off") |
|
|
| plt.tight_layout() |
| plt.savefig(save_path, dpi=120) |
| plt.close() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--resume", action="store_true") |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("06_interaction", LOGS_DIR / "06_interaction.log") |
|
|
| |
| top_plan = read_json(TOP_EXPERTS_PLAN_PATH) |
| top_mon = read_json(TOP_EXPERTS_MON_PATH) |
| plan_pairs = [(d["layer"], d["expert"]) for d in top_plan["top_experts"]] |
| mon_pairs = [(d["layer"], d["expert"]) for d in top_mon["top_experts"]] |
|
|
| |
| jac = compute_jaccard(set(plan_pairs), set(mon_pairs)) |
| log.info(f"Jaccard overlap (top-K experts): {jac:.3f}") |
|
|
| |
| stats = np.load(RESULTS_DIR / "routing_stats.npz") |
| delta_plan = stats["delta_plan_vs_exec"] |
| delta_mon = stats["delta_mon_vs_exec"] |
| delta_crossdim = delta_plan - delta_mon |
|
|
| |
| log.info("Loading routing shards for PMI...") |
| num_layers = MODEL_CONFIG["num_layers"] |
| topk_ids, sample_id_to_range = load_all_shards(ROUTING_DIR, num_layers) |
| labeled = read_jsonl(LABELED_COTS_PATH) |
| plan_tis = collect_global_token_indices(labeled, sample_id_to_range, "plan_decision_tis") |
| log.info(f"Computing PMI over {len(plan_tis)} planning decision points " |
| f"for same-layer (plan_expert, mon_expert) pairs...") |
| pmi_pairs = compute_pmi_pairwise( |
| topk_ids, plan_tis, plan_pairs, mon_pairs, |
| ) |
| log.info(f"Computed PMI for {len(pmi_pairs)} same-layer pairs") |
|
|
| |
| summary = { |
| "jaccard_overlap": float(jac), |
| "n_plan_experts": len(plan_pairs), |
| "n_mon_experts": len(mon_pairs), |
| "intersection": [list(p) for p in (set(plan_pairs) & set(mon_pairs))], |
| "n_pmi_pairs": len(pmi_pairs), |
| "pmi_pairs": [ |
| {"plan_layer": k[0], "plan_expert": k[1], |
| "mon_layer": k[2], "mon_expert": k[3], **v} |
| for k, v in pmi_pairs.items() |
| ], |
| } |
| if pmi_pairs: |
| pmi_vals = [v["pmi"] for v in pmi_pairs.values()] |
| summary["pmi_stats"] = { |
| "mean": float(np.mean(pmi_vals)), |
| "std": float(np.std(pmi_vals)), |
| "max": float(np.max(pmi_vals)), |
| "min": float(np.min(pmi_vals)), |
| } |
| write_json(summary, RESULTS_DIR / "interaction_summary.json") |
|
|
| |
| plot_interaction_heatmap( |
| jac, delta_crossdim, pmi_pairs, INTERACTION_HEATMAP, |
| plan_pairs, mon_pairs, |
| ) |
| log.info(f"Saved interaction heatmap: {INTERACTION_HEATMAP}") |
| log.info(f"Saved interaction summary: {RESULTS_DIR / 'interaction_summary.json'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|