v2 / scripts /08b_attention_diagnostic.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Stage 3.5 (DIAGNOSTIC): Attention output direction analysis.
Goal: check whether planning/monitoring directions are well-represented
in the ATTENTION OUTPUT of each layer (not only in the post-MLP residual).
If attention output also shows strong plan-vs-exec separation, our
FFN-residual-based steering may miss this signal. This script prints
a comparison and saves a JSON / figure but does NOT modify steering.
Output: data/results/attention_diagnostic.{json,png}
Decision rule (informational, not auto-applied):
- If attention mean-diff norm > 50% of MLP residual mean-diff norm,
consider also hooking attention output during steering.
- If < 30%, FFN-only steering is fine (current pipeline).
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import (
ensure_dirs, LOGS_DIR, LABELED_COTS_PATH,
TARGET_LAYERS_PATH,
ATTN_RESIDUALS_PATH, ATTN_DIAGNOSTIC_PATH, ATTN_DIAGNOSTIC_FIG,
RESIDUALS_PATH,
)
from src.utils import setup_logger, read_jsonl, read_json, write_json, cleanup_memory
from src.model_io import load_model_and_tokenizer
from src.attention_capture import AttentionOutputCapture
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--n_samples", type=int, default=50,
help="# of training CoTs to use (subset of labeled set)")
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("08b_attn", LOGS_DIR / "08b_attn.log")
if args.resume and ATTN_DIAGNOSTIC_PATH.exists():
log.info("Diagnostic already done, skipping.")
return
target_layers = read_json(TARGET_LAYERS_PATH)["union_layers"]
log.info(f"Target layers: {target_layers}")
records = read_jsonl(LABELED_COTS_PATH)[:args.n_samples]
log.info(f"Using {len(records)} labeled CoTs for attention capture")
log.info("Loading model...")
model, tokenizer = load_model_and_tokenizer()
# Accumulate attention outputs at decision points
cats = ["plan", "mon", "exec"]
acc = {li: {c: [] for c in cats} for li in target_layers}
for rec in tqdm(records, desc="capture attn"):
text = rec["cot"]
plan_tis = rec["plan_decision_tis"]
mon_tis = rec["mon_decision_tis"]
exec_tis = rec["exec_decision_tis"]
enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False)
if enc["input_ids"].shape[1] != len(rec["token_ids"]):
continue
input_ids = enc["input_ids"].to(model.device)
cap = AttentionOutputCapture(model, target_layers=target_layers)
cap.start()
try:
with torch.no_grad():
_ = model(input_ids)
finally:
attn_outs = cap.stop()
for li in target_layers:
if li not in attn_outs:
continue
h = attn_outs[li]
if plan_tis: acc[li]["plan"].append(h[plan_tis])
if mon_tis: acc[li]["mon"].append(h[mon_tis])
if exec_tis: acc[li]["exec"].append(h[exec_tis])
cleanup_memory()
# Free model
del model
cleanup_memory()
# Compute attention mean-diff norms; compare to residual mean-diff norms
log.info("Computing attention mean-diff norms...")
diagnostic = {"layers": {}}
for li in target_layers:
plan = torch.cat(acc[li]["plan"], dim=0).to(torch.float32) if acc[li]["plan"] else None
mon = torch.cat(acc[li]["mon"], dim=0).to(torch.float32) if acc[li]["mon"] else None
execu= torch.cat(acc[li]["exec"], dim=0).to(torch.float32) if acc[li]["exec"] else None
layer_info = {}
if plan is not None and execu is not None and plan.shape[0] > 0 and execu.shape[0] > 0:
attn_plan_diff = (plan.mean(0) - execu.mean(0)).norm().item()
layer_info["attn_plan_norm"] = attn_plan_diff
if mon is not None and execu is not None and mon.shape[0] > 0 and execu.shape[0] > 0:
attn_mon_diff = (mon.mean(0) - execu.mean(0)).norm().item()
layer_info["attn_mon_norm"] = attn_mon_diff
diagnostic["layers"][str(li)] = layer_info
# Compare to residual norms (load existing v1_raw directions and compute their norms)
log.info("Comparing to FFN residual mean-diff norms...")
if RESIDUALS_PATH.exists():
residuals = torch.load(RESIDUALS_PATH, map_location="cpu")
for li in target_layers:
if str(li) not in residuals:
continue
r = residuals[str(li)]
res_plan = r["plan"].to(torch.float32) if r["plan"].shape[0] > 0 else None
res_mon = r["mon"].to(torch.float32) if r["mon"].shape[0] > 0 else None
res_exec = r["exec"].to(torch.float32) if r["exec"].shape[0] > 0 else None
li_str = str(li)
if li_str not in diagnostic["layers"]:
diagnostic["layers"][li_str] = {}
if res_plan is not None and res_exec is not None:
diagnostic["layers"][li_str]["res_plan_norm"] = \
(res_plan.mean(0) - res_exec.mean(0)).norm().item()
if res_mon is not None and res_exec is not None:
diagnostic["layers"][li_str]["res_mon_norm"] = \
(res_mon.mean(0) - res_exec.mean(0)).norm().item()
# Compute summary ratios
ratios = {"plan": [], "mon": []}
for li_str, info in diagnostic["layers"].items():
if "attn_plan_norm" in info and "res_plan_norm" in info and info["res_plan_norm"] > 0:
ratios["plan"].append(info["attn_plan_norm"] / info["res_plan_norm"])
if "attn_mon_norm" in info and "res_mon_norm" in info and info["res_mon_norm"] > 0:
ratios["mon"].append(info["attn_mon_norm"] / info["res_mon_norm"])
summary = {}
for d in ["plan", "mon"]:
if ratios[d]:
avg = sum(ratios[d]) / len(ratios[d])
mx = max(ratios[d])
summary[d] = {
"mean_attn_to_residual_ratio": avg,
"max_attn_to_residual_ratio": mx,
"n_layers": len(ratios[d]),
"recommendation": (
"attention also matters — consider hooking it" if avg > 0.5
else "FFN-only steering OK" if avg < 0.3
else "borderline, monitor"
),
}
diagnostic["summary"] = summary
write_json(diagnostic, ATTN_DIAGNOSTIC_PATH)
log.info(f"Saved {ATTN_DIAGNOSTIC_PATH}")
log.info(f"Summary: {summary}")
# Plot
try:
import matplotlib.pyplot as plt
import numpy as np
layers = sorted(int(l) for l in diagnostic["layers"].keys())
attn_p = [diagnostic["layers"][str(li)].get("attn_plan_norm", 0) for li in layers]
res_p = [diagnostic["layers"][str(li)].get("res_plan_norm", 0) for li in layers]
attn_m = [diagnostic["layers"][str(li)].get("attn_mon_norm", 0) for li in layers]
res_m = [diagnostic["layers"][str(li)].get("res_mon_norm", 0) for li in layers]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
x = np.arange(len(layers))
w = 0.4
axes[0].bar(x - w/2, attn_p, w, label="attn output")
axes[0].bar(x + w/2, res_p, w, label="post-layer residual")
axes[0].set_xticks(x); axes[0].set_xticklabels(layers, rotation=90)
axes[0].set_title("Planning mean-diff norm by source")
axes[0].set_xlabel("layer"); axes[0].legend()
axes[1].bar(x - w/2, attn_m, w, label="attn output")
axes[1].bar(x + w/2, res_m, w, label="post-layer residual")
axes[1].set_xticks(x); axes[1].set_xticklabels(layers, rotation=90)
axes[1].set_title("Monitoring mean-diff norm by source")
axes[1].set_xlabel("layer"); axes[1].legend()
plt.tight_layout()
plt.savefig(ATTN_DIAGNOSTIC_FIG, dpi=120)
plt.close()
log.info(f"Saved {ATTN_DIAGNOSTIC_FIG}")
except Exception as e:
log.warning(f"Plot failed: {e}")
if __name__ == "__main__":
main()