File size: 4,607 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """
Stage 1 part A: Run forward pass over labeled CoTs, capture router top-k at every token.
Output: sharded .pt files in data/routing/, 50 CoTs per shard.
Each shard: {
"sample_ids": [...],
"sample_lengths": [...], # number of tokens per CoT
"topk_ids": {layer_id: (N_total, top_k) int16},
"topk_gates": {layer_id: (N_total, top_k) float16},
}
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, cleanup_memory, get_vram_mb
from src.model_io import load_model_and_tokenizer
from src.routing_capture import RoutingCapture
SHARD_SIZE = 50 # CoTs per shard
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--resume", action="store_true")
parser.add_argument("--shard_size", type=int, default=SHARD_SIZE)
args = parser.parse_args()
ensure_dirs()
log = setup_logger("04_routing", LOGS_DIR / "04_routing.log")
log.info(f"Reading labeled CoTs: {LABELED_COTS_PATH}")
records = read_jsonl(LABELED_COTS_PATH)
log.info(f"Got {len(records)} CoTs")
# ========== Figure out which shards exist (for resume) ==========
def shard_path(shard_idx):
return ROUTING_DIR / f"shard_{shard_idx:04d}.pt"
n_shards = (len(records) + args.shard_size - 1) // args.shard_size
log.info(f"Will produce {n_shards} shards of up to {args.shard_size} CoTs each")
# Check which shards are missing
missing = [i for i in range(n_shards) if not shard_path(i).exists()]
if args.resume and not missing:
log.info("All shards already exist. Skipping.")
return
log.info(f"Shards to compute: {len(missing)} / {n_shards}")
# ========== Load model ==========
log.info("Loading model...")
model, tokenizer = load_model_and_tokenizer()
log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB")
num_layers = MODEL_CONFIG["num_layers"]
# ========== Process shards ==========
for shard_idx in missing:
shard_start = shard_idx * args.shard_size
shard_end = min(shard_start + args.shard_size, len(records))
shard_records = records[shard_start:shard_end]
log.info(f"[shard {shard_idx}] processing records {shard_start}:{shard_end}")
per_layer_ids = {li: [] for li in range(num_layers)}
per_layer_gates = {li: [] for li in range(num_layers)}
sample_ids = []
sample_lengths = []
for rec in tqdm(shard_records, desc=f"shard {shard_idx}"):
text = rec["cot"]
enc = tokenizer(
text, return_tensors="pt", add_special_tokens=False, truncation=False
)
S_expected = len(rec["token_ids"])
if enc["input_ids"].shape[1] != S_expected:
log.warning(f"idx={rec['idx']}: retokenize length mismatch "
f"({enc['input_ids'].shape[1]} vs {S_expected}). Skipping.")
continue
input_ids = enc["input_ids"].to(model.device)
cap = RoutingCapture(model)
cap.start()
try:
with torch.no_grad():
_ = model(input_ids)
finally:
routing = cap.stop()
for li in range(num_layers):
if li not in routing:
continue
per_layer_ids[li].append(routing[li]["topk_ids"])
per_layer_gates[li].append(routing[li]["topk_gates"])
sample_ids.append(rec["idx"])
sample_lengths.append(input_ids.shape[1])
cleanup_memory()
# Concatenate
shard_data = {
"sample_ids": sample_ids,
"sample_lengths": sample_lengths,
"topk_ids": {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v},
"topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v},
}
# Atomic save
shard_p = shard_path(shard_idx)
tmp = shard_p.with_suffix(".pt.tmp")
torch.save(shard_data, tmp)
tmp.replace(shard_p)
n_tok = sum(sample_lengths)
log.info(f"[shard {shard_idx}] saved {shard_p}: {len(sample_ids)} CoTs, "
f"{n_tok} total tokens, {shard_p.stat().st_size / 1e9:.2f} GB")
log.info("=" * 60)
log.info("Routing capture complete.")
if __name__ == "__main__":
main()
|