File size: 5,137 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Stage 3 part A: Capture post-layer residual stream at decision points.

Only captures:
  - Target layers (union of planning and monitoring top-expert layers)
  - Decision point positions (plan / mon / exec / all_newline / non-newline-sample)

The output file stores per-layer dict with 5 categories of activations.
"""
import sys
import argparse
import random
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

import torch
from tqdm import tqdm

from configs.paths import (
    ensure_dirs, LOGS_DIR, LABELED_COTS_PATH,
    TARGET_LAYERS_PATH, RESIDUALS_PATH, GENERAL_RESIDUALS_PATH,
)
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, read_json, cleanup_memory, get_vram_mb
from src.model_io import load_model_and_tokenizer
from src.residual_capture import ResidualCapture


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--non_nl_samples_per_cot", type=int, default=20,
                        help="# random non-newline tokens sampled per CoT (for general direction)")
    args = parser.parse_args()

    ensure_dirs()
    log = setup_logger("07_residuals", LOGS_DIR / "07_residuals.log")

    if args.resume and RESIDUALS_PATH.exists() and GENERAL_RESIDUALS_PATH.exists():
        log.info("Residuals already saved. Skipping.")
        return

    # Target layers
    tgt = read_json(TARGET_LAYERS_PATH)
    target_layers = tgt["union_layers"]
    log.info(f"Target layers ({len(target_layers)}): {target_layers}")

    # Load labeled
    records = read_jsonl(LABELED_COTS_PATH)
    log.info(f"Got {len(records)} labeled CoTs")

    # Load model
    log.info("Loading model...")
    model, tokenizer = load_model_and_tokenizer()
    log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB")

    # Accumulators (per-layer, per-category)
    cats = ["plan", "mon", "exec", "all_nl", "non_nl"]
    acc = {li: {c: [] for c in cats} for li in target_layers}

    rng = random.Random(42)

    for rec in tqdm(records, desc="capture residuals"):
        text = rec["cot"]
        plan_tis = rec["plan_decision_tis"]
        mon_tis = rec["mon_decision_tis"]
        exec_tis = rec["exec_decision_tis"]
        all_nl_tis = rec["all_newline_tis"]

        # Sample non-newline tokens: random tokens that are NOT in all_nl_tis
        n_tokens = len(rec["token_ids"])
        nl_set = set(all_nl_tis)
        candidates = [ti for ti in range(n_tokens) if ti not in nl_set]
        rng.shuffle(candidates)
        non_nl_sample = candidates[:args.non_nl_samples_per_cot]

        # Re-tokenize and check length
        enc = tokenizer(
            text, return_tensors="pt", add_special_tokens=False, truncation=False
        )
        if enc["input_ids"].shape[1] != n_tokens:
            log.warning(f"idx={rec['idx']}: retokenize length mismatch. Skipping.")
            continue

        input_ids = enc["input_ids"].to(model.device)

        cap = ResidualCapture(model, target_layers=target_layers)
        cap.start()
        try:
            with torch.no_grad():
                _ = model(input_ids)
        finally:
            residuals = cap.stop()

        # Slice per-category activations
        for li in target_layers:
            if li not in residuals:
                continue
            h = residuals[li]   # (S, D) fp16 cpu
            if plan_tis:
                acc[li]["plan"].append(h[plan_tis])
            if mon_tis:
                acc[li]["mon"].append(h[mon_tis])
            if exec_tis:
                acc[li]["exec"].append(h[exec_tis])
            if all_nl_tis:
                acc[li]["all_nl"].append(h[all_nl_tis])
            if non_nl_sample:
                acc[li]["non_nl"].append(h[non_nl_sample])

        cleanup_memory()

    # Concatenate per-layer per-category
    log.info("Concatenating captures...")
    final = {}
    for li in target_layers:
        final[str(li)] = {}
        for c in cats:
            if acc[li][c]:
                final[str(li)][c] = torch.cat(acc[li][c], dim=0)
            else:
                final[str(li)][c] = torch.empty(0, MODEL_CONFIG["hidden_size"], dtype=torch.float16)
            log.info(f"  layer {li:3d} cat {c:<8s} shape {tuple(final[str(li)][c].shape)}")

    # Save main (plan/mon/exec)
    save_main = {str(li): {
        "plan": final[str(li)]["plan"],
        "mon":  final[str(li)]["mon"],
        "exec": final[str(li)]["exec"],
    } for li in target_layers}
    tmp = RESIDUALS_PATH.with_suffix(".pt.tmp")
    torch.save(save_main, tmp)
    tmp.replace(RESIDUALS_PATH)
    log.info(f"Saved plan/mon/exec residuals: {RESIDUALS_PATH}")

    # Save general (for general direction computation)
    save_gen = {str(li): {
        "all_nl": final[str(li)]["all_nl"],
        "non_nl": final[str(li)]["non_nl"],
    } for li in target_layers}
    tmp = GENERAL_RESIDUALS_PATH.with_suffix(".pt.tmp")
    torch.save(save_gen, tmp)
    tmp.replace(GENERAL_RESIDUALS_PATH)
    log.info(f"Saved general residuals: {GENERAL_RESIDUALS_PATH}")


if __name__ == "__main__":
    main()