File size: 2,514 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Attention output capture (Apr 2026 addition).

Diagnoses whether planning/monitoring directions also live in the
attention block's output (rather than only in MoE/MLP residual).

For each target layer, hook the attention output BEFORE it gets added
back to the residual stream. If we find strong plan-vs-exec separation
in attention output, then the FFN-only steering may be incomplete.

This is a DIAGNOSTIC, not a steering mechanism. It produces a
side-by-side comparison: residual-stream direction strength vs
attention-output direction strength.
"""
import torch
from typing import Dict, List, Optional


class AttentionOutputCapture:
    """
    Hook self_attn submodule output (the attention contribution to residual,
    BEFORE the residual addition).

    Captured shape: (S, D) per layer per CoT.
    """
    def __init__(self, model, target_layers: List[int]):
        self.model = model
        self.target_layers = target_layers
        self.handles = []
        self.captured: Dict[int, torch.Tensor] = {}

    def _make_hook(self, layer_id: int):
        def fn(module, inputs, output):
            # Standard transformer self_attn returns (attn_output, ...)
            if isinstance(output, tuple):
                h = output[0]
            else:
                h = output
            # h: (B, S, D)
            self.captured[layer_id] = h.squeeze(0).to(torch.float16).cpu()
        return fn

    def start(self):
        self.captured = {}
        for li in self.target_layers:
            layer = self.model.model.layers[li]
            try:
                attn_module = layer.self_attn
            except AttributeError:
                # Skip layers without standard attribute
                continue
            h = attn_module.register_forward_hook(self._make_hook(li))
            self.handles.append(h)

    def stop(self):
        for h in self.handles:
            h.remove()
        self.handles = []
        return self.captured


def run_forward_and_capture_attn(
    model, tokenizer, text: str, target_layers: List[int]
) -> Dict:
    enc = tokenizer(text, return_tensors="pt", add_special_tokens=False, truncation=False)
    input_ids = enc["input_ids"].to(model.device)

    cap = AttentionOutputCapture(model, target_layers=target_layers)
    cap.start()
    try:
        with torch.no_grad():
            _ = model(input_ids)
    finally:
        attn_outs = cap.stop()

    return {
        "input_ids": input_ids.squeeze(0).cpu(),
        "attn_outs": attn_outs,
    }