| """ |
| Stage 3.5 (DIAGNOSTIC): Attention output direction analysis. |
| |
| Goal: check whether planning/monitoring directions are well-represented |
| in the ATTENTION OUTPUT of each layer (not only in the post-MLP residual). |
| |
| If attention output also shows strong plan-vs-exec separation, our |
| FFN-residual-based steering may miss this signal. This script prints |
| a comparison and saves a JSON / figure but does NOT modify steering. |
| |
| Output: data/results/attention_diagnostic.{json,png} |
| |
| Decision rule (informational, not auto-applied): |
| - If attention mean-diff norm > 50% of MLP residual mean-diff norm, |
| consider also hooking attention output during steering. |
| - If < 30%, FFN-only steering is fine (current pipeline). |
| """ |
| import sys |
| import argparse |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from configs.paths import ( |
| ensure_dirs, LOGS_DIR, LABELED_COTS_PATH, |
| TARGET_LAYERS_PATH, |
| ATTN_RESIDUALS_PATH, ATTN_DIAGNOSTIC_PATH, ATTN_DIAGNOSTIC_FIG, |
| RESIDUALS_PATH, |
| ) |
| from src.utils import setup_logger, read_jsonl, read_json, write_json, cleanup_memory |
| from src.model_io import load_model_and_tokenizer |
| from src.attention_capture import AttentionOutputCapture |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--n_samples", type=int, default=50, |
| help="# of training CoTs to use (subset of labeled set)") |
| parser.add_argument("--resume", action="store_true") |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("08b_attn", LOGS_DIR / "08b_attn.log") |
|
|
| if args.resume and ATTN_DIAGNOSTIC_PATH.exists(): |
| log.info("Diagnostic already done, skipping.") |
| return |
|
|
| target_layers = read_json(TARGET_LAYERS_PATH)["union_layers"] |
| log.info(f"Target layers: {target_layers}") |
|
|
| records = read_jsonl(LABELED_COTS_PATH)[:args.n_samples] |
| log.info(f"Using {len(records)} labeled CoTs for attention capture") |
|
|
| log.info("Loading model...") |
| model, tokenizer = load_model_and_tokenizer() |
|
|
| |
| cats = ["plan", "mon", "exec"] |
| acc = {li: {c: [] for c in cats} for li in target_layers} |
|
|
| for rec in tqdm(records, desc="capture attn"): |
| text = rec["cot"] |
| plan_tis = rec["plan_decision_tis"] |
| mon_tis = rec["mon_decision_tis"] |
| exec_tis = rec["exec_decision_tis"] |
|
|
| enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False) |
| if enc["input_ids"].shape[1] != len(rec["token_ids"]): |
| continue |
| input_ids = enc["input_ids"].to(model.device) |
|
|
| cap = AttentionOutputCapture(model, target_layers=target_layers) |
| cap.start() |
| try: |
| with torch.no_grad(): |
| _ = model(input_ids) |
| finally: |
| attn_outs = cap.stop() |
|
|
| for li in target_layers: |
| if li not in attn_outs: |
| continue |
| h = attn_outs[li] |
| if plan_tis: acc[li]["plan"].append(h[plan_tis]) |
| if mon_tis: acc[li]["mon"].append(h[mon_tis]) |
| if exec_tis: acc[li]["exec"].append(h[exec_tis]) |
|
|
| cleanup_memory() |
|
|
| |
| del model |
| cleanup_memory() |
|
|
| |
| log.info("Computing attention mean-diff norms...") |
| diagnostic = {"layers": {}} |
|
|
| for li in target_layers: |
| plan = torch.cat(acc[li]["plan"], dim=0).to(torch.float32) if acc[li]["plan"] else None |
| mon = torch.cat(acc[li]["mon"], dim=0).to(torch.float32) if acc[li]["mon"] else None |
| execu= torch.cat(acc[li]["exec"], dim=0).to(torch.float32) if acc[li]["exec"] else None |
| layer_info = {} |
|
|
| if plan is not None and execu is not None and plan.shape[0] > 0 and execu.shape[0] > 0: |
| attn_plan_diff = (plan.mean(0) - execu.mean(0)).norm().item() |
| layer_info["attn_plan_norm"] = attn_plan_diff |
| if mon is not None and execu is not None and mon.shape[0] > 0 and execu.shape[0] > 0: |
| attn_mon_diff = (mon.mean(0) - execu.mean(0)).norm().item() |
| layer_info["attn_mon_norm"] = attn_mon_diff |
|
|
| diagnostic["layers"][str(li)] = layer_info |
|
|
| |
| log.info("Comparing to FFN residual mean-diff norms...") |
| if RESIDUALS_PATH.exists(): |
| residuals = torch.load(RESIDUALS_PATH, map_location="cpu") |
| for li in target_layers: |
| if str(li) not in residuals: |
| continue |
| r = residuals[str(li)] |
| res_plan = r["plan"].to(torch.float32) if r["plan"].shape[0] > 0 else None |
| res_mon = r["mon"].to(torch.float32) if r["mon"].shape[0] > 0 else None |
| res_exec = r["exec"].to(torch.float32) if r["exec"].shape[0] > 0 else None |
| li_str = str(li) |
| if li_str not in diagnostic["layers"]: |
| diagnostic["layers"][li_str] = {} |
| if res_plan is not None and res_exec is not None: |
| diagnostic["layers"][li_str]["res_plan_norm"] = \ |
| (res_plan.mean(0) - res_exec.mean(0)).norm().item() |
| if res_mon is not None and res_exec is not None: |
| diagnostic["layers"][li_str]["res_mon_norm"] = \ |
| (res_mon.mean(0) - res_exec.mean(0)).norm().item() |
|
|
| |
| ratios = {"plan": [], "mon": []} |
| for li_str, info in diagnostic["layers"].items(): |
| if "attn_plan_norm" in info and "res_plan_norm" in info and info["res_plan_norm"] > 0: |
| ratios["plan"].append(info["attn_plan_norm"] / info["res_plan_norm"]) |
| if "attn_mon_norm" in info and "res_mon_norm" in info and info["res_mon_norm"] > 0: |
| ratios["mon"].append(info["attn_mon_norm"] / info["res_mon_norm"]) |
|
|
| summary = {} |
| for d in ["plan", "mon"]: |
| if ratios[d]: |
| avg = sum(ratios[d]) / len(ratios[d]) |
| mx = max(ratios[d]) |
| summary[d] = { |
| "mean_attn_to_residual_ratio": avg, |
| "max_attn_to_residual_ratio": mx, |
| "n_layers": len(ratios[d]), |
| "recommendation": ( |
| "attention also matters — consider hooking it" if avg > 0.5 |
| else "FFN-only steering OK" if avg < 0.3 |
| else "borderline, monitor" |
| ), |
| } |
| diagnostic["summary"] = summary |
|
|
| write_json(diagnostic, ATTN_DIAGNOSTIC_PATH) |
| log.info(f"Saved {ATTN_DIAGNOSTIC_PATH}") |
| log.info(f"Summary: {summary}") |
|
|
| |
| try: |
| import matplotlib.pyplot as plt |
| import numpy as np |
| layers = sorted(int(l) for l in diagnostic["layers"].keys()) |
| attn_p = [diagnostic["layers"][str(li)].get("attn_plan_norm", 0) for li in layers] |
| res_p = [diagnostic["layers"][str(li)].get("res_plan_norm", 0) for li in layers] |
| attn_m = [diagnostic["layers"][str(li)].get("attn_mon_norm", 0) for li in layers] |
| res_m = [diagnostic["layers"][str(li)].get("res_mon_norm", 0) for li in layers] |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| x = np.arange(len(layers)) |
| w = 0.4 |
| axes[0].bar(x - w/2, attn_p, w, label="attn output") |
| axes[0].bar(x + w/2, res_p, w, label="post-layer residual") |
| axes[0].set_xticks(x); axes[0].set_xticklabels(layers, rotation=90) |
| axes[0].set_title("Planning mean-diff norm by source") |
| axes[0].set_xlabel("layer"); axes[0].legend() |
|
|
| axes[1].bar(x - w/2, attn_m, w, label="attn output") |
| axes[1].bar(x + w/2, res_m, w, label="post-layer residual") |
| axes[1].set_xticks(x); axes[1].set_xticklabels(layers, rotation=90) |
| axes[1].set_title("Monitoring mean-diff norm by source") |
| axes[1].set_xlabel("layer"); axes[1].legend() |
|
|
| plt.tight_layout() |
| plt.savefig(ATTN_DIAGNOSTIC_FIG, dpi=120) |
| plt.close() |
| log.info(f"Saved {ATTN_DIAGNOSTIC_FIG}") |
| except Exception as e: |
| log.warning(f"Plot failed: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|