| """ |
| Stage 1 part A: Run forward pass over labeled CoTs, capture router top-k at every token. |
| |
| Output: sharded .pt files in data/routing/, 50 CoTs per shard. |
| Each shard: { |
| "sample_ids": [...], |
| "sample_lengths": [...], # number of tokens per CoT |
| "topk_ids": {layer_id: (N_total, top_k) int16}, |
| "topk_gates": {layer_id: (N_total, top_k) float16}, |
| } |
| """ |
| import sys |
| import argparse |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from configs.paths import ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR |
| from configs.model import MODEL_CONFIG |
| from src.utils import setup_logger, read_jsonl, cleanup_memory, get_vram_mb |
| from src.model_io import load_model_and_tokenizer |
| from src.routing_capture import RoutingCapture |
|
|
|
|
| SHARD_SIZE = 50 |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--resume", action="store_true") |
| parser.add_argument("--shard_size", type=int, default=SHARD_SIZE) |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("04_routing", LOGS_DIR / "04_routing.log") |
|
|
| log.info(f"Reading labeled CoTs: {LABELED_COTS_PATH}") |
| records = read_jsonl(LABELED_COTS_PATH) |
| log.info(f"Got {len(records)} CoTs") |
|
|
| |
| def shard_path(shard_idx): |
| return ROUTING_DIR / f"shard_{shard_idx:04d}.pt" |
|
|
| n_shards = (len(records) + args.shard_size - 1) // args.shard_size |
| log.info(f"Will produce {n_shards} shards of up to {args.shard_size} CoTs each") |
|
|
| |
| missing = [i for i in range(n_shards) if not shard_path(i).exists()] |
| if args.resume and not missing: |
| log.info("All shards already exist. Skipping.") |
| return |
| log.info(f"Shards to compute: {len(missing)} / {n_shards}") |
|
|
| |
| log.info("Loading model...") |
| model, tokenizer = load_model_and_tokenizer() |
| log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB") |
|
|
| num_layers = MODEL_CONFIG["num_layers"] |
|
|
| |
| for shard_idx in missing: |
| shard_start = shard_idx * args.shard_size |
| shard_end = min(shard_start + args.shard_size, len(records)) |
| shard_records = records[shard_start:shard_end] |
|
|
| log.info(f"[shard {shard_idx}] processing records {shard_start}:{shard_end}") |
|
|
| per_layer_ids = {li: [] for li in range(num_layers)} |
| per_layer_gates = {li: [] for li in range(num_layers)} |
| sample_ids = [] |
| sample_lengths = [] |
|
|
| for rec in tqdm(shard_records, desc=f"shard {shard_idx}"): |
| text = rec["cot"] |
|
|
| enc = tokenizer( |
| text, return_tensors="pt", add_special_tokens=False, truncation=False |
| ) |
| S_expected = len(rec["token_ids"]) |
| if enc["input_ids"].shape[1] != S_expected: |
| log.warning(f"idx={rec['idx']}: retokenize length mismatch " |
| f"({enc['input_ids'].shape[1]} vs {S_expected}). Skipping.") |
| continue |
|
|
| input_ids = enc["input_ids"].to(model.device) |
|
|
| cap = RoutingCapture(model) |
| cap.start() |
| try: |
| with torch.no_grad(): |
| _ = model(input_ids) |
| finally: |
| routing = cap.stop() |
|
|
| for li in range(num_layers): |
| if li not in routing: |
| continue |
| per_layer_ids[li].append(routing[li]["topk_ids"]) |
| per_layer_gates[li].append(routing[li]["topk_gates"]) |
|
|
| sample_ids.append(rec["idx"]) |
| sample_lengths.append(input_ids.shape[1]) |
| cleanup_memory() |
|
|
| |
| shard_data = { |
| "sample_ids": sample_ids, |
| "sample_lengths": sample_lengths, |
| "topk_ids": {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v}, |
| "topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v}, |
| } |
| |
| shard_p = shard_path(shard_idx) |
| tmp = shard_p.with_suffix(".pt.tmp") |
| torch.save(shard_data, tmp) |
| tmp.replace(shard_p) |
|
|
| n_tok = sum(sample_lengths) |
| log.info(f"[shard {shard_idx}] saved {shard_p}: {len(sample_ids)} CoTs, " |
| f"{n_tok} total tokens, {shard_p.stat().st_size / 1e9:.2f} GB") |
|
|
| log.info("=" * 60) |
| log.info("Routing capture complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|