File size: 4,607 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Stage 1 part A: Run forward pass over labeled CoTs, capture router top-k at every token.

Output: sharded .pt files in data/routing/, 50 CoTs per shard.
Each shard: {
  "sample_ids": [...],
  "sample_lengths": [...],         # number of tokens per CoT
  "topk_ids":   {layer_id: (N_total, top_k) int16},
  "topk_gates": {layer_id: (N_total, top_k) float16},
}
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

import torch
from tqdm import tqdm

from configs.paths import ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, ROUTING_DIR
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, cleanup_memory, get_vram_mb
from src.model_io import load_model_and_tokenizer
from src.routing_capture import RoutingCapture


SHARD_SIZE = 50   # CoTs per shard


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--shard_size", type=int, default=SHARD_SIZE)
    args = parser.parse_args()

    ensure_dirs()
    log = setup_logger("04_routing", LOGS_DIR / "04_routing.log")

    log.info(f"Reading labeled CoTs: {LABELED_COTS_PATH}")
    records = read_jsonl(LABELED_COTS_PATH)
    log.info(f"Got {len(records)} CoTs")

    # ========== Figure out which shards exist (for resume) ==========
    def shard_path(shard_idx):
        return ROUTING_DIR / f"shard_{shard_idx:04d}.pt"

    n_shards = (len(records) + args.shard_size - 1) // args.shard_size
    log.info(f"Will produce {n_shards} shards of up to {args.shard_size} CoTs each")

    # Check which shards are missing
    missing = [i for i in range(n_shards) if not shard_path(i).exists()]
    if args.resume and not missing:
        log.info("All shards already exist. Skipping.")
        return
    log.info(f"Shards to compute: {len(missing)} / {n_shards}")

    # ========== Load model ==========
    log.info("Loading model...")
    model, tokenizer = load_model_and_tokenizer()
    log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB")

    num_layers = MODEL_CONFIG["num_layers"]

    # ========== Process shards ==========
    for shard_idx in missing:
        shard_start = shard_idx * args.shard_size
        shard_end = min(shard_start + args.shard_size, len(records))
        shard_records = records[shard_start:shard_end]

        log.info(f"[shard {shard_idx}] processing records {shard_start}:{shard_end}")

        per_layer_ids = {li: [] for li in range(num_layers)}
        per_layer_gates = {li: [] for li in range(num_layers)}
        sample_ids = []
        sample_lengths = []

        for rec in tqdm(shard_records, desc=f"shard {shard_idx}"):
            text = rec["cot"]

            enc = tokenizer(
                text, return_tensors="pt", add_special_tokens=False, truncation=False
            )
            S_expected = len(rec["token_ids"])
            if enc["input_ids"].shape[1] != S_expected:
                log.warning(f"idx={rec['idx']}: retokenize length mismatch "
                            f"({enc['input_ids'].shape[1]} vs {S_expected}). Skipping.")
                continue

            input_ids = enc["input_ids"].to(model.device)

            cap = RoutingCapture(model)
            cap.start()
            try:
                with torch.no_grad():
                    _ = model(input_ids)
            finally:
                routing = cap.stop()

            for li in range(num_layers):
                if li not in routing:
                    continue
                per_layer_ids[li].append(routing[li]["topk_ids"])
                per_layer_gates[li].append(routing[li]["topk_gates"])

            sample_ids.append(rec["idx"])
            sample_lengths.append(input_ids.shape[1])
            cleanup_memory()

        # Concatenate
        shard_data = {
            "sample_ids": sample_ids,
            "sample_lengths": sample_lengths,
            "topk_ids":   {li: torch.cat(v, dim=0) for li, v in per_layer_ids.items() if v},
            "topk_gates": {li: torch.cat(v, dim=0) for li, v in per_layer_gates.items() if v},
        }
        # Atomic save
        shard_p = shard_path(shard_idx)
        tmp = shard_p.with_suffix(".pt.tmp")
        torch.save(shard_data, tmp)
        tmp.replace(shard_p)

        n_tok = sum(sample_lengths)
        log.info(f"[shard {shard_idx}] saved {shard_p}: {len(sample_ids)} CoTs, "
                 f"{n_tok} total tokens, {shard_p.stat().st_size / 1e9:.2f} GB")

    log.info("=" * 60)
    log.info("Routing capture complete.")


if __name__ == "__main__":
    main()