""" Attention output capture (Apr 2026 addition). Diagnoses whether planning/monitoring directions also live in the attention block's output (rather than only in MoE/MLP residual). For each target layer, hook the attention output BEFORE it gets added back to the residual stream. If we find strong plan-vs-exec separation in attention output, then the FFN-only steering may be incomplete. This is a DIAGNOSTIC, not a steering mechanism. It produces a side-by-side comparison: residual-stream direction strength vs attention-output direction strength. """ import torch from typing import Dict, List, Optional class AttentionOutputCapture: """ Hook self_attn submodule output (the attention contribution to residual, BEFORE the residual addition). Captured shape: (S, D) per layer per CoT. """ def __init__(self, model, target_layers: List[int]): self.model = model self.target_layers = target_layers self.handles = [] self.captured: Dict[int, torch.Tensor] = {} def _make_hook(self, layer_id: int): def fn(module, inputs, output): # Standard transformer self_attn returns (attn_output, ...) if isinstance(output, tuple): h = output[0] else: h = output # h: (B, S, D) self.captured[layer_id] = h.squeeze(0).to(torch.float16).cpu() return fn def start(self): self.captured = {} for li in self.target_layers: layer = self.model.model.layers[li] try: attn_module = layer.self_attn except AttributeError: # Skip layers without standard attribute continue h = attn_module.register_forward_hook(self._make_hook(li)) self.handles.append(h) def stop(self): for h in self.handles: h.remove() self.handles = [] return self.captured def run_forward_and_capture_attn( model, tokenizer, text: str, target_layers: List[int] ) -> Dict: enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False) input_ids = enc["input_ids"].to(model.device) cap = AttentionOutputCapture(model, target_layers=target_layers) cap.start() try: with torch.no_grad(): _ = model(input_ids) finally: attn_outs = cap.stop() return { "input_ids": input_ids.squeeze(0).cpu(), "attn_outs": attn_outs, }