| """ |
| Stage 3 part A: Capture post-layer residual stream at decision points. |
| |
| Only captures: |
| - Target layers (union of planning and monitoring top-expert layers) |
| - Decision point positions (plan / mon / exec / all_newline / non-newline-sample) |
| |
| The output file stores per-layer dict with 5 categories of activations. |
| """ |
| import sys |
| import argparse |
| import random |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from configs.paths import ( |
| ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, |
| TARGET_LAYERS_PATH, RESIDUALS_PATH, GENERAL_RESIDUALS_PATH, |
| ) |
| from configs.model import MODEL_CONFIG |
| from src.utils import setup_logger, read_jsonl, read_json, cleanup_memory, get_vram_mb |
| from src.model_io import load_model_and_tokenizer |
| from src.residual_capture import ResidualCapture |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--resume", action="store_true") |
| parser.add_argument("--non_nl_samples_per_cot", type=int, default=20, |
| help="# random non-newline tokens sampled per CoT (for general direction)") |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("07_residuals", LOGS_DIR / "07_residuals.log") |
|
|
| if args.resume and RESIDUALS_PATH.exists() and GENERAL_RESIDUALS_PATH.exists(): |
| log.info("Residuals already saved. Skipping.") |
| return |
|
|
| |
| tgt = read_json(TARGET_LAYERS_PATH) |
| target_layers = tgt["union_layers"] |
| log.info(f"Target layers ({len(target_layers)}): {target_layers}") |
|
|
| |
| records = read_jsonl(LABELED_COTS_PATH) |
| log.info(f"Got {len(records)} labeled CoTs") |
|
|
| |
| log.info("Loading model...") |
| model, tokenizer = load_model_and_tokenizer() |
| log.info(f"Model loaded. VRAM: {get_vram_mb():.0f} MB") |
|
|
| |
| cats = ["plan", "mon", "exec", "all_nl", "non_nl"] |
| acc = {li: {c: [] for c in cats} for li in target_layers} |
|
|
| rng = random.Random(42) |
|
|
| for rec in tqdm(records, desc="capture residuals"): |
| text = rec["cot"] |
| plan_tis = rec["plan_decision_tis"] |
| mon_tis = rec["mon_decision_tis"] |
| exec_tis = rec["exec_decision_tis"] |
| all_nl_tis = rec["all_newline_tis"] |
|
|
| |
| n_tokens = len(rec["token_ids"]) |
| nl_set = set(all_nl_tis) |
| candidates = [ti for ti in range(n_tokens) if ti not in nl_set] |
| rng.shuffle(candidates) |
| non_nl_sample = candidates[:args.non_nl_samples_per_cot] |
|
|
| |
| enc = tokenizer( |
| text, return_tensors="pt", add_special_tokens=False, truncation=False |
| ) |
| if enc["input_ids"].shape[1] != n_tokens: |
| log.warning(f"idx={rec['idx']}: retokenize length mismatch. Skipping.") |
| continue |
|
|
| input_ids = enc["input_ids"].to(model.device) |
|
|
| cap = ResidualCapture(model, target_layers=target_layers) |
| cap.start() |
| try: |
| with torch.no_grad(): |
| _ = model(input_ids) |
| finally: |
| residuals = cap.stop() |
|
|
| |
| for li in target_layers: |
| if li not in residuals: |
| continue |
| h = residuals[li] |
| if plan_tis: |
| acc[li]["plan"].append(h[plan_tis]) |
| if mon_tis: |
| acc[li]["mon"].append(h[mon_tis]) |
| if exec_tis: |
| acc[li]["exec"].append(h[exec_tis]) |
| if all_nl_tis: |
| acc[li]["all_nl"].append(h[all_nl_tis]) |
| if non_nl_sample: |
| acc[li]["non_nl"].append(h[non_nl_sample]) |
|
|
| cleanup_memory() |
|
|
| |
| log.info("Concatenating captures...") |
| final = {} |
| for li in target_layers: |
| final[str(li)] = {} |
| for c in cats: |
| if acc[li][c]: |
| final[str(li)][c] = torch.cat(acc[li][c], dim=0) |
| else: |
| final[str(li)][c] = torch.empty(0, MODEL_CONFIG["hidden_size"], dtype=torch.float16) |
| log.info(f" layer {li:3d} cat {c:<8s} shape {tuple(final[str(li)][c].shape)}") |
|
|
| |
| save_main = {str(li): { |
| "plan": final[str(li)]["plan"], |
| "mon": final[str(li)]["mon"], |
| "exec": final[str(li)]["exec"], |
| } for li in target_layers} |
| tmp = RESIDUALS_PATH.with_suffix(".pt.tmp") |
| torch.save(save_main, tmp) |
| tmp.replace(RESIDUALS_PATH) |
| log.info(f"Saved plan/mon/exec residuals: {RESIDUALS_PATH}") |
|
|
| |
| save_gen = {str(li): { |
| "all_nl": final[str(li)]["all_nl"], |
| "non_nl": final[str(li)]["non_nl"], |
| } for li in target_layers} |
| tmp = GENERAL_RESIDUALS_PATH.with_suffix(".pt.tmp") |
| torch.save(save_gen, tmp) |
| tmp.replace(GENERAL_RESIDUALS_PATH) |
| log.info(f"Saved general residuals: {GENERAL_RESIDUALS_PATH}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|