v2 / scripts /06_interaction_analysis.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Stage 2: Dimension interaction analysis.
Produces:
- Jaccard overlap of top-K experts
- Co-activation PMI between (plan_expert, mon_expert) pairs
- Cross-dim contrast visualization
- (Direction cosine matrix is produced later, after script 08)
"""
import sys
import argparse
import json
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import numpy as np
import torch
from configs.paths import (
ensure_dirs, LOGS_DIR, ROUTING_DIR, LABELED_COTS_PATH,
TOP_EXPERTS_PLAN_PATH, TOP_EXPERTS_MON_PATH,
RESULTS_DIR, INTERACTION_HEATMAP,
)
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, read_json, write_json
def compute_jaccard(set_a, set_b):
if not set_a and not set_b:
return 0.0
return len(set_a & set_b) / len(set_a | set_b)
def compute_pmi_matrix(topk_ids_by_layer, token_indices, n_layers, n_experts, eps=1e-6):
"""
For given tokens, compute co-activation PMI between all (expert_i, expert_j) in each layer.
Returns list of (L, E, E) matrices — too large for E=128 × 48 layers.
Instead, compute PMI ONLY between the top planning experts and top monitoring experts.
"""
raise NotImplementedError("Use compute_pmi_pairwise instead.")
def compute_pmi_pairwise(topk_ids_by_layer, token_indices, plan_experts, mon_experts, eps=1e-6):
"""
Compute co-activation PMI between pairs of (plan_expert, mon_expert).
For each token t in token_indices, check:
- is plan_expert e_p active at its layer l_p?
- is mon_expert e_m active at its layer l_m?
- both active?
Pairs with SAME LAYER yield strongest co-activation signals
(since topK can include both simultaneously).
Returns a dict: {(l_p, e_p, l_m, e_m): pmi}
To avoid combinatorial explosion, we only compute pairs where l_p == l_m
(same-layer co-activation).
"""
n = len(token_indices)
if n == 0:
return {}
idx_tensor = torch.tensor(token_indices, dtype=torch.long)
# For each (layer, expert), build activation mask
def expert_active(layer, expert):
topk = topk_ids_by_layer[layer][idx_tensor].numpy() # (n, top_k)
return (topk == expert).any(axis=1) # (n,) bool
results = {}
for (lp, ep) in plan_experts:
for (lm, em) in mon_experts:
if lp != lm:
continue
act_p = expert_active(lp, ep)
act_m = expert_active(lm, em)
p_p = act_p.mean() + eps
p_m = act_m.mean() + eps
p_pm = (act_p & act_m).mean() + eps
pmi = float(np.log(p_pm / (p_p * p_m)))
results[(lp, ep, lm, em)] = {
"pmi": pmi,
"P_plan": float(p_p),
"P_mon": float(p_m),
"P_joint": float(p_pm),
}
return results
def load_all_shards(shards_dir, num_layers):
"""Reuse simplified loader. Only need topk_ids here."""
shard_files = sorted(shards_dir.glob("shard_*.pt"))
per_layer_ids = {li: [] for li in range(num_layers)}
sample_id_to_range = {}
cursor = 0
for sf in shard_files:
shard = torch.load(sf, map_location="cpu")
for sid, slen in zip(shard["sample_ids"], shard["sample_lengths"]):
sample_id_to_range[sid] = (cursor, cursor + slen)
cursor += slen
for li in range(num_layers):
if li in shard["topk_ids"]:
per_layer_ids[li].append(shard["topk_ids"][li])
topk_ids = {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v}
return topk_ids, sample_id_to_range
def collect_global_token_indices(labeled, sample_id_to_range, field):
out = []
for r in labeled:
sid = r["idx"]
if sid not in sample_id_to_range:
continue
start, end = sample_id_to_range[sid]
for ti in r[field]:
gi = start + ti
if gi < end:
out.append(gi)
return out
def plot_interaction_heatmap(
jaccard_value, delta_crossdim, pmi_pairs, save_path,
plan_experts, mon_experts,
):
import matplotlib.pyplot as plt
import seaborn as sns
fig, axes = plt.subplots(1, 3, figsize=(24, 7))
# (1) Jaccard as a text box
axes[0].axis("off")
axes[0].text(0.5, 0.5,
f"Jaccard overlap of top-K experts\n\n"
f"J = |E_plan ∩ E_mon| / |E_plan ∪ E_mon|\n\n"
f"J = {jaccard_value:.3f}\n\n"
f"|E_plan| = {len(plan_experts)}\n"
f"|E_mon| = {len(mon_experts)}\n"
f"|intersection| = "
f"{len(set(map(tuple, plan_experts)) & set(map(tuple, mon_experts)))}",
ha="center", va="center", fontsize=14,
bbox=dict(boxstyle="round,pad=0.8", facecolor="lightblue"))
axes[0].set_title("Top-K Expert Overlap", fontsize=14)
# (2) Cross-dim contrast: Δfreq(plan) - Δfreq(mon)
sns.heatmap(delta_crossdim, cmap="coolwarm", center=0, ax=axes[1],
xticklabels=False, yticklabels=False)
axes[1].set_xlabel("Expert ID")
axes[1].set_ylabel("Layer ID")
axes[1].set_title("Δfreq(plan) − Δfreq(mon)\n(experts that distinguish plan from mon)",
fontsize=14)
# (3) PMI pair distribution
if pmi_pairs:
pmi_vals = [v["pmi"] for v in pmi_pairs.values()]
axes[2].hist(pmi_vals, bins=30, color="steelblue", edgecolor="black")
axes[2].axvline(0, color="red", linestyle="--", label="independence (PMI=0)")
axes[2].set_xlabel("Co-activation PMI")
axes[2].set_ylabel("# pairs")
axes[2].set_title(
f"Co-activation PMI between\nplan and mon experts (same layer)\n"
f"Mean PMI = {np.mean(pmi_vals):+.3f}", fontsize=12,
)
axes[2].legend()
else:
axes[2].text(0.5, 0.5, "No same-layer plan-mon pairs found", ha="center", va="center")
axes[2].axis("off")
plt.tight_layout()
plt.savefig(save_path, dpi=120)
plt.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("06_interaction", LOGS_DIR / "06_interaction.log")
# Load top experts
top_plan = read_json(TOP_EXPERTS_PLAN_PATH)
top_mon = read_json(TOP_EXPERTS_MON_PATH)
plan_pairs = [(d["layer"], d["expert"]) for d in top_plan["top_experts"]]
mon_pairs = [(d["layer"], d["expert"]) for d in top_mon["top_experts"]]
# 1) Jaccard overlap
jac = compute_jaccard(set(plan_pairs), set(mon_pairs))
log.info(f"Jaccard overlap (top-K experts): {jac:.3f}")
# 2) Cross-dim contrast
stats = np.load(RESULTS_DIR / "routing_stats.npz")
delta_plan = stats["delta_plan_vs_exec"]
delta_mon = stats["delta_mon_vs_exec"]
delta_crossdim = delta_plan - delta_mon # positive => plan-selective, negative => mon-selective
# 3) Same-layer PMI of plan-mon expert pairs
log.info("Loading routing shards for PMI...")
num_layers = MODEL_CONFIG["num_layers"]
topk_ids, sample_id_to_range = load_all_shards(ROUTING_DIR, num_layers)
labeled = read_jsonl(LABELED_COTS_PATH)
plan_tis = collect_global_token_indices(labeled, sample_id_to_range, "plan_decision_tis")
log.info(f"Computing PMI over {len(plan_tis)} planning decision points "
f"for same-layer (plan_expert, mon_expert) pairs...")
pmi_pairs = compute_pmi_pairwise(
topk_ids, plan_tis, plan_pairs, mon_pairs,
)
log.info(f"Computed PMI for {len(pmi_pairs)} same-layer pairs")
# 4) Summary & save
summary = {
"jaccard_overlap": float(jac),
"n_plan_experts": len(plan_pairs),
"n_mon_experts": len(mon_pairs),
"intersection": [list(p) for p in (set(plan_pairs) & set(mon_pairs))],
"n_pmi_pairs": len(pmi_pairs),
"pmi_pairs": [
{"plan_layer": k[0], "plan_expert": k[1],
"mon_layer": k[2], "mon_expert": k[3], **v}
for k, v in pmi_pairs.items()
],
}
if pmi_pairs:
pmi_vals = [v["pmi"] for v in pmi_pairs.values()]
summary["pmi_stats"] = {
"mean": float(np.mean(pmi_vals)),
"std": float(np.std(pmi_vals)),
"max": float(np.max(pmi_vals)),
"min": float(np.min(pmi_vals)),
}
write_json(summary, RESULTS_DIR / "interaction_summary.json")
# Plot
plot_interaction_heatmap(
jac, delta_crossdim, pmi_pairs, INTERACTION_HEATMAP,
plan_pairs, mon_pairs,
)
log.info(f"Saved interaction heatmap: {INTERACTION_HEATMAP}")
log.info(f"Saved interaction summary: {RESULTS_DIR / 'interaction_summary.json'}")
if __name__ == "__main__":
main()