File size: 8,857 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""
Stage 2: Dimension interaction analysis.

Produces:
  - Jaccard overlap of top-K experts
  - Co-activation PMI between (plan_expert, mon_expert) pairs
  - Cross-dim contrast visualization
  - (Direction cosine matrix is produced later, after script 08)
"""
import sys
import argparse
import json
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

import numpy as np
import torch

from configs.paths import (
    ensure_dirs, LOGS_DIR, ROUTING_DIR, LABELED_COTS_PATH,
    TOP_EXPERTS_PLAN_PATH, TOP_EXPERTS_MON_PATH,
    RESULTS_DIR, INTERACTION_HEATMAP,
)
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, read_json, write_json


def compute_jaccard(set_a, set_b):
    if not set_a and not set_b:
        return 0.0
    return len(set_a & set_b) / len(set_a | set_b)


def compute_pmi_matrix(topk_ids_by_layer, token_indices, n_layers, n_experts, eps=1e-6):
    """
    For given tokens, compute co-activation PMI between all (expert_i, expert_j) in each layer.

    Returns list of (L, E, E) matrices — too large for E=128 × 48 layers.
    Instead, compute PMI ONLY between the top planning experts and top monitoring experts.
    """
    raise NotImplementedError("Use compute_pmi_pairwise instead.")


def compute_pmi_pairwise(topk_ids_by_layer, token_indices, plan_experts, mon_experts, eps=1e-6):
    """
    Compute co-activation PMI between pairs of (plan_expert, mon_expert).

    For each token t in token_indices, check:
      - is plan_expert e_p active at its layer l_p?
      - is mon_expert e_m active at its layer l_m?
      - both active?

    Pairs with SAME LAYER yield strongest co-activation signals
    (since topK can include both simultaneously).

    Returns a dict: {(l_p, e_p, l_m, e_m): pmi}

    To avoid combinatorial explosion, we only compute pairs where l_p == l_m
    (same-layer co-activation).
    """
    n = len(token_indices)
    if n == 0:
        return {}
    idx_tensor = torch.tensor(token_indices, dtype=torch.long)

    # For each (layer, expert), build activation mask
    def expert_active(layer, expert):
        topk = topk_ids_by_layer[layer][idx_tensor].numpy()   # (n, top_k)
        return (topk == expert).any(axis=1)   # (n,) bool

    results = {}
    for (lp, ep) in plan_experts:
        for (lm, em) in mon_experts:
            if lp != lm:
                continue
            act_p = expert_active(lp, ep)
            act_m = expert_active(lm, em)
            p_p = act_p.mean() + eps
            p_m = act_m.mean() + eps
            p_pm = (act_p & act_m).mean() + eps
            pmi = float(np.log(p_pm / (p_p * p_m)))
            results[(lp, ep, lm, em)] = {
                "pmi": pmi,
                "P_plan": float(p_p),
                "P_mon": float(p_m),
                "P_joint": float(p_pm),
            }
    return results


def load_all_shards(shards_dir, num_layers):
    """Reuse simplified loader. Only need topk_ids here."""
    shard_files = sorted(shards_dir.glob("shard_*.pt"))
    per_layer_ids = {li: [] for li in range(num_layers)}
    sample_id_to_range = {}
    cursor = 0
    for sf in shard_files:
        shard = torch.load(sf, map_location="cpu")
        for sid, slen in zip(shard["sample_ids"], shard["sample_lengths"]):
            sample_id_to_range[sid] = (cursor, cursor + slen)
            cursor += slen
        for li in range(num_layers):
            if li in shard["topk_ids"]:
                per_layer_ids[li].append(shard["topk_ids"][li])
    topk_ids = {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v}
    return topk_ids, sample_id_to_range


def collect_global_token_indices(labeled, sample_id_to_range, field):
    out = []
    for r in labeled:
        sid = r["idx"]
        if sid not in sample_id_to_range:
            continue
        start, end = sample_id_to_range[sid]
        for ti in r[field]:
            gi = start + ti
            if gi < end:
                out.append(gi)
    return out


def plot_interaction_heatmap(
    jaccard_value, delta_crossdim, pmi_pairs, save_path,
    plan_experts, mon_experts,
):
    import matplotlib.pyplot as plt
    import seaborn as sns

    fig, axes = plt.subplots(1, 3, figsize=(24, 7))

    # (1) Jaccard as a text box
    axes[0].axis("off")
    axes[0].text(0.5, 0.5,
                 f"Jaccard overlap of top-K experts\n\n"
                 f"J = |E_plan ∩ E_mon| / |E_plan ∪ E_mon|\n\n"
                 f"J = {jaccard_value:.3f}\n\n"
                 f"|E_plan| = {len(plan_experts)}\n"
                 f"|E_mon|  = {len(mon_experts)}\n"
                 f"|intersection| = "
                 f"{len(set(map(tuple, plan_experts)) & set(map(tuple, mon_experts)))}",
                 ha="center", va="center", fontsize=14,
                 bbox=dict(boxstyle="round,pad=0.8", facecolor="lightblue"))
    axes[0].set_title("Top-K Expert Overlap", fontsize=14)

    # (2) Cross-dim contrast: Δfreq(plan) - Δfreq(mon)
    sns.heatmap(delta_crossdim, cmap="coolwarm", center=0, ax=axes[1],
                xticklabels=False, yticklabels=False)
    axes[1].set_xlabel("Expert ID")
    axes[1].set_ylabel("Layer ID")
    axes[1].set_title("Δfreq(plan) − Δfreq(mon)\n(experts that distinguish plan from mon)",
                      fontsize=14)

    # (3) PMI pair distribution
    if pmi_pairs:
        pmi_vals = [v["pmi"] for v in pmi_pairs.values()]
        axes[2].hist(pmi_vals, bins=30, color="steelblue", edgecolor="black")
        axes[2].axvline(0, color="red", linestyle="--", label="independence (PMI=0)")
        axes[2].set_xlabel("Co-activation PMI")
        axes[2].set_ylabel("# pairs")
        axes[2].set_title(
            f"Co-activation PMI between\nplan and mon experts (same layer)\n"
            f"Mean PMI = {np.mean(pmi_vals):+.3f}", fontsize=12,
        )
        axes[2].legend()
    else:
        axes[2].text(0.5, 0.5, "No same-layer plan-mon pairs found", ha="center", va="center")
        axes[2].axis("off")

    plt.tight_layout()
    plt.savefig(save_path, dpi=120)
    plt.close()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--resume", action="store_true")
    args = parser.parse_args()

    ensure_dirs()
    log = setup_logger("06_interaction", LOGS_DIR / "06_interaction.log")

    # Load top experts
    top_plan = read_json(TOP_EXPERTS_PLAN_PATH)
    top_mon = read_json(TOP_EXPERTS_MON_PATH)
    plan_pairs = [(d["layer"], d["expert"]) for d in top_plan["top_experts"]]
    mon_pairs = [(d["layer"], d["expert"]) for d in top_mon["top_experts"]]

    # 1) Jaccard overlap
    jac = compute_jaccard(set(plan_pairs), set(mon_pairs))
    log.info(f"Jaccard overlap (top-K experts): {jac:.3f}")

    # 2) Cross-dim contrast
    stats = np.load(RESULTS_DIR / "routing_stats.npz")
    delta_plan = stats["delta_plan_vs_exec"]
    delta_mon = stats["delta_mon_vs_exec"]
    delta_crossdim = delta_plan - delta_mon   # positive => plan-selective, negative => mon-selective

    # 3) Same-layer PMI of plan-mon expert pairs
    log.info("Loading routing shards for PMI...")
    num_layers = MODEL_CONFIG["num_layers"]
    topk_ids, sample_id_to_range = load_all_shards(ROUTING_DIR, num_layers)
    labeled = read_jsonl(LABELED_COTS_PATH)
    plan_tis = collect_global_token_indices(labeled, sample_id_to_range, "plan_decision_tis")
    log.info(f"Computing PMI over {len(plan_tis)} planning decision points "
             f"for same-layer (plan_expert, mon_expert) pairs...")
    pmi_pairs = compute_pmi_pairwise(
        topk_ids, plan_tis, plan_pairs, mon_pairs,
    )
    log.info(f"Computed PMI for {len(pmi_pairs)} same-layer pairs")

    # 4) Summary & save
    summary = {
        "jaccard_overlap": float(jac),
        "n_plan_experts": len(plan_pairs),
        "n_mon_experts": len(mon_pairs),
        "intersection": [list(p) for p in (set(plan_pairs) & set(mon_pairs))],
        "n_pmi_pairs": len(pmi_pairs),
        "pmi_pairs": [
            {"plan_layer": k[0], "plan_expert": k[1],
             "mon_layer":  k[2], "mon_expert":  k[3], **v}
            for k, v in pmi_pairs.items()
        ],
    }
    if pmi_pairs:
        pmi_vals = [v["pmi"] for v in pmi_pairs.values()]
        summary["pmi_stats"] = {
            "mean": float(np.mean(pmi_vals)),
            "std": float(np.std(pmi_vals)),
            "max": float(np.max(pmi_vals)),
            "min": float(np.min(pmi_vals)),
        }
    write_json(summary, RESULTS_DIR / "interaction_summary.json")

    # Plot
    plot_interaction_heatmap(
        jac, delta_crossdim, pmi_pairs, INTERACTION_HEATMAP,
        plan_pairs, mon_pairs,
    )
    log.info(f"Saved interaction heatmap: {INTERACTION_HEATMAP}")
    log.info(f"Saved interaction summary: {RESULTS_DIR / 'interaction_summary.json'}")


if __name__ == "__main__":
    main()