v2 / scripts /07_capture_residuals.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Stage 3 part A: Capture post-layer residual stream at decision points.
Only captures:
- Target layers (union of planning and monitoring top-expert layers)
- Decision point positions (plan / mon / exec / all_newline / non-newline-sample)
The output file stores per-layer dict with 5 categories of activations.
"""
import sys
import argparse
import random
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import (
ensure_dirs, LOGS_DIR, LABELED_COTS_PATH,
TARGET_LAYERS_PATH, RESIDUALS_PATH, GENERAL_RESIDUALS_PATH,
)
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, read_json, cleanup_memory, get_vram_mb
from src.model_io import load_model_and_tokenizer
from src.residual_capture import ResidualCapture
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--resume", action="store_true")
parser.add_argument("--non_nl_samples_per_cot", type=int, default=20,
help="# random non-newline tokens sampled per CoT (for general direction)")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("07_residuals", LOGS_DIR / "07_residuals.log")
if args.resume and RESIDUALS_PATH.exists() and GENERAL_RESIDUALS_PATH.exists():
log.info("Residuals already saved. Skipping.")
return
# Target layers
tgt = read_json(TARGET_LAYERS_PATH)
target_layers = tgt["union_layers"]
log.info(f"Target layers ({len(target_layers)}): {target_layers}")
# Load labeled
records = read_jsonl(LABELED_COTS_PATH)
log.info(f"Got {len(records)} labeled CoTs")
# Load model
log.info("Loading model...")
model, tokenizer = load_model_and_tokenizer()
log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB")
# Accumulators (per-layer, per-category)
cats = ["plan", "mon", "exec", "all_nl", "non_nl"]
acc = {li: {c: [] for c in cats} for li in target_layers}
rng = random.Random(42)
for rec in tqdm(records, desc="capture residuals"):
text = rec["cot"]
plan_tis = rec["plan_decision_tis"]
mon_tis = rec["mon_decision_tis"]
exec_tis = rec["exec_decision_tis"]
all_nl_tis = rec["all_newline_tis"]
# Sample non-newline tokens: random tokens that are NOT in all_nl_tis
n_tokens = len(rec["token_ids"])
nl_set = set(all_nl_tis)
candidates = [ti for ti in range(n_tokens) if ti not in nl_set]
rng.shuffle(candidates)
non_nl_sample = candidates[:args.non_nl_samples_per_cot]
# Re-tokenize and check length
enc = tokenizer(
text, return_tensors="pt", add_special_tokens=False, truncation=False
)
if enc["input_ids"].shape[1] != n_tokens:
log.warning(f"idx={rec['idx']}: retokenize length mismatch. Skipping.")
continue
input_ids = enc["input_ids"].to(model.device)
cap = ResidualCapture(model, target_layers=target_layers)
cap.start()
try:
with torch.no_grad():
_ = model(input_ids)
finally:
residuals = cap.stop()
# Slice per-category activations
for li in target_layers:
if li not in residuals:
continue
h = residuals[li] # (S, D) fp16 cpu
if plan_tis:
acc[li]["plan"].append(h[plan_tis])
if mon_tis:
acc[li]["mon"].append(h[mon_tis])
if exec_tis:
acc[li]["exec"].append(h[exec_tis])
if all_nl_tis:
acc[li]["all_nl"].append(h[all_nl_tis])
if non_nl_sample:
acc[li]["non_nl"].append(h[non_nl_sample])
cleanup_memory()
# Concatenate per-layer per-category
log.info("Concatenating captures...")
final = {}
for li in target_layers:
final[str(li)] = {}
for c in cats:
if acc[li][c]:
final[str(li)][c] = torch.cat(acc[li][c], dim=0)
else:
final[str(li)][c] = torch.empty(0, MODEL_CONFIG["hidden_size"], dtype=torch.float16)
log.info(f" layer {li:3d} cat {c:<8s} shape {tuple(final[str(li)][c].shape)}")
# Save main (plan/mon/exec)
save_main = {str(li): {
"plan": final[str(li)]["plan"],
"mon": final[str(li)]["mon"],
"exec": final[str(li)]["exec"],
} for li in target_layers}
tmp = RESIDUALS_PATH.with_suffix(".pt.tmp")
torch.save(save_main, tmp)
tmp.replace(RESIDUALS_PATH)
log.info(f"Saved plan/mon/exec residuals: {RESIDUALS_PATH}")
# Save general (for general direction computation)
save_gen = {str(li): {
"all_nl": final[str(li)]["all_nl"],
"non_nl": final[str(li)]["non_nl"],
} for li in target_layers}
tmp = GENERAL_RESIDUALS_PATH.with_suffix(".pt.tmp")
torch.save(save_gen, tmp)
tmp.replace(GENERAL_RESIDUALS_PATH)
log.info(f"Saved general residuals: {GENERAL_RESIDUALS_PATH}")
if __name__ == "__main__":
main()