File size: 7,615 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
Stage 1 part B: Aggregate routing shards, compute frequency differentials,
select top-K experts for each dimension.
"""
import sys
import argparse
import json
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

import numpy as np
import torch

from configs.paths import (
    ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR,
    TOP_EXPERTS_PLAN_PATH, TOP_EXPERTS_MON_PATH, TARGET_LAYERS_PATH,
    RESULTS_DIR, ROUTING_HEATMAP_PLAN, ROUTING_HEATMAP_MON,
)
from configs.model import MODEL_CONFIG, TOP_K_EXPERTS
from src.utils import setup_logger, read_jsonl, write_json
from src.expert_select import select_top_experts, get_target_layers


def load_all_shards(shards_dir: Path, num_layers: int):
    """
    Concatenate all shards into one in-memory set of tensors.
    Also return sample_boundaries (cumulative token offsets), and per-sample idx.

    Returns:
      {
        "topk_ids":   {layer: (N_total, top_k) tensor},
        "topk_gates": {layer: (N_total, top_k) tensor},
        "sample_id_to_range": {sample_idx: (start, end)},
      }
    """
    shard_files = sorted(shards_dir.glob("shard_*.pt"))
    if not shard_files:
        raise FileNotFoundError(f"No shards in {shards_dir}. Run 04 first.")

    per_layer_ids = {li: [] for li in range(num_layers)}
    per_layer_gates = {li: [] for li in range(num_layers)}
    sample_id_to_range = {}
    cursor = 0

    for sf in shard_files:
        shard = torch.load(sf, map_location="cpu")
        sample_ids = shard["sample_ids"]
        sample_lengths = shard["sample_lengths"]
        for sid, slen in zip(sample_ids, sample_lengths):
            sample_id_to_range[sid] = (cursor, cursor + slen)
            cursor += slen
        for li in range(num_layers):
            if li in shard["topk_ids"]:
                per_layer_ids[li].append(shard["topk_ids"][li])
                per_layer_gates[li].append(shard["topk_gates"][li])

    out = {
        "topk_ids":   {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v},
        "topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v},
        "sample_id_to_range": sample_id_to_range,
        "n_total": cursor,
    }
    return out


def collect_token_indices(labeled_records, sample_id_to_range, field: str):
    """
    Convert per-CoT local token indices (field, e.g. "plan_decision_tis") to
    GLOBAL indices into the concatenated tensor.
    """
    out = []
    for r in labeled_records:
        sid = r["idx"]
        if sid not in sample_id_to_range:
            continue
        start, end = sample_id_to_range[sid]
        for ti in r[field]:
            gi = start + ti
            if gi < end:
                out.append(gi)
    return out


def plot_routing_heatmap(freq_diff: np.ndarray, title: str, path: Path):
    import matplotlib.pyplot as plt
    import seaborn as sns
    fig, ax = plt.subplots(figsize=(14, 8))
    sns.heatmap(freq_diff, cmap="coolwarm", center=0, cbar=True, ax=ax,
                xticklabels=False, yticklabels=False)
    ax.set_xlabel("Expert ID")
    ax.set_ylabel("Layer ID")
    ax.set_title(title)
    plt.tight_layout()
    plt.savefig(path, dpi=120)
    plt.close()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--top_k", type=int, default=TOP_K_EXPERTS)
    parser.add_argument("--resume", action="store_true")
    args = parser.parse_args()

    ensure_dirs()
    log = setup_logger("05_select", LOGS_DIR / "05_select.log")

    if args.resume and TOP_EXPERTS_PLAN_PATH.exists() and TOP_EXPERTS_MON_PATH.exists():
        log.info("Top-experts already saved. Skipping.")
        return

    num_layers = MODEL_CONFIG["num_layers"]
    num_experts = MODEL_CONFIG["num_experts"]

    log.info("Loading routing shards...")
    routing_data = load_all_shards(ROUTING_DIR, num_layers)
    log.info(f"Total tokens: {routing_data['n_total']}")

    log.info("Loading labels...")
    labeled = read_jsonl(LABELED_COTS_PATH)
    plan_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"],
                                      "plan_decision_tis")
    mon_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"],
                                     "mon_decision_tis")
    exec_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"],
                                      "exec_decision_tis")
    log.info(f"Global indices: plan={len(plan_tis)}, mon={len(mon_tis)}, exec={len(exec_tis)}")

    if len(plan_tis) < 20 or len(mon_tis) < 20:
        log.warning("Very few decision points — results will be unreliable")

    log.info("Computing expert selection scores...")
    results = select_top_experts(
        routing_data, plan_tis, mon_tis, exec_tis, top_k=args.top_k,
    )

    # Save top experts
    def serialize_experts(pairs):
        return [{"layer": l, "expert": e} for l, e in pairs]

    top_plan_out = {
        "top_experts": serialize_experts(results["top_experts_planning"]),
        "target_layers": get_target_layers(results["top_experts_planning"]),
        "n_plan_tokens": len(plan_tis),
        "n_mon_tokens":  len(mon_tis),
        "n_exec_tokens": len(exec_tis),
        "metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)",
    }
    top_mon_out = {
        "top_experts": serialize_experts(results["top_experts_monitoring"]),
        "target_layers": get_target_layers(results["top_experts_monitoring"]),
        "n_plan_tokens": len(plan_tis),
        "n_mon_tokens":  len(mon_tis),
        "n_exec_tokens": len(exec_tis),
        "metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)",
    }
    write_json(top_plan_out, TOP_EXPERTS_PLAN_PATH)
    write_json(top_mon_out, TOP_EXPERTS_MON_PATH)
    log.info(f"Top-{args.top_k} planning experts saved: {TOP_EXPERTS_PLAN_PATH}")
    log.info(f"Top-{args.top_k} monitoring experts saved: {TOP_EXPERTS_MON_PATH}")

    # Save unified target layers (union of plan and mon)
    plan_layers = set(top_plan_out["target_layers"])
    mon_layers = set(top_mon_out["target_layers"])
    all_layers = sorted(plan_layers | mon_layers)
    write_json({
        "planning_layers":   top_plan_out["target_layers"],
        "monitoring_layers": top_mon_out["target_layers"],
        "union_layers":      all_layers,
    }, TARGET_LAYERS_PATH)
    log.info(f"Target layers: planning={sorted(plan_layers)}")
    log.info(f"                monitoring={sorted(mon_layers)}")
    log.info(f"                union={all_layers}")

    # Plot heatmaps
    log.info("Plotting routing heatmaps...")
    plot_routing_heatmap(
        results["delta_plan_vs_exec"],
        "Planning vs Exec  —  P(expert in top-K | S_plan) − P(... | S_exec)",
        ROUTING_HEATMAP_PLAN,
    )
    plot_routing_heatmap(
        results["delta_mon_vs_exec"],
        "Monitoring vs Exec  —  P(expert in top-K | S_mon) − P(... | S_exec)",
        ROUTING_HEATMAP_MON,
    )

    # Save raw stats to results dir for later inspection
    np.savez(
        RESULTS_DIR / "routing_stats.npz",
        freq_plan=results["freq_plan"],
        freq_mon=results["freq_mon"],
        freq_exec=results["freq_exec"],
        delta_plan_vs_exec=results["delta_plan_vs_exec"],
        delta_mon_vs_exec=results["delta_mon_vs_exec"],
        delta_plan_vs_mon=results["delta_plan_vs_mon"],
        logratio_plan_vs_exec=results["logratio_plan_vs_exec"],
        logratio_mon_vs_exec=results["logratio_mon_vs_exec"],
    )
    log.info("Saved raw stats -> routing_stats.npz")
    log.info("Done.")


if __name__ == "__main__":
    main()