File size: 7,615 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | """
Stage 1 part B: Aggregate routing shards, compute frequency differentials,
select top-K experts for each dimension.
"""
import sys
import argparse
import json
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import numpy as np
import torch
from configs.paths import (
ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR,
TOP_EXPERTS_PLAN_PATH, TOP_EXPERTS_MON_PATH, TARGET_LAYERS_PATH,
RESULTS_DIR, ROUTING_HEATMAP_PLAN, ROUTING_HEATMAP_MON,
)
from configs.model import MODEL_CONFIG, TOP_K_EXPERTS
from src.utils import setup_logger, read_jsonl, write_json
from src.expert_select import select_top_experts, get_target_layers
def load_all_shards(shards_dir: Path, num_layers: int):
"""
Concatenate all shards into one in-memory set of tensors.
Also return sample_boundaries (cumulative token offsets), and per-sample idx.
Returns:
{
"topk_ids": {layer: (N_total, top_k) tensor},
"topk_gates": {layer: (N_total, top_k) tensor},
"sample_id_to_range": {sample_idx: (start, end)},
}
"""
shard_files = sorted(shards_dir.glob("shard_*.pt"))
if not shard_files:
raise FileNotFoundError(f"No shards in {shards_dir}. Run 04 first.")
per_layer_ids = {li: [] for li in range(num_layers)}
per_layer_gates = {li: [] for li in range(num_layers)}
sample_id_to_range = {}
cursor = 0
for sf in shard_files:
shard = torch.load(sf, map_location="cpu")
sample_ids = shard["sample_ids"]
sample_lengths = shard["sample_lengths"]
for sid, slen in zip(sample_ids, sample_lengths):
sample_id_to_range[sid] = (cursor, cursor + slen)
cursor += slen
for li in range(num_layers):
if li in shard["topk_ids"]:
per_layer_ids[li].append(shard["topk_ids"][li])
per_layer_gates[li].append(shard["topk_gates"][li])
out = {
"topk_ids": {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v},
"topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v},
"sample_id_to_range": sample_id_to_range,
"n_total": cursor,
}
return out
def collect_token_indices(labeled_records, sample_id_to_range, field: str):
"""
Convert per-CoT local token indices (field, e.g. "plan_decision_tis") to
GLOBAL indices into the concatenated tensor.
"""
out = []
for r in labeled_records:
sid = r["idx"]
if sid not in sample_id_to_range:
continue
start, end = sample_id_to_range[sid]
for ti in r[field]:
gi = start + ti
if gi < end:
out.append(gi)
return out
def plot_routing_heatmap(freq_diff: np.ndarray, title: str, path: Path):
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(figsize=(14, 8))
sns.heatmap(freq_diff, cmap="coolwarm", center=0, cbar=True, ax=ax,
xticklabels=False, yticklabels=False)
ax.set_xlabel("Expert ID")
ax.set_ylabel("Layer ID")
ax.set_title(title)
plt.tight_layout()
plt.savefig(path, dpi=120)
plt.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--top_k", type=int, default=TOP_K_EXPERTS)
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("05_select", LOGS_DIR / "05_select.log")
if args.resume and TOP_EXPERTS_PLAN_PATH.exists() and TOP_EXPERTS_MON_PATH.exists():
log.info("Top-experts already saved. Skipping.")
return
num_layers = MODEL_CONFIG["num_layers"]
num_experts = MODEL_CONFIG["num_experts"]
log.info("Loading routing shards...")
routing_data = load_all_shards(ROUTING_DIR, num_layers)
log.info(f"Total tokens: {routing_data['n_total']}")
log.info("Loading labels...")
labeled = read_jsonl(LABELED_COTS_PATH)
plan_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"],
"plan_decision_tis")
mon_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"],
"mon_decision_tis")
exec_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"],
"exec_decision_tis")
log.info(f"Global indices: plan={len(plan_tis)}, mon={len(mon_tis)}, exec={len(exec_tis)}")
if len(plan_tis) < 20 or len(mon_tis) < 20:
log.warning("Very few decision points — results will be unreliable")
log.info("Computing expert selection scores...")
results = select_top_experts(
routing_data, plan_tis, mon_tis, exec_tis, top_k=args.top_k,
)
# Save top experts
def serialize_experts(pairs):
return [{"layer": l, "expert": e} for l, e in pairs]
top_plan_out = {
"top_experts": serialize_experts(results["top_experts_planning"]),
"target_layers": get_target_layers(results["top_experts_planning"]),
"n_plan_tokens": len(plan_tis),
"n_mon_tokens": len(mon_tis),
"n_exec_tokens": len(exec_tis),
"metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)",
}
top_mon_out = {
"top_experts": serialize_experts(results["top_experts_monitoring"]),
"target_layers": get_target_layers(results["top_experts_monitoring"]),
"n_plan_tokens": len(plan_tis),
"n_mon_tokens": len(mon_tis),
"n_exec_tokens": len(exec_tis),
"metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)",
}
write_json(top_plan_out, TOP_EXPERTS_PLAN_PATH)
write_json(top_mon_out, TOP_EXPERTS_MON_PATH)
log.info(f"Top-{args.top_k} planning experts saved: {TOP_EXPERTS_PLAN_PATH}")
log.info(f"Top-{args.top_k} monitoring experts saved: {TOP_EXPERTS_MON_PATH}")
# Save unified target layers (union of plan and mon)
plan_layers = set(top_plan_out["target_layers"])
mon_layers = set(top_mon_out["target_layers"])
all_layers = sorted(plan_layers | mon_layers)
write_json({
"planning_layers": top_plan_out["target_layers"],
"monitoring_layers": top_mon_out["target_layers"],
"union_layers": all_layers,
}, TARGET_LAYERS_PATH)
log.info(f"Target layers: planning={sorted(plan_layers)}")
log.info(f" monitoring={sorted(mon_layers)}")
log.info(f" union={all_layers}")
# Plot heatmaps
log.info("Plotting routing heatmaps...")
plot_routing_heatmap(
results["delta_plan_vs_exec"],
"Planning vs Exec — P(expert in top-K | S_plan) − P(... | S_exec)",
ROUTING_HEATMAP_PLAN,
)
plot_routing_heatmap(
results["delta_mon_vs_exec"],
"Monitoring vs Exec — P(expert in top-K | S_mon) − P(... | S_exec)",
ROUTING_HEATMAP_MON,
)
# Save raw stats to results dir for later inspection
np.savez(
RESULTS_DIR / "routing_stats.npz",
freq_plan=results["freq_plan"],
freq_mon=results["freq_mon"],
freq_exec=results["freq_exec"],
delta_plan_vs_exec=results["delta_plan_vs_exec"],
delta_mon_vs_exec=results["delta_mon_vs_exec"],
delta_plan_vs_mon=results["delta_plan_vs_mon"],
logratio_plan_vs_exec=results["logratio_plan_vs_exec"],
logratio_mon_vs_exec=results["logratio_mon_vs_exec"],
)
log.info("Saved raw stats -> routing_stats.npz")
log.info("Done.")
if __name__ == "__main__":
main()
|