v2 / src /attention_capture.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Attention output capture (Apr 2026 addition).
Diagnoses whether planning/monitoring directions also live in the
attention block's output (rather than only in MoE/MLP residual).
For each target layer, hook the attention output BEFORE it gets added
back to the residual stream. If we find strong plan-vs-exec separation
in attention output, then the FFN-only steering may be incomplete.
This is a DIAGNOSTIC, not a steering mechanism. It produces a
side-by-side comparison: residual-stream direction strength vs
attention-output direction strength.
"""
import torch
from typing import Dict, List, Optional
class AttentionOutputCapture:
"""
Hook self_attn submodule output (the attention contribution to residual,
BEFORE the residual addition).
Captured shape: (S, D) per layer per CoT.
"""
def __init__(self, model, target_layers: List[int]):
self.model = model
self.target_layers = target_layers
self.handles = []
self.captured: Dict[int, torch.Tensor] = {}
def _make_hook(self, layer_id: int):
def fn(module, inputs, output):
# Standard transformer self_attn returns (attn_output, ...)
if isinstance(output, tuple):
h = output[0]
else:
h = output
# h: (B, S, D)
self.captured[layer_id] = h.squeeze(0).to(torch.float16).cpu()
return fn
def start(self):
self.captured = {}
for li in self.target_layers:
layer = self.model.model.layers[li]
try:
attn_module = layer.self_attn
except AttributeError:
# Skip layers without standard attribute
continue
h = attn_module.register_forward_hook(self._make_hook(li))
self.handles.append(h)
def stop(self):
for h in self.handles:
h.remove()
self.handles = []
return self.captured
def run_forward_and_capture_attn(
model, tokenizer, text: str, target_layers: List[int]
) -> Dict:
enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False)
input_ids = enc["input_ids"].to(model.device)
cap = AttentionOutputCapture(model, target_layers=target_layers)
cap.start()
try:
with torch.no_grad():
_ = model(input_ids)
finally:
attn_outs = cap.stop()
return {
"input_ids": input_ids.squeeze(0).cpu(),
"attn_outs": attn_outs,
}