| """ |
| Stage 1 part B: Aggregate routing shards, compute frequency differentials, |
| select top-K experts for each dimension. |
| """ |
| import sys |
| import argparse |
| import json |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import numpy as np |
| import torch |
|
|
| from configs.paths import ( |
| ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR, |
| TOP_EXPERTS_PLAN_PATH, TOP_EXPERTS_MON_PATH, TARGET_LAYERS_PATH, |
| RESULTS_DIR, ROUTING_HEATMAP_PLAN, ROUTING_HEATMAP_MON, |
| ) |
| from configs.model import MODEL_CONFIG, TOP_K_EXPERTS |
| from src.utils import setup_logger, read_jsonl, write_json |
| from src.expert_select import select_top_experts, get_target_layers |
|
|
|
|
| def load_all_shards(shards_dir: Path, num_layers: int): |
| """ |
| Concatenate all shards into one in-memory set of tensors. |
| Also return sample_boundaries (cumulative token offsets), and per-sample idx. |
| |
| Returns: |
| { |
| "topk_ids": {layer: (N_total, top_k) tensor}, |
| "topk_gates": {layer: (N_total, top_k) tensor}, |
| "sample_id_to_range": {sample_idx: (start, end)}, |
| } |
| """ |
| shard_files = sorted(shards_dir.glob("shard_*.pt")) |
| if not shard_files: |
| raise FileNotFoundError(f"No shards in {shards_dir}. Run 04 first.") |
|
|
| per_layer_ids = {li: [] for li in range(num_layers)} |
| per_layer_gates = {li: [] for li in range(num_layers)} |
| sample_id_to_range = {} |
| cursor = 0 |
|
|
| for sf in shard_files: |
| shard = torch.load(sf, map_location="cpu") |
| sample_ids = shard["sample_ids"] |
| sample_lengths = shard["sample_lengths"] |
| for sid, slen in zip(sample_ids, sample_lengths): |
| sample_id_to_range[sid] = (cursor, cursor + slen) |
| cursor += slen |
| for li in range(num_layers): |
| if li in shard["topk_ids"]: |
| per_layer_ids[li].append(shard["topk_ids"][li]) |
| per_layer_gates[li].append(shard["topk_gates"][li]) |
|
|
| out = { |
| "topk_ids": {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v}, |
| "topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v}, |
| "sample_id_to_range": sample_id_to_range, |
| "n_total": cursor, |
| } |
| return out |
|
|
|
|
| def collect_token_indices(labeled_records, sample_id_to_range, field: str): |
| """ |
| Convert per-CoT local token indices (field, e.g. "plan_decision_tis") to |
| GLOBAL indices into the concatenated tensor. |
| """ |
| out = [] |
| for r in labeled_records: |
| sid = r["idx"] |
| if sid not in sample_id_to_range: |
| continue |
| start, end = sample_id_to_range[sid] |
| for ti in r[field]: |
| gi = start + ti |
| if gi < end: |
| out.append(gi) |
| return out |
|
|
|
|
| def plot_routing_heatmap(freq_diff: np.ndarray, title: str, path: Path): |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| fig, ax = plt.subplots(figsize=(14, 8)) |
| sns.heatmap(freq_diff, cmap="coolwarm", center=0, cbar=True, ax=ax, |
| xticklabels=False, yticklabels=False) |
| ax.set_xlabel("Expert ID") |
| ax.set_ylabel("Layer ID") |
| ax.set_title(title) |
| plt.tight_layout() |
| plt.savefig(path, dpi=120) |
| plt.close() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--top_k", type=int, default=TOP_K_EXPERTS) |
| parser.add_argument("--resume", action="store_true") |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("05_select", LOGS_DIR / "05_select.log") |
|
|
| if args.resume and TOP_EXPERTS_PLAN_PATH.exists() and TOP_EXPERTS_MON_PATH.exists(): |
| log.info("Top-experts already saved. Skipping.") |
| return |
|
|
| num_layers = MODEL_CONFIG["num_layers"] |
| num_experts = MODEL_CONFIG["num_experts"] |
|
|
| log.info("Loading routing shards...") |
| routing_data = load_all_shards(ROUTING_DIR, num_layers) |
| log.info(f"Total tokens: {routing_data['n_total']}") |
|
|
| log.info("Loading labels...") |
| labeled = read_jsonl(LABELED_COTS_PATH) |
| plan_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"], |
| "plan_decision_tis") |
| mon_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"], |
| "mon_decision_tis") |
| exec_tis = collect_token_indices(labeled, routing_data["sample_id_to_range"], |
| "exec_decision_tis") |
| log.info(f"Global indices: plan={len(plan_tis)}, mon={len(mon_tis)}, exec={len(exec_tis)}") |
|
|
| if len(plan_tis) < 20 or len(mon_tis) < 20: |
| log.warning("Very few decision points — results will be unreliable") |
|
|
| log.info("Computing expert selection scores...") |
| results = select_top_experts( |
| routing_data, plan_tis, mon_tis, exec_tis, top_k=args.top_k, |
| ) |
|
|
| |
| def serialize_experts(pairs): |
| return [{"layer": l, "expert": e} for l, e in pairs] |
|
|
| top_plan_out = { |
| "top_experts": serialize_experts(results["top_experts_planning"]), |
| "target_layers": get_target_layers(results["top_experts_planning"]), |
| "n_plan_tokens": len(plan_tis), |
| "n_mon_tokens": len(mon_tis), |
| "n_exec_tokens": len(exec_tis), |
| "metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)", |
| } |
| top_mon_out = { |
| "top_experts": serialize_experts(results["top_experts_monitoring"]), |
| "target_layers": get_target_layers(results["top_experts_monitoring"]), |
| "n_plan_tokens": len(plan_tis), |
| "n_mon_tokens": len(mon_tis), |
| "n_exec_tokens": len(exec_tis), |
| "metric": "combined rank_norm(Δfreq) + rank_norm(log_ratio)", |
| } |
| write_json(top_plan_out, TOP_EXPERTS_PLAN_PATH) |
| write_json(top_mon_out, TOP_EXPERTS_MON_PATH) |
| log.info(f"Top-{args.top_k} planning experts saved: {TOP_EXPERTS_PLAN_PATH}") |
| log.info(f"Top-{args.top_k} monitoring experts saved: {TOP_EXPERTS_MON_PATH}") |
|
|
| |
| plan_layers = set(top_plan_out["target_layers"]) |
| mon_layers = set(top_mon_out["target_layers"]) |
| all_layers = sorted(plan_layers | mon_layers) |
| write_json({ |
| "planning_layers": top_plan_out["target_layers"], |
| "monitoring_layers": top_mon_out["target_layers"], |
| "union_layers": all_layers, |
| }, TARGET_LAYERS_PATH) |
| log.info(f"Target layers: planning={sorted(plan_layers)}") |
| log.info(f" monitoring={sorted(mon_layers)}") |
| log.info(f" union={all_layers}") |
|
|
| |
| log.info("Plotting routing heatmaps...") |
| plot_routing_heatmap( |
| results["delta_plan_vs_exec"], |
| "Planning vs Exec — P(expert in top-K | S_plan) − P(... | S_exec)", |
| ROUTING_HEATMAP_PLAN, |
| ) |
| plot_routing_heatmap( |
| results["delta_mon_vs_exec"], |
| "Monitoring vs Exec — P(expert in top-K | S_mon) − P(... | S_exec)", |
| ROUTING_HEATMAP_MON, |
| ) |
|
|
| |
| np.savez( |
| RESULTS_DIR / "routing_stats.npz", |
| freq_plan=results["freq_plan"], |
| freq_mon=results["freq_mon"], |
| freq_exec=results["freq_exec"], |
| delta_plan_vs_exec=results["delta_plan_vs_exec"], |
| delta_mon_vs_exec=results["delta_mon_vs_exec"], |
| delta_plan_vs_mon=results["delta_plan_vs_mon"], |
| logratio_plan_vs_exec=results["logratio_plan_vs_exec"], |
| logratio_mon_vs_exec=results["logratio_mon_vs_exec"], |
| ) |
| log.info("Saved raw stats -> routing_stats.npz") |
| log.info("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|