| """ |
| Attention output capture (Apr 2026 addition). |
| |
| Diagnoses whether planning/monitoring directions also live in the |
| attention block's output (rather than only in MoE/MLP residual). |
| |
| For each target layer, hook the attention output BEFORE it gets added |
| back to the residual stream. If we find strong plan-vs-exec separation |
| in attention output, then the FFN-only steering may be incomplete. |
| |
| This is a DIAGNOSTIC, not a steering mechanism. It produces a |
| side-by-side comparison: residual-stream direction strength vs |
| attention-output direction strength. |
| """ |
| import torch |
| from typing import Dict, List, Optional |
|
|
|
|
| class AttentionOutputCapture: |
| """ |
| Hook self_attn submodule output (the attention contribution to residual, |
| BEFORE the residual addition). |
| |
| Captured shape: (S, D) per layer per CoT. |
| """ |
| def __init__(self, model, target_layers: List[int]): |
| self.model = model |
| self.target_layers = target_layers |
| self.handles = [] |
| self.captured: Dict[int, torch.Tensor] = {} |
|
|
| def _make_hook(self, layer_id: int): |
| def fn(module, inputs, output): |
| |
| if isinstance(output, tuple): |
| h = output[0] |
| else: |
| h = output |
| |
| self.captured[layer_id] = h.squeeze(0).to(torch.float16).cpu() |
| return fn |
|
|
| def start(self): |
| self.captured = {} |
| for li in self.target_layers: |
| layer = self.model.model.layers[li] |
| try: |
| attn_module = layer.self_attn |
| except AttributeError: |
| |
| continue |
| h = attn_module.register_forward_hook(self._make_hook(li)) |
| self.handles.append(h) |
|
|
| def stop(self): |
| for h in self.handles: |
| h.remove() |
| self.handles = [] |
| return self.captured |
|
|
|
|
| def run_forward_and_capture_attn( |
| model, tokenizer, text: str, target_layers: List[int] |
| ) -> Dict: |
| enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False) |
| input_ids = enc["input_ids"].to(model.device) |
|
|
| cap = AttentionOutputCapture(model, target_layers=target_layers) |
| cap.start() |
| try: |
| with torch.no_grad(): |
| _ = model(input_ids) |
| finally: |
| attn_outs = cap.stop() |
|
|
| return { |
| "input_ids": input_ids.squeeze(0).cpu(), |
| "attn_outs": attn_outs, |
| } |
|
|