v2 / scripts /05_select_top_experts.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Stage 1 part B: Aggregate routing shards, compute frequency differentials,
select top-K experts for each dimension.
"""
import sys
import argparse
import json
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import numpy as np
import torch
from configs.paths import (
ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR,
TOP_EXPERTS_PLAN_PATH, TOP_EXPERTS_MON_PATH, TARGET_LAYERS_PATH,
RESULTS_DIR, ROUTING_HEATMAP_PLAN, ROUTING_HEATMAP_MON,
)
from configs.model import MODEL_CONFIG, TOP_K_EXPERTS
from src.utils import setup_logger, read_jsonl, write_json
from src.expert_select import select_top_experts, get_target_layers
def load_all_shards(shards_dir: Path, num_layers: int):
"""
Concatenate all shards into one in-memory set of tensors.
Also return sample_boundaries (cumulative token offsets), and per-sample idx.
Returns:
{
"topk_ids": {layer: (N_total, top_k) tensor},
"topk_gates": {layer: (N_total, top_k) tensor},
"sample_id_to_range": {sample_idx: (start, end)},
}
"""
shard_files = sorted(shards_dir.glob("shard_*.pt"))
if not shard_files:
raise FileNotFoundError(f"No shards in {shards_dir}. Run 04 first.")
per_layer_ids = {li: [] for li in range(num_layers)}
per_layer_gates = {li: [] for li in range(num_layers)}
sample_id_to_range = {}
cursor = 0
for sf in shard_files:
shard = torch.load(sf, map_location="cpu")
sample_ids = shard["sample_ids"]
sample_lengths = shard["sample_lengths"]
for sid, slen in zip(sample_ids, sample_lengths):
sample_id_to_range[sid] = (cursor, cursor + slen)
cursor += slen
for li in range(num_layers):
if li in shard["topk_ids"]:
per_layer_ids[li].append(shard["topk_ids"][li])
per_layer_gates[li].append(shard["topk_gates"][li])
out = {
"topk_ids": {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v},
"topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v},
"sample_id_to_range": sample_id_to_range,
"n_total": cursor,
}
return out
def collect_token_indices(labeled_records, sample_id_to_range, field: str):
"""
Convert per-CoT local token indices (field, e.g. "plan_decision_tis") to
GLOBAL indices into the concatenated tensor.
"""
out = []
for r in labeled_records:
sid = r["idx"]
if sid not in sample_id_to_range:
continue
start, end = sample_id_to_range[sid]
for ti in r[field]:
gi = start + ti
if gi < end:
out.append(gi)
return out
def plot_routing_heatmap(freq_diff: np.ndarray, title: str, path: Path):
import matplotlib.pyplot as plt
import seaborn as sns
fig, ax = plt.subplots(figsize=(14, 8))
sns.heatmap(freq_diff, cmap="coolwarm", center=0, cbar=True, ax=ax,
xticklabels=False, yticklabels=False)
ax.set_xlabel("Expert ID")
ax.set_ylabel("Layer ID")
ax.set_title(title)
plt.tight_layout()
plt.savefig(path, dpi=120)
plt.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--top_k", type=int, default=TOP_K_EXPERTS)
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("05_select", LOGS_DIR / "05_select.log")
if args.resume and TOP_EXPERTS_PLAN_PATH.exists() and TOP_EXPERTS_MON_PATH.exists():
log.info("Top-experts already saved. Skipping.")
return
num_layers = MODEL_CONFIG["num_layers"]
num_experts = MODEL_CONFIG["num_experts"]
log.info("Loading routing shards...")
routing_data = load_all_shards(ROUTING_DIR, num_layers)
log.info(f"Total tokens: {routing_data['n_total']}")
log.info("Loading labels...")
labeled = read_jsonl(LABELED_COTS_PATH)
plan_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"],
"plan_decision_tis")
mon_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"],
"mon_decision_tis")
exec_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"],
"exec_decision_tis")
log.info(f"Global indices: plan={len(plan_tis)}, mon={len(mon_tis)}, exec={len(exec_tis)}")
if len(plan_tis) < 20 or len(mon_tis) < 20:
log.warning("Very few decision points — results will be unreliable")
log.info("Computing expert selection scores...")
results = select_top_experts(
routing_data, plan_tis, mon_tis, exec_tis, top_k=args.top_k,
)
# Save top experts
def serialize_experts(pairs):
return [{"layer": l, "expert": e} for l, e in pairs]
top_plan_out = {
"top_experts": serialize_experts(results["top_experts_planning"]),
"target_layers": get_target_layers(results["top_experts_planning"]),
"n_plan_tokens": len(plan_tis),
"n_mon_tokens": len(mon_tis),
"n_exec_tokens": len(exec_tis),
"metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)",
}
top_mon_out = {
"top_experts": serialize_experts(results["top_experts_monitoring"]),
"target_layers": get_target_layers(results["top_experts_monitoring"]),
"n_plan_tokens": len(plan_tis),
"n_mon_tokens": len(mon_tis),
"n_exec_tokens": len(exec_tis),
"metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)",
}
write_json(top_plan_out, TOP_EXPERTS_PLAN_PATH)
write_json(top_mon_out, TOP_EXPERTS_MON_PATH)
log.info(f"Top-{args.top_k} planning experts saved: {TOP_EXPERTS_PLAN_PATH}")
log.info(f"Top-{args.top_k} monitoring experts saved: {TOP_EXPERTS_MON_PATH}")
# Save unified target layers (union of plan and mon)
plan_layers = set(top_plan_out["target_layers"])
mon_layers = set(top_mon_out["target_layers"])
all_layers = sorted(plan_layers | mon_layers)
write_json({
"planning_layers": top_plan_out["target_layers"],
"monitoring_layers": top_mon_out["target_layers"],
"union_layers": all_layers,
}, TARGET_LAYERS_PATH)
log.info(f"Target layers: planning={sorted(plan_layers)}")
log.info(f" monitoring={sorted(mon_layers)}")
log.info(f" union={all_layers}")
# Plot heatmaps
log.info("Plotting routing heatmaps...")
plot_routing_heatmap(
results["delta_plan_vs_exec"],
"Planning vs Exec — P(expert in top-K | S_plan) − P(... | S_exec)",
ROUTING_HEATMAP_PLAN,
)
plot_routing_heatmap(
results["delta_mon_vs_exec"],
"Monitoring vs Exec — P(expert in top-K | S_mon) − P(... | S_exec)",
ROUTING_HEATMAP_MON,
)
# Save raw stats to results dir for later inspection
np.savez(
RESULTS_DIR / "routing_stats.npz",
freq_plan=results["freq_plan"],
freq_mon=results["freq_mon"],
freq_exec=results["freq_exec"],
delta_plan_vs_exec=results["delta_plan_vs_exec"],
delta_mon_vs_exec=results["delta_mon_vs_exec"],
delta_plan_vs_mon=results["delta_plan_vs_mon"],
logratio_plan_vs_exec=results["logratio_plan_vs_exec"],
logratio_mon_vs_exec=results["logratio_mon_vs_exec"],
)
log.info("Saved raw stats -> routing_stats.npz")
log.info("Done.")
if __name__ == "__main__":
main()