File size: 16,766 Bytes
37ed739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
"""
Attention analysis utilities for interpretability.

Implements:
1. Attention rollout (Kovaleva et al., 2019) - composition across layers
2. Head ranking by contribution
3. Helper functions for attention pattern analysis

References:
- Kovaleva et al. (2019): "Revealing the Dark Secrets of BERT"
- Clark et al. (2019): "What Does BERT Look At?"
"""

import torch
import numpy as np
from typing import Dict, List, Tuple, Optional
import logging

logger = logging.getLogger(__name__)


class AttentionRollout:
    """
    Compute attention rollout to track information flow through transformer layers.

    Attention rollout composes attention weights across layers to show which
    input tokens contribute most to each output token through the entire network.

    For layer l, rollout is computed as:
        A_rollout(l) = A_rollout(l-1) @ A(l)

    Where @ is matrix multiplication and A(l) is the attention matrix at layer l.
    """

    def __init__(self, attention_tensor: torch.Tensor, num_layers: int, num_heads: int):
        """
        Args:
            attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len]
            num_layers: Number of layers
            num_heads: Number of attention heads per layer
        """
        self.attention_tensor = attention_tensor
        self.num_layers = num_layers
        self.num_heads = num_heads

        # Will store rollout result
        self.rollout = None

    def compute_rollout(self, token_idx: int = -1, average_heads: bool = True) -> torch.Tensor:
        """
        Compute attention rollout for a specific generated token.

        Args:
            token_idx: Which generated token to analyze (-1 = last token)
            average_heads: Whether to average across heads before composition

        Returns:
            Rollout matrix [num_layers, seq_len, seq_len]
            or [num_layers, num_heads, seq_len, seq_len] if not averaging
        """
        # Extract attention for specific token
        # Shape: [num_layers, num_heads, seq_len, seq_len]
        attn = self.attention_tensor[token_idx]

        if average_heads:
            # Average across heads first
            # Shape: [num_layers, seq_len, seq_len]
            attn = attn.mean(dim=1)

        # Initialize rollout with identity matrix (no attention = self-attention)
        seq_len = attn.shape[-1]

        if average_heads:
            rollout = [torch.eye(seq_len)]
        else:
            # Keep heads separate
            rollout = [torch.eye(seq_len).unsqueeze(0).repeat(self.num_heads, 1, 1)]

        # Compose attention across layers
        # We build rollout from layer 0 to layer L, multiplying in the correct order:
        # rollout = attn[L] @ attn[L-1] @ ... @ attn[0]
        # To build iteratively, we apply new layers on the LEFT: new_rollout = attn[i] @ old_rollout
        for layer_idx in range(self.num_layers):
            layer_attn = attn[layer_idx]

            if average_heads:
                # Apply new layer attention on the left
                # Shape: [seq_len, seq_len]
                rollout.append(layer_attn @ rollout[-1])
            else:
                # Multiply each head separately, new layer on the left
                # Shape: [num_heads, seq_len, seq_len]
                prev_rollout = rollout[-1]
                new_rollout = torch.bmm(layer_attn, prev_rollout)
                rollout.append(new_rollout)

        # Stack into tensor
        # Shape: [num_layers+1, seq_len, seq_len] or [num_layers+1, num_heads, seq_len, seq_len]
        self.rollout = torch.stack(rollout)

        # Normalize rollout so each row sums to 1
        # After composing attention, rows don't sum to 1 anymore
        # We renormalize to maintain interpretability as attention weights
        if average_heads:
            # Shape: [num_layers+1, seq_len, seq_len]
            row_sums = self.rollout.sum(dim=-1, keepdim=True)
            # Avoid division by zero
            row_sums = torch.clamp(row_sums, min=1e-10)
            self.rollout = self.rollout / row_sums
        else:
            # Shape: [num_layers+1, num_heads, seq_len, seq_len]
            row_sums = self.rollout.sum(dim=-1, keepdim=True)
            row_sums = torch.clamp(row_sums, min=1e-10)
            self.rollout = self.rollout / row_sums

        logger.info(f"Computed attention rollout: shape={self.rollout.shape}")

        # Debug: Check if rollout looks reasonable
        if self.rollout.shape[0] > 0:
            sample_weights = self.rollout[-1, 0, :]  # Last layer, first position, all targets
            logger.info(f"Sample rollout weights (pos 0): min={sample_weights.min().item():.6f}, max={sample_weights.max().item():.6f}, sum={sample_weights.sum().item():.6f}")

        return self.rollout

    def get_top_sources(self, target_token_idx: int, layer_idx: int, k: int = 8) -> List[Tuple[int, float]]:
        """
        Get top-k source tokens that contribute most to target token at a specific layer.

        Args:
            target_token_idx: Index of target token in sequence
            layer_idx: Which layer's rollout to use
            k: Number of top sources to return

        Returns:
            List of (source_idx, weight) tuples, sorted by weight descending
        """
        if self.rollout is None:
            raise ValueError("Must call compute_rollout() first")

        # Get rollout weights for target token
        # Shape: [seq_len] (attention from all sources to target)
        weights = self.rollout[layer_idx, :, target_token_idx]

        # Get top-k
        top_values, top_indices = torch.topk(weights, k=min(k, len(weights)))

        # Convert to list of tuples
        top_sources = [
            (idx.item(), val.item())
            for idx, val in zip(top_indices, top_values)
        ]

        return top_sources


class HeadRanker:
    """
    Rank attention heads by their contribution to model predictions.

    Multiple ranking strategies:
    1. Rollout contribution: How much each head's attention flows to output
    2. Mean max weight: Average of maximum attention weight per head
    3. Entropy: Uncertainty in head's attention distribution
    """

    def __init__(self, attention_tensor: torch.Tensor, num_layers: int, num_heads: int):
        """
        Args:
            attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len]
            num_layers: Number of layers
            num_heads: Number of heads per layer
        """
        self.attention_tensor = attention_tensor
        self.num_layers = num_layers
        self.num_heads = num_heads

    def rank_by_rollout_contribution(self, token_idx: int = -1, top_k: int = 20) -> List[Tuple[int, int, float]]:
        """
        Rank heads by their rollout contribution.

        This measures how much information from each head flows to the final output.

        Args:
            token_idx: Which generated token to analyze
            top_k: Number of top heads to return

        Returns:
            List of (layer_idx, head_idx, contribution_score) tuples
        """
        # Compute rollout without averaging heads
        rollout_computer = AttentionRollout(self.attention_tensor, self.num_layers, self.num_heads)
        rollout = rollout_computer.compute_rollout(token_idx=token_idx, average_heads=False)

        # For each head, compute contribution as sum of rollout weights
        # Shape: [num_layers+1, num_heads, seq_len, seq_len]
        head_contributions = []

        for layer_idx in range(self.num_layers):
            for head_idx in range(self.num_heads):
                # Sum of all attention weights in final rollout for this head
                contribution = rollout[-1, head_idx].sum().item()
                head_contributions.append((layer_idx, head_idx, contribution))

        # Sort by contribution descending
        head_contributions.sort(key=lambda x: x[2], reverse=True)

        # Return top-k
        return head_contributions[:top_k]

    def rank_by_max_weight(self, top_k: int = 20) -> List[Tuple[int, int, float]]:
        """
        Rank heads by average maximum attention weight.

        Heads with high max weights are focusing strongly on specific tokens.

        Args:
            top_k: Number of top heads to return

        Returns:
            List of (layer_idx, head_idx, avg_max_weight) tuples
        """
        head_scores = []

        # Average across all generated tokens
        attn = self.attention_tensor.mean(dim=0)  # [num_layers, num_heads, seq_len, seq_len]

        for layer_idx in range(self.num_layers):
            for head_idx in range(self.num_heads):
                # Get max attention weight for each target token, then average
                head_attn = attn[layer_idx, head_idx]  # [seq_len, seq_len]
                max_weights = head_attn.max(dim=0)[0]  # Max per target token
                avg_max = max_weights.mean().item()

                head_scores.append((layer_idx, head_idx, avg_max))

        # Sort by score descending
        head_scores.sort(key=lambda x: x[2], reverse=True)

        return head_scores[:top_k]

    def rank_by_entropy(self, top_k: int = 20, high_entropy: bool = False) -> List[Tuple[int, int, float]]:
        """
        Rank heads by attention distribution entropy.

        Low entropy = focused attention (head attends to few tokens)
        High entropy = diffuse attention (head attends to many tokens)

        Args:
            top_k: Number of top heads to return
            high_entropy: If True, return highest entropy heads; if False, return lowest

        Returns:
            List of (layer_idx, head_idx, entropy) tuples
        """
        head_entropies = []

        # Average across all generated tokens
        attn = self.attention_tensor.mean(dim=0)  # [num_layers, num_heads, seq_len, seq_len]

        for layer_idx in range(self.num_layers):
            for head_idx in range(self.num_heads):
                head_attn = attn[layer_idx, head_idx]  # [seq_len, seq_len]

                # Compute entropy for each target token's attention distribution
                # H = -sum(p * log(p))
                entropy_per_token = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=0)
                avg_entropy = entropy_per_token.mean().item()

                head_entropies.append((layer_idx, head_idx, avg_entropy))

        # Sort by entropy
        head_entropies.sort(key=lambda x: x[2], reverse=high_entropy)

        return head_entropies[:top_k]


def identify_head_roles(attention_tensor: torch.Tensor, tokens: List[str],
                        num_layers: int, num_heads: int) -> Dict[str, List[Tuple[int, int]]]:
    """
    Identify potential roles of attention heads based on attention patterns.

    Heuristics:
    - Delimiter heads: High attention to brackets, colons, etc.
    - Positional heads: Attend primarily to adjacent tokens
    - Broad heads: Uniform attention across many tokens

    Args:
        attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len]
        tokens: List of token strings
        num_layers: Number of layers
        num_heads: Number of heads

    Returns:
        Dictionary mapping role names to list of (layer_idx, head_idx) tuples
    """
    delimiter_tokens = {'(', ')', '{', '}', '[', ']', ':', ',', ';'}
    roles = {
        'delimiter_focused': [],
        'positional': [],
        'broad': []
    }

    # Average across all generated tokens
    attn = attention_tensor.mean(dim=0)  # [num_layers, num_heads, seq_len, seq_len]

    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            head_attn = attn[layer_idx, head_idx]  # [seq_len, seq_len]

            # Check for delimiter focus
            delimiter_indices = [i for i, tok in enumerate(tokens) if tok in delimiter_tokens]
            if delimiter_indices:
                delimiter_attention = head_attn[:, delimiter_indices].mean().item()
                if delimiter_attention > 0.5:  # Threshold
                    roles['delimiter_focused'].append((layer_idx, head_idx))

            # Check for positional pattern (diagonal attention)
            # Create diagonal mask
            diagonal_mask = torch.eye(head_attn.shape[0], dtype=torch.bool)
            adjacent_mask = diagonal_mask.roll(1, dims=1) | diagonal_mask.roll(-1, dims=1)
            positional_attention = head_attn[adjacent_mask].mean().item()
            if positional_attention > 0.6:
                roles['positional'].append((layer_idx, head_idx))

            # Check for broad attention (high entropy)
            entropy = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=1).mean().item()
            if entropy > 2.0:  # Threshold
                roles['broad'].append((layer_idx, head_idx))

    logger.info(f"Identified head roles: {[(k, len(v)) for k, v in roles.items()]}")

    return roles


def compute_token_attention_maps(attention_tensor: torch.Tensor,
                                  prompt_tokens: List[str],
                                  generated_tokens: List[str],
                                  num_layers: int,
                                  num_heads: int,
                                  prompt_length: int) -> List[Dict]:
    """
    Compute attention maps showing which prompt tokens each generated token attends to.

    This creates the INPUT → INTERNALS → OUTPUT connection for visualization.

    Args:
        attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len]
        prompt_tokens: List of tokens in the prompt
        generated_tokens: List of generated tokens
        num_layers: Number of layers
        num_heads: Number of heads
        prompt_length: Number of tokens in the prompt

    Returns:
        List of dicts, one per generated token:
        [{
            'token_idx': int,
            'token': str,
            'attention_to_prompt': [
                {'prompt_idx': int, 'prompt_token': str, 'weight': float},
                ...
            ]
        }]
    """
    token_maps = []

    for token_idx, token in enumerate(generated_tokens):
        # Get attention for this token: [num_layers, num_heads, seq_len, seq_len]
        token_attn = attention_tensor[token_idx]

        # Average across all layers and heads to get overall attention pattern
        # Shape: [seq_len, seq_len]
        avg_attn = token_attn.mean(dim=0).mean(dim=0)

        # When generating this token, the model is at the last position
        # in the current sequence (before adding the new token)
        # Sequence length at generation time: prompt_length + token_idx
        # Last position index: prompt_length + token_idx - 1
        current_pos = prompt_length + token_idx - 1 if token_idx > 0 else prompt_length - 1

        # Extract attention FROM current position TO prompt tokens
        # This shows which prompt tokens the model attended to when generating this token
        # Shape: [prompt_length]
        attention_to_prompt = avg_attn[current_pos, :prompt_length]

        # Debug: Log sample attention weights for first token
        if token_idx == 0:
            logger.info(f"Token 0 attention weights: min={attention_to_prompt.min().item():.6f}, max={attention_to_prompt.max().item():.6f}, sum={attention_to_prompt.sum().item():.6f}")
            logger.info(f"First 5 weights: {attention_to_prompt[:5].tolist()}")

        # Create list of prompt token attentions
        prompt_attentions = []
        for prompt_idx in range(prompt_length):
            prompt_attentions.append({
                'prompt_idx': prompt_idx,
                'prompt_token': prompt_tokens[prompt_idx] if prompt_idx < len(prompt_tokens) else f'<{prompt_idx}>',
                'weight': attention_to_prompt[prompt_idx].item()
            })

        # Sort by weight descending
        prompt_attentions.sort(key=lambda x: x['weight'], reverse=True)

        token_maps.append({
            'token_idx': token_idx,
            'token': token,
            'position': current_pos,
            'attention_to_prompt': prompt_attentions
        })

    logger.info(f"Computed attention maps for {len(token_maps)} generated tokens")

    return token_maps


# Example usage
if __name__ == "__main__":
    print("Attention analysis module loaded successfully")

    # Example: Compute rollout on fake data
    # num_tokens, num_layers, num_heads, seq_len = 5, 4, 8, 16
    # fake_attn = torch.softmax(torch.randn(num_tokens, num_layers, num_heads, seq_len, seq_len), dim=-1)
    #
    # rollout = AttentionRollout(fake_attn, num_layers, num_heads)
    # result = rollout.compute_rollout(token_idx=0)
    # print(f"Rollout shape: {result.shape}")