""" Stage 3 part A: Capture post-layer residual stream at decision points. Only captures: - Target layers (union of planning and monitoring top-expert layers) - Decision point positions (plan / mon / exec / all_newline / non-newline-sample) The output file stores per-layer dict with 5 categories of activations. """ import sys import argparse import random from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch from tqdm import tqdm from configs.paths import ( ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, TARGET_LAYERS_PATH, RESIDUALS_PATH, GENERAL_RESIDUALS_PATH, ) from configs.model import MODEL_CONFIG from src.utils import setup_logger, read_jsonl, read_json, cleanup_memory, get_vram_mb from src.model_io import load_model_and_tokenizer from src.residual_capture import ResidualCapture def main(): parser = argparse.ArgumentParser() parser.add_argument("--resume", action="store_true") parser.add_argument("--non_nl_samples_per_cot", type=int, default=20, help="# random non-newline tokens sampled per CoT (for general direction)") args = parser.parse_args() ensure_dirs() log = setup_logger("07_residuals", LOGS_DIR / "07_residuals.log") if args.resume and RESIDUALS_PATH.exists() and GENERAL_RESIDUALS_PATH.exists(): log.info("Residuals already saved. Skipping.") return # Target layers tgt = read_json(TARGET_LAYERS_PATH) target_layers = tgt["union_layers"] log.info(f"Target layers ({len(target_layers)}): {target_layers}") # Load labeled records = read_jsonl(LABELED_COTS_PATH) log.info(f"Got {len(records)} labeled CoTs") # Load model log.info("Loading model...") model, tokenizer = load_model_and_tokenizer() log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB") # Accumulators (per-layer, per-category) cats = ["plan", "mon", "exec", "all_nl", "non_nl"] acc = {li: {c: [] for c in cats} for li in target_layers} rng = random.Random(42) for rec in tqdm(records, desc="capture residuals"): text = rec["cot"] plan_tis = rec["plan_decision_tis"] mon_tis = rec["mon_decision_tis"] exec_tis = rec["exec_decision_tis"] all_nl_tis = rec["all_newline_tis"] # Sample non-newline tokens: random tokens that are NOT in all_nl_tis n_tokens = len(rec["token_ids"]) nl_set = set(all_nl_tis) candidates = [ti for ti in range(n_tokens) if ti not in nl_set] rng.shuffle(candidates) non_nl_sample = candidates[:args.non_nl_samples_per_cot] # Re-tokenize and check length enc = tokenizer( text, return_tensors="pt", add_special_tokens=False, truncation=False ) if enc["input_ids"].shape[1] != n_tokens: log.warning(f"idx={rec['idx']}: retokenize length mismatch. Skipping.") continue input_ids = enc["input_ids"].to(model.device) cap = ResidualCapture(model, target_layers=target_layers) cap.start() try: with torch.no_grad(): _ = model(input_ids) finally: residuals = cap.stop() # Slice per-category activations for li in target_layers: if li not in residuals: continue h = residuals[li] # (S, D) fp16 cpu if plan_tis: acc[li]["plan"].append(h[plan_tis]) if mon_tis: acc[li]["mon"].append(h[mon_tis]) if exec_tis: acc[li]["exec"].append(h[exec_tis]) if all_nl_tis: acc[li]["all_nl"].append(h[all_nl_tis]) if non_nl_sample: acc[li]["non_nl"].append(h[non_nl_sample]) cleanup_memory() # Concatenate per-layer per-category log.info("Concatenating captures...") final = {} for li in target_layers: final[str(li)] = {} for c in cats: if acc[li][c]: final[str(li)][c] = torch.cat(acc[li][c], dim=0) else: final[str(li)][c] = torch.empty(0, MODEL_CONFIG["hidden_size"], dtype=torch.float16) log.info(f" layer {li:3d} cat {c:<8s} shape {tuple(final[str(li)][c].shape)}") # Save main (plan/mon/exec) save_main = {str(li): { "plan": final[str(li)]["plan"], "mon": final[str(li)]["mon"], "exec": final[str(li)]["exec"], } for li in target_layers} tmp = RESIDUALS_PATH.with_suffix(".pt.tmp") torch.save(save_main, tmp) tmp.replace(RESIDUALS_PATH) log.info(f"Saved plan/mon/exec residuals: {RESIDUALS_PATH}") # Save general (for general direction computation) save_gen = {str(li): { "all_nl": final[str(li)]["all_nl"], "non_nl": final[str(li)]["non_nl"], } for li in target_layers} tmp = GENERAL_RESIDUALS_PATH.with_suffix(".pt.tmp") torch.save(save_gen, tmp) tmp.replace(GENERAL_RESIDUALS_PATH) log.info(f"Saved general residuals: {GENERAL_RESIDUALS_PATH}") if __name__ == "__main__": main()