File size: 5,137 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """
Stage 3 part A: Capture post-layer residual stream at decision points.
Only captures:
- Target layers (union of planning and monitoring top-expert layers)
- Decision point positions (plan / mon / exec / all_newline / non-newline-sample)
The output file stores per-layer dict with 5 categories of activations.
"""
import sys
import argparse
import random
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import (
ensure_dirs, LOGS_DIR, LABELED_COTS_PATH,
TARGET_LAYERS_PATH, RESIDUALS_PATH, GENERAL_RESIDUALS_PATH,
)
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, read_json, cleanup_memory, get_vram_mb
from src.model_io import load_model_and_tokenizer
from src.residual_capture import ResidualCapture
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--resume", action="store_true")
parser.add_argument("--non_nl_samples_per_cot", type=int, default=20,
help="# random non-newline tokens sampled per CoT (for general direction)")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("07_residuals", LOGS_DIR / "07_residuals.log")
if args.resume and RESIDUALS_PATH.exists() and GENERAL_RESIDUALS_PATH.exists():
log.info("Residuals already saved. Skipping.")
return
# Target layers
tgt = read_json(TARGET_LAYERS_PATH)
target_layers = tgt["union_layers"]
log.info(f"Target layers ({len(target_layers)}): {target_layers}")
# Load labeled
records = read_jsonl(LABELED_COTS_PATH)
log.info(f"Got {len(records)} labeled CoTs")
# Load model
log.info("Loading model...")
model, tokenizer = load_model_and_tokenizer()
log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB")
# Accumulators (per-layer, per-category)
cats = ["plan", "mon", "exec", "all_nl", "non_nl"]
acc = {li: {c: [] for c in cats} for li in target_layers}
rng = random.Random(42)
for rec in tqdm(records, desc="capture residuals"):
text = rec["cot"]
plan_tis = rec["plan_decision_tis"]
mon_tis = rec["mon_decision_tis"]
exec_tis = rec["exec_decision_tis"]
all_nl_tis = rec["all_newline_tis"]
# Sample non-newline tokens: random tokens that are NOT in all_nl_tis
n_tokens = len(rec["token_ids"])
nl_set = set(all_nl_tis)
candidates = [ti for ti in range(n_tokens) if ti not in nl_set]
rng.shuffle(candidates)
non_nl_sample = candidates[:args.non_nl_samples_per_cot]
# Re-tokenize and check length
enc = tokenizer(
text, return_tensors="pt", add_special_tokens=False, truncation=False
)
if enc["input_ids"].shape[1] != n_tokens:
log.warning(f"idx={rec['idx']}: retokenize length mismatch. Skipping.")
continue
input_ids = enc["input_ids"].to(model.device)
cap = ResidualCapture(model, target_layers=target_layers)
cap.start()
try:
with torch.no_grad():
_ = model(input_ids)
finally:
residuals = cap.stop()
# Slice per-category activations
for li in target_layers:
if li not in residuals:
continue
h = residuals[li] # (S, D) fp16 cpu
if plan_tis:
acc[li]["plan"].append(h[plan_tis])
if mon_tis:
acc[li]["mon"].append(h[mon_tis])
if exec_tis:
acc[li]["exec"].append(h[exec_tis])
if all_nl_tis:
acc[li]["all_nl"].append(h[all_nl_tis])
if non_nl_sample:
acc[li]["non_nl"].append(h[non_nl_sample])
cleanup_memory()
# Concatenate per-layer per-category
log.info("Concatenating captures...")
final = {}
for li in target_layers:
final[str(li)] = {}
for c in cats:
if acc[li][c]:
final[str(li)][c] = torch.cat(acc[li][c], dim=0)
else:
final[str(li)][c] = torch.empty(0, MODEL_CONFIG["hidden_size"], dtype=torch.float16)
log.info(f" layer {li:3d} cat {c:<8s} shape {tuple(final[str(li)][c].shape)}")
# Save main (plan/mon/exec)
save_main = {str(li): {
"plan": final[str(li)]["plan"],
"mon": final[str(li)]["mon"],
"exec": final[str(li)]["exec"],
} for li in target_layers}
tmp = RESIDUALS_PATH.with_suffix(".pt.tmp")
torch.save(save_main, tmp)
tmp.replace(RESIDUALS_PATH)
log.info(f"Saved plan/mon/exec residuals: {RESIDUALS_PATH}")
# Save general (for general direction computation)
save_gen = {str(li): {
"all_nl": final[str(li)]["all_nl"],
"non_nl": final[str(li)]["non_nl"],
} for li in target_layers}
tmp = GENERAL_RESIDUALS_PATH.with_suffix(".pt.tmp")
torch.save(save_gen, tmp)
tmp.replace(GENERAL_RESIDUALS_PATH)
log.info(f"Saved general residuals: {GENERAL_RESIDUALS_PATH}")
if __name__ == "__main__":
main()
|