""" Stage 1 part B: Aggregate routing shards, compute frequency differentials, select top-K experts for each dimension. """ import sys import argparse import json from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import numpy as np import torch from configs.paths import ( ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR, TOP_EXPERTS_PLAN_PATH, TOP_EXPERTS_MON_PATH, TARGET_LAYERS_PATH, RESULTS_DIR, ROUTING_HEATMAP_PLAN, ROUTING_HEATMAP_MON, ) from configs.model import MODEL_CONFIG, TOP_K_EXPERTS from src.utils import setup_logger, read_jsonl, write_json from src.expert_select import select_top_experts, get_target_layers def load_all_shards(shards_dir: Path, num_layers: int): """ Concatenate all shards into one in-memory set of tensors. Also return sample_boundaries (cumulative token offsets), and per-sample idx. Returns: { "topk_ids": {layer: (N_total, top_k) tensor}, "topk_gates": {layer: (N_total, top_k) tensor}, "sample_id_to_range": {sample_idx: (start, end)}, } """ shard_files = sorted(shards_dir.glob("shard_*.pt")) if not shard_files: raise FileNotFoundError(f"No shards in {shards_dir}. Run 04 first.") per_layer_ids = {li: [] for li in range(num_layers)} per_layer_gates = {li: [] for li in range(num_layers)} sample_id_to_range = {} cursor = 0 for sf in shard_files: shard = torch.load(sf, map_location="cpu") sample_ids = shard["sample_ids"] sample_lengths = shard["sample_lengths"] for sid, slen in zip(sample_ids, sample_lengths): sample_id_to_range[sid] = (cursor, cursor + slen) cursor += slen for li in range(num_layers): if li in shard["topk_ids"]: per_layer_ids[li].append(shard["topk_ids"][li]) per_layer_gates[li].append(shard["topk_gates"][li]) out = { "topk_ids": {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v}, "topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v}, "sample_id_to_range": sample_id_to_range, "n_total": cursor, } return out def collect_token_indices(labeled_records, sample_id_to_range, field: str): """ Convert per-CoT local token indices (field, e.g. "plan_decision_tis") to GLOBAL indices into the concatenated tensor. """ out = [] for r in labeled_records: sid = r["idx"] if sid not in sample_id_to_range: continue start, end = sample_id_to_range[sid] for ti in r[field]: gi = start + ti if gi < end: out.append(gi) return out def plot_routing_heatmap(freq_diff: np.ndarray, title: str, path: Path): import matplotlib.pyplot as plt import seaborn as sns fig, ax = plt.subplots(figsize=(14, 8)) sns.heatmap(freq_diff, cmap="coolwarm", center=0, cbar=True, ax=ax, xticklabels=False, yticklabels=False) ax.set_xlabel("Expert ID") ax.set_ylabel("Layer ID") ax.set_title(title) plt.tight_layout() plt.savefig(path, dpi=120) plt.close() def main(): parser = argparse.ArgumentParser() parser.add_argument("--top_k", type=int, default=TOP_K_EXPERTS) parser.add_argument("--resume", action="store_true") args = parser.parse_args() ensure_dirs() log = setup_logger("05_select", LOGS_DIR / "05_select.log") if args.resume and TOP_EXPERTS_PLAN_PATH.exists() and TOP_EXPERTS_MON_PATH.exists(): log.info("Top-experts already saved. Skipping.") return num_layers = MODEL_CONFIG["num_layers"] num_experts = MODEL_CONFIG["num_experts"] log.info("Loading routing shards...") routing_data = load_all_shards(ROUTING_DIR, num_layers) log.info(f"Total tokens: {routing_data['n_total']}") log.info("Loading labels...") labeled = read_jsonl(LABELED_COTS_PATH) plan_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"], "plan_decision_tis") mon_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"], "mon_decision_tis") exec_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"], "exec_decision_tis") log.info(f"Global indices: plan={len(plan_tis)}, mon={len(mon_tis)}, exec={len(exec_tis)}") if len(plan_tis) < 20 or len(mon_tis) < 20: log.warning("Very few decision points — results will be unreliable") log.info("Computing expert selection scores...") results = select_top_experts( routing_data, plan_tis, mon_tis, exec_tis, top_k=args.top_k, ) # Save top experts def serialize_experts(pairs): return [{"layer": l, "expert": e} for l, e in pairs] top_plan_out = { "top_experts": serialize_experts(results["top_experts_planning"]), "target_layers": get_target_layers(results["top_experts_planning"]), "n_plan_tokens": len(plan_tis), "n_mon_tokens": len(mon_tis), "n_exec_tokens": len(exec_tis), "metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)", } top_mon_out = { "top_experts": serialize_experts(results["top_experts_monitoring"]), "target_layers": get_target_layers(results["top_experts_monitoring"]), "n_plan_tokens": len(plan_tis), "n_mon_tokens": len(mon_tis), "n_exec_tokens": len(exec_tis), "metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)", } write_json(top_plan_out, TOP_EXPERTS_PLAN_PATH) write_json(top_mon_out, TOP_EXPERTS_MON_PATH) log.info(f"Top-{args.top_k} planning experts saved: {TOP_EXPERTS_PLAN_PATH}") log.info(f"Top-{args.top_k} monitoring experts saved: {TOP_EXPERTS_MON_PATH}") # Save unified target layers (union of plan and mon) plan_layers = set(top_plan_out["target_layers"]) mon_layers = set(top_mon_out["target_layers"]) all_layers = sorted(plan_layers | mon_layers) write_json({ "planning_layers": top_plan_out["target_layers"], "monitoring_layers": top_mon_out["target_layers"], "union_layers": all_layers, }, TARGET_LAYERS_PATH) log.info(f"Target layers: planning={sorted(plan_layers)}") log.info(f" monitoring={sorted(mon_layers)}") log.info(f" union={all_layers}") # Plot heatmaps log.info("Plotting routing heatmaps...") plot_routing_heatmap( results["delta_plan_vs_exec"], "Planning vs Exec — P(expert in top-K | S_plan) − P(... | S_exec)", ROUTING_HEATMAP_PLAN, ) plot_routing_heatmap( results["delta_mon_vs_exec"], "Monitoring vs Exec — P(expert in top-K | S_mon) − P(... | S_exec)", ROUTING_HEATMAP_MON, ) # Save raw stats to results dir for later inspection np.savez( RESULTS_DIR / "routing_stats.npz", freq_plan=results["freq_plan"], freq_mon=results["freq_mon"], freq_exec=results["freq_exec"], delta_plan_vs_exec=results["delta_plan_vs_exec"], delta_mon_vs_exec=results["delta_mon_vs_exec"], delta_plan_vs_mon=results["delta_plan_vs_mon"], logratio_plan_vs_exec=results["logratio_plan_vs_exec"], logratio_mon_vs_exec=results["logratio_mon_vs_exec"], ) log.info("Saved raw stats -> routing_stats.npz") log.info("Done.") if __name__ == "__main__": main()