""" Stage 3.5 (DIAGNOSTIC): Attention output direction analysis. Goal: check whether planning/monitoring directions are well-represented in the ATTENTION OUTPUT of each layer (not only in the post-MLP residual). If attention output also shows strong plan-vs-exec separation, our FFN-residual-based steering may miss this signal. This script prints a comparison and saves a JSON / figure but does NOT modify steering. Output: data/results/attention_diagnostic.{json,png} Decision rule (informational, not auto-applied): - If attention mean-diff norm > 50% of MLP residual mean-diff norm, consider also hooking attention output during steering. - If < 30%, FFN-only steering is fine (current pipeline). """ import sys import argparse from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch from tqdm import tqdm from configs.paths import ( ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, TARGET_LAYERS_PATH, ATTN_RESIDUALS_PATH, ATTN_DIAGNOSTIC_PATH, ATTN_DIAGNOSTIC_FIG, RESIDUALS_PATH, ) from src.utils import setup_logger, read_jsonl, read_json, write_json, cleanup_memory from src.model_io import load_model_and_tokenizer from src.attention_capture import AttentionOutputCapture def main(): parser = argparse.ArgumentParser() parser.add_argument("--n_samples", type=int, default=50, help="# of training CoTs to use (subset of labeled set)") parser.add_argument("--resume", action="store_true") args = parser.parse_args() ensure_dirs() log = setup_logger("08b_attn", LOGS_DIR / "08b_attn.log") if args.resume and ATTN_DIAGNOSTIC_PATH.exists(): log.info("Diagnostic already done, skipping.") return target_layers = read_json(TARGET_LAYERS_PATH)["union_layers"] log.info(f"Target layers: {target_layers}") records = read_jsonl(LABELED_COTS_PATH)[:args.n_samples] log.info(f"Using {len(records)} labeled CoTs for attention capture") log.info("Loading model...") model, tokenizer = load_model_and_tokenizer() # Accumulate attention outputs at decision points cats = ["plan", "mon", "exec"] acc = {li: {c: [] for c in cats} for li in target_layers} for rec in tqdm(records, desc="capture attn"): text = rec["cot"] plan_tis = rec["plan_decision_tis"] mon_tis = rec["mon_decision_tis"] exec_tis = rec["exec_decision_tis"] enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False) if enc["input_ids"].shape[1] != len(rec["token_ids"]): continue input_ids = enc["input_ids"].to(model.device) cap = AttentionOutputCapture(model, target_layers=target_layers) cap.start() try: with torch.no_grad(): _ = model(input_ids) finally: attn_outs = cap.stop() for li in target_layers: if li not in attn_outs: continue h = attn_outs[li] if plan_tis: acc[li]["plan"].append(h[plan_tis]) if mon_tis: acc[li]["mon"].append(h[mon_tis]) if exec_tis: acc[li]["exec"].append(h[exec_tis]) cleanup_memory() # Free model del model cleanup_memory() # Compute attention mean-diff norms; compare to residual mean-diff norms log.info("Computing attention mean-diff norms...") diagnostic = {"layers": {}} for li in target_layers: plan = torch.cat(acc[li]["plan"], dim=0).to(torch.float32) if acc[li]["plan"] else None mon = torch.cat(acc[li]["mon"], dim=0).to(torch.float32) if acc[li]["mon"] else None execu= torch.cat(acc[li]["exec"], dim=0).to(torch.float32) if acc[li]["exec"] else None layer_info = {} if plan is not None and execu is not None and plan.shape[0] > 0 and execu.shape[0] > 0: attn_plan_diff = (plan.mean(0) - execu.mean(0)).norm().item() layer_info["attn_plan_norm"] = attn_plan_diff if mon is not None and execu is not None and mon.shape[0] > 0 and execu.shape[0] > 0: attn_mon_diff = (mon.mean(0) - execu.mean(0)).norm().item() layer_info["attn_mon_norm"] = attn_mon_diff diagnostic["layers"][str(li)] = layer_info # Compare to residual norms (load existing v1_raw directions and compute their norms) log.info("Comparing to FFN residual mean-diff norms...") if RESIDUALS_PATH.exists(): residuals = torch.load(RESIDUALS_PATH, map_location="cpu") for li in target_layers: if str(li) not in residuals: continue r = residuals[str(li)] res_plan = r["plan"].to(torch.float32) if r["plan"].shape[0] > 0 else None res_mon = r["mon"].to(torch.float32) if r["mon"].shape[0] > 0 else None res_exec = r["exec"].to(torch.float32) if r["exec"].shape[0] > 0 else None li_str = str(li) if li_str not in diagnostic["layers"]: diagnostic["layers"][li_str] = {} if res_plan is not None and res_exec is not None: diagnostic["layers"][li_str]["res_plan_norm"] = \ (res_plan.mean(0) - res_exec.mean(0)).norm().item() if res_mon is not None and res_exec is not None: diagnostic["layers"][li_str]["res_mon_norm"] = \ (res_mon.mean(0) - res_exec.mean(0)).norm().item() # Compute summary ratios ratios = {"plan": [], "mon": []} for li_str, info in diagnostic["layers"].items(): if "attn_plan_norm" in info and "res_plan_norm" in info and info["res_plan_norm"] > 0: ratios["plan"].append(info["attn_plan_norm"] / info["res_plan_norm"]) if "attn_mon_norm" in info and "res_mon_norm" in info and info["res_mon_norm"] > 0: ratios["mon"].append(info["attn_mon_norm"] / info["res_mon_norm"]) summary = {} for d in ["plan", "mon"]: if ratios[d]: avg = sum(ratios[d]) / len(ratios[d]) mx = max(ratios[d]) summary[d] = { "mean_attn_to_residual_ratio": avg, "max_attn_to_residual_ratio": mx, "n_layers": len(ratios[d]), "recommendation": ( "attention also matters — consider hooking it" if avg > 0.5 else "FFN-only steering OK" if avg < 0.3 else "borderline, monitor" ), } diagnostic["summary"] = summary write_json(diagnostic, ATTN_DIAGNOSTIC_PATH) log.info(f"Saved {ATTN_DIAGNOSTIC_PATH}") log.info(f"Summary: {summary}") # Plot try: import matplotlib.pyplot as plt import numpy as np layers = sorted(int(l) for l in diagnostic["layers"].keys()) attn_p = [diagnostic["layers"][str(li)].get("attn_plan_norm", 0) for li in layers] res_p = [diagnostic["layers"][str(li)].get("res_plan_norm", 0) for li in layers] attn_m = [diagnostic["layers"][str(li)].get("attn_mon_norm", 0) for li in layers] res_m = [diagnostic["layers"][str(li)].get("res_mon_norm", 0) for li in layers] fig, axes = plt.subplots(1, 2, figsize=(14, 5)) x = np.arange(len(layers)) w = 0.4 axes[0].bar(x - w/2, attn_p, w, label="attn output") axes[0].bar(x + w/2, res_p, w, label="post-layer residual") axes[0].set_xticks(x); axes[0].set_xticklabels(layers, rotation=90) axes[0].set_title("Planning mean-diff norm by source") axes[0].set_xlabel("layer"); axes[0].legend() axes[1].bar(x - w/2, attn_m, w, label="attn output") axes[1].bar(x + w/2, res_m, w, label="post-layer residual") axes[1].set_xticks(x); axes[1].set_xticklabels(layers, rotation=90) axes[1].set_title("Monitoring mean-diff norm by source") axes[1].set_xlabel("layer"); axes[1].legend() plt.tight_layout() plt.savefig(ATTN_DIAGNOSTIC_FIG, dpi=120) plt.close() log.info(f"Saved {ATTN_DIAGNOSTIC_FIG}") except Exception as e: log.warning(f"Plot failed: {e}") if __name__ == "__main__": main()