v2 / scripts /04_capture_routing.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Stage 1 part A: Run forward pass over labeled CoTs, capture router top-k at every token.
Output: sharded .pt files in data/routing/, 50 CoTs per shard.
Each shard: {
"sample_ids": [...],
"sample_lengths": [...], # number of tokens per CoT
"topk_ids": {layer_id: (N_total, top_k) int16},
"topk_gates": {layer_id: (N_total, top_k) float16},
}
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, cleanup_memory, get_vram_mb
from src.model_io import load_model_and_tokenizer
from src.routing_capture import RoutingCapture
SHARD_SIZE = 50 # CoTs per shard
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--resume", action="store_true")
parser.add_argument("--shard_size", type=int, default=SHARD_SIZE)
args = parser.parse_args()
ensure_dirs()
log = setup_logger("04_routing", LOGS_DIR / "04_routing.log")
log.info(f"Reading labeled CoTs: {LABELED_COTS_PATH}")
records = read_jsonl(LABELED_COTS_PATH)
log.info(f"Got {len(records)} CoTs")
# ========== Figure out which shards exist (for resume) ==========
def shard_path(shard_idx):
return ROUTING_DIR / f"shard_{shard_idx:04d}.pt"
n_shards = (len(records) + args.shard_size - 1) // args.shard_size
log.info(f"Will produce {n_shards} shards of up to {args.shard_size} CoTs each")
# Check which shards are missing
missing = [i for i in range(n_shards) if not shard_path(i).exists()]
if args.resume and not missing:
log.info("All shards already exist. Skipping.")
return
log.info(f"Shards to compute: {len(missing)} / {n_shards}")
# ========== Load model ==========
log.info("Loading model...")
model, tokenizer = load_model_and_tokenizer()
log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB")
num_layers = MODEL_CONFIG["num_layers"]
# ========== Process shards ==========
for shard_idx in missing:
shard_start = shard_idx * args.shard_size
shard_end = min(shard_start + args.shard_size, len(records))
shard_records = records[shard_start:shard_end]
log.info(f"[shard {shard_idx}] processing records {shard_start}:{shard_end}")
per_layer_ids = {li: [] for li in range(num_layers)}
per_layer_gates = {li: [] for li in range(num_layers)}
sample_ids = []
sample_lengths = []
for rec in tqdm(shard_records, desc=f"shard {shard_idx}"):
text = rec["cot"]
enc = tokenizer(
text, return_tensors="pt", add_special_tokens=False, truncation=False
)
S_expected = len(rec["token_ids"])
if enc["input_ids"].shape[1] != S_expected:
log.warning(f"idx={rec['idx']}: retokenize length mismatch "
f"({enc['input_ids'].shape[1]} vs {S_expected}). Skipping.")
continue
input_ids = enc["input_ids"].to(model.device)
cap = RoutingCapture(model)
cap.start()
try:
with torch.no_grad():
_ = model(input_ids)
finally:
routing = cap.stop()
for li in range(num_layers):
if li not in routing:
continue
per_layer_ids[li].append(routing[li]["topk_ids"])
per_layer_gates[li].append(routing[li]["topk_gates"])
sample_ids.append(rec["idx"])
sample_lengths.append(input_ids.shape[1])
cleanup_memory()
# Concatenate
shard_data = {
"sample_ids": sample_ids,
"sample_lengths": sample_lengths,
"topk_ids": {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v},
"topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v},
}
# Atomic save
shard_p = shard_path(shard_idx)
tmp = shard_p.with_suffix(".pt.tmp")
torch.save(shard_data, tmp)
tmp.replace(shard_p)
n_tok = sum(sample_lengths)
log.info(f"[shard {shard_idx}] saved {shard_p}: {len(sample_ids)} CoTs, "
f"{n_tok} total tokens, {shard_p.stat().st_size / 1e9:.2f} GB")
log.info("=" * 60)
log.info("Routing capture complete.")
if __name__ == "__main__":
main()