File size: 8,246 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | """
Stage 3.5 (DIAGNOSTIC): Attention output direction analysis.
Goal: check whether planning/monitoring directions are well-represented
in the ATTENTION OUTPUT of each layer (not only in the post-MLP residual).
If attention output also shows strong plan-vs-exec separation, our
FFN-residual-based steering may miss this signal. This script prints
a comparison and saves a JSON / figure but does NOT modify steering.
Output: data/results/attention_diagnostic.{json,png}
Decision rule (informational, not auto-applied):
- If attention mean-diff norm > 50% of MLP residual mean-diff norm,
consider also hooking attention output during steering.
- If < 30%, FFN-only steering is fine (current pipeline).
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import (
ensure_dirs, LOGS_DIR, LABELED_COTS_PATH,
TARGET_LAYERS_PATH,
ATTN_RESIDUALS_PATH, ATTN_DIAGNOSTIC_PATH, ATTN_DIAGNOSTIC_FIG,
RESIDUALS_PATH,
)
from src.utils import setup_logger, read_jsonl, read_json, write_json, cleanup_memory
from src.model_io import load_model_and_tokenizer
from src.attention_capture import AttentionOutputCapture
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--n_samples", type=int, default=50,
help="# of training CoTs to use (subset of labeled set)")
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("08b_attn", LOGS_DIR / "08b_attn.log")
if args.resume and ATTN_DIAGNOSTIC_PATH.exists():
log.info("Diagnostic already done, skipping.")
return
target_layers = read_json(TARGET_LAYERS_PATH)["union_layers"]
log.info(f"Target layers: {target_layers}")
records = read_jsonl(LABELED_COTS_PATH)[:args.n_samples]
log.info(f"Using {len(records)} labeled CoTs for attention capture")
log.info("Loading model...")
model, tokenizer = load_model_and_tokenizer()
# Accumulate attention outputs at decision points
cats = ["plan", "mon", "exec"]
acc = {li: {c: [] for c in cats} for li in target_layers}
for rec in tqdm(records, desc="capture attn"):
text = rec["cot"]
plan_tis = rec["plan_decision_tis"]
mon_tis = rec["mon_decision_tis"]
exec_tis = rec["exec_decision_tis"]
enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False)
if enc["input_ids"].shape[1] != len(rec["token_ids"]):
continue
input_ids = enc["input_ids"].to(model.device)
cap = AttentionOutputCapture(model, target_layers=target_layers)
cap.start()
try:
with torch.no_grad():
_ = model(input_ids)
finally:
attn_outs = cap.stop()
for li in target_layers:
if li not in attn_outs:
continue
h = attn_outs[li]
if plan_tis: acc[li]["plan"].append(h[plan_tis])
if mon_tis: acc[li]["mon"].append(h[mon_tis])
if exec_tis: acc[li]["exec"].append(h[exec_tis])
cleanup_memory()
# Free model
del model
cleanup_memory()
# Compute attention mean-diff norms; compare to residual mean-diff norms
log.info("Computing attention mean-diff norms...")
diagnostic = {"layers": {}}
for li in target_layers:
plan = torch.cat(acc[li]["plan"], dim=0).to(torch.float32) if acc[li]["plan"] else None
mon = torch.cat(acc[li]["mon"], dim=0).to(torch.float32) if acc[li]["mon"] else None
execu= torch.cat(acc[li]["exec"], dim=0).to(torch.float32) if acc[li]["exec"] else None
layer_info = {}
if plan is not None and execu is not None and plan.shape[0] > 0 and execu.shape[0] > 0:
attn_plan_diff = (plan.mean(0) - execu.mean(0)).norm().item()
layer_info["attn_plan_norm"] = attn_plan_diff
if mon is not None and execu is not None and mon.shape[0] > 0 and execu.shape[0] > 0:
attn_mon_diff = (mon.mean(0) - execu.mean(0)).norm().item()
layer_info["attn_mon_norm"] = attn_mon_diff
diagnostic["layers"][str(li)] = layer_info
# Compare to residual norms (load existing v1_raw directions and compute their norms)
log.info("Comparing to FFN residual mean-diff norms...")
if RESIDUALS_PATH.exists():
residuals = torch.load(RESIDUALS_PATH, map_location="cpu")
for li in target_layers:
if str(li) not in residuals:
continue
r = residuals[str(li)]
res_plan = r["plan"].to(torch.float32) if r["plan"].shape[0] > 0 else None
res_mon = r["mon"].to(torch.float32) if r["mon"].shape[0] > 0 else None
res_exec = r["exec"].to(torch.float32) if r["exec"].shape[0] > 0 else None
li_str = str(li)
if li_str not in diagnostic["layers"]:
diagnostic["layers"][li_str] = {}
if res_plan is not None and res_exec is not None:
diagnostic["layers"][li_str]["res_plan_norm"] = \
(res_plan.mean(0) - res_exec.mean(0)).norm().item()
if res_mon is not None and res_exec is not None:
diagnostic["layers"][li_str]["res_mon_norm"] = \
(res_mon.mean(0) - res_exec.mean(0)).norm().item()
# Compute summary ratios
ratios = {"plan": [], "mon": []}
for li_str, info in diagnostic["layers"].items():
if "attn_plan_norm" in info and "res_plan_norm" in info and info["res_plan_norm"] > 0:
ratios["plan"].append(info["attn_plan_norm"] / info["res_plan_norm"])
if "attn_mon_norm" in info and "res_mon_norm" in info and info["res_mon_norm"] > 0:
ratios["mon"].append(info["attn_mon_norm"] / info["res_mon_norm"])
summary = {}
for d in ["plan", "mon"]:
if ratios[d]:
avg = sum(ratios[d]) / len(ratios[d])
mx = max(ratios[d])
summary[d] = {
"mean_attn_to_residual_ratio": avg,
"max_attn_to_residual_ratio": mx,
"n_layers": len(ratios[d]),
"recommendation": (
"attention also matters — consider hooking it" if avg > 0.5
else "FFN-only steering OK" if avg < 0.3
else "borderline, monitor"
),
}
diagnostic["summary"] = summary
write_json(diagnostic, ATTN_DIAGNOSTIC_PATH)
log.info(f"Saved {ATTN_DIAGNOSTIC_PATH}")
log.info(f"Summary: {summary}")
# Plot
try:
import matplotlib.pyplot as plt
import numpy as np
layers = sorted(int(l) for l in diagnostic["layers"].keys())
attn_p = [diagnostic["layers"][str(li)].get("attn_plan_norm", 0) for li in layers]
res_p = [diagnostic["layers"][str(li)].get("res_plan_norm", 0) for li in layers]
attn_m = [diagnostic["layers"][str(li)].get("attn_mon_norm", 0) for li in layers]
res_m = [diagnostic["layers"][str(li)].get("res_mon_norm", 0) for li in layers]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
x = np.arange(len(layers))
w = 0.4
axes[0].bar(x - w/2, attn_p, w, label="attn output")
axes[0].bar(x + w/2, res_p, w, label="post-layer residual")
axes[0].set_xticks(x); axes[0].set_xticklabels(layers, rotation=90)
axes[0].set_title("Planning mean-diff norm by source")
axes[0].set_xlabel("layer"); axes[0].legend()
axes[1].bar(x - w/2, attn_m, w, label="attn output")
axes[1].bar(x + w/2, res_m, w, label="post-layer residual")
axes[1].set_xticks(x); axes[1].set_xticklabels(layers, rotation=90)
axes[1].set_title("Monitoring mean-diff norm by source")
axes[1].set_xlabel("layer"); axes[1].legend()
plt.tight_layout()
plt.savefig(ATTN_DIAGNOSTIC_FIG, dpi=120)
plt.close()
log.info(f"Saved {ATTN_DIAGNOSTIC_FIG}")
except Exception as e:
log.warning(f"Plot failed: {e}")
if __name__ == "__main__":
main()
|