""" Stage 1 part A: Run forward pass over labeled CoTs, capture router top-k at every token. Output: sharded .pt files in data/routing/, 50 CoTs per shard. Each shard: { "sample_ids": [...], "sample_lengths": [...], # number of tokens per CoT "topk_ids": {layer_id: (N_total, top_k) int16}, "topk_gates": {layer_id: (N_total, top_k) float16}, } """ import sys import argparse from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch from tqdm import tqdm from configs.paths import ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR from configs.model import MODEL_CONFIG from src.utils import setup_logger, read_jsonl, cleanup_memory, get_vram_mb from src.model_io import load_model_and_tokenizer from src.routing_capture import RoutingCapture SHARD_SIZE = 50 # CoTs per shard def main(): parser = argparse.ArgumentParser() parser.add_argument("--resume", action="store_true") parser.add_argument("--shard_size", type=int, default=SHARD_SIZE) args = parser.parse_args() ensure_dirs() log = setup_logger("04_routing", LOGS_DIR / "04_routing.log") log.info(f"Reading labeled CoTs: {LABELED_COTS_PATH}") records = read_jsonl(LABELED_COTS_PATH) log.info(f"Got {len(records)} CoTs") # ========== Figure out which shards exist (for resume) ========== def shard_path(shard_idx): return ROUTING_DIR / f"shard_{shard_idx:04d}.pt" n_shards = (len(records) + args.shard_size - 1) // args.shard_size log.info(f"Will produce {n_shards} shards of up to {args.shard_size} CoTs each") # Check which shards are missing missing = [i for i in range(n_shards) if not shard_path(i).exists()] if args.resume and not missing: log.info("All shards already exist. Skipping.") return log.info(f"Shards to compute: {len(missing)} / {n_shards}") # ========== Load model ========== log.info("Loading model...") model, tokenizer = load_model_and_tokenizer() log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB") num_layers = MODEL_CONFIG["num_layers"] # ========== Process shards ========== for shard_idx in missing: shard_start = shard_idx * args.shard_size shard_end = min(shard_start + args.shard_size, len(records)) shard_records = records[shard_start:shard_end] log.info(f"[shard {shard_idx}] processing records {shard_start}:{shard_end}") per_layer_ids = {li: [] for li in range(num_layers)} per_layer_gates = {li: [] for li in range(num_layers)} sample_ids = [] sample_lengths = [] for rec in tqdm(shard_records, desc=f"shard {shard_idx}"): text = rec["cot"] enc = tokenizer( text, return_tensors="pt", add_special_tokens=False, truncation=False ) S_expected = len(rec["token_ids"]) if enc["input_ids"].shape[1] != S_expected: log.warning(f"idx={rec['idx']}: retokenize length mismatch " f"({enc['input_ids'].shape[1]} vs {S_expected}). Skipping.") continue input_ids = enc["input_ids"].to(model.device) cap = RoutingCapture(model) cap.start() try: with torch.no_grad(): _ = model(input_ids) finally: routing = cap.stop() for li in range(num_layers): if li not in routing: continue per_layer_ids[li].append(routing[li]["topk_ids"]) per_layer_gates[li].append(routing[li]["topk_gates"]) sample_ids.append(rec["idx"]) sample_lengths.append(input_ids.shape[1]) cleanup_memory() # Concatenate shard_data = { "sample_ids": sample_ids, "sample_lengths": sample_lengths, "topk_ids": {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v}, "topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v}, } # Atomic save shard_p = shard_path(shard_idx) tmp = shard_p.with_suffix(".pt.tmp") torch.save(shard_data, tmp) tmp.replace(shard_p) n_tok = sum(sample_lengths) log.info(f"[shard {shard_idx}] saved {shard_p}: {len(sample_ids)} CoTs, " f"{n_tok} total tokens, {shard_p.stat().st_size / 1e9:.2f} GB") log.info("=" * 60) log.info("Routing capture complete.") if __name__ == "__main__": main()