Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 16,766 Bytes
37ed739 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 |
"""
Attention analysis utilities for interpretability.
Implements:
1. Attention rollout (Kovaleva et al., 2019) - composition across layers
2. Head ranking by contribution
3. Helper functions for attention pattern analysis
References:
- Kovaleva et al. (2019): "Revealing the Dark Secrets of BERT"
- Clark et al. (2019): "What Does BERT Look At?"
"""
import torch
import numpy as np
from typing import Dict, List, Tuple, Optional
import logging
logger = logging.getLogger(__name__)
class AttentionRollout:
"""
Compute attention rollout to track information flow through transformer layers.
Attention rollout composes attention weights across layers to show which
input tokens contribute most to each output token through the entire network.
For layer l, rollout is computed as:
A_rollout(l) = A_rollout(l-1) @ A(l)
Where @ is matrix multiplication and A(l) is the attention matrix at layer l.
"""
def __init__(self, attention_tensor: torch.Tensor, num_layers: int, num_heads: int):
"""
Args:
attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len]
num_layers: Number of layers
num_heads: Number of attention heads per layer
"""
self.attention_tensor = attention_tensor
self.num_layers = num_layers
self.num_heads = num_heads
# Will store rollout result
self.rollout = None
def compute_rollout(self, token_idx: int = -1, average_heads: bool = True) -> torch.Tensor:
"""
Compute attention rollout for a specific generated token.
Args:
token_idx: Which generated token to analyze (-1 = last token)
average_heads: Whether to average across heads before composition
Returns:
Rollout matrix [num_layers, seq_len, seq_len]
or [num_layers, num_heads, seq_len, seq_len] if not averaging
"""
# Extract attention for specific token
# Shape: [num_layers, num_heads, seq_len, seq_len]
attn = self.attention_tensor[token_idx]
if average_heads:
# Average across heads first
# Shape: [num_layers, seq_len, seq_len]
attn = attn.mean(dim=1)
# Initialize rollout with identity matrix (no attention = self-attention)
seq_len = attn.shape[-1]
if average_heads:
rollout = [torch.eye(seq_len)]
else:
# Keep heads separate
rollout = [torch.eye(seq_len).unsqueeze(0).repeat(self.num_heads, 1, 1)]
# Compose attention across layers
# We build rollout from layer 0 to layer L, multiplying in the correct order:
# rollout = attn[L] @ attn[L-1] @ ... @ attn[0]
# To build iteratively, we apply new layers on the LEFT: new_rollout = attn[i] @ old_rollout
for layer_idx in range(self.num_layers):
layer_attn = attn[layer_idx]
if average_heads:
# Apply new layer attention on the left
# Shape: [seq_len, seq_len]
rollout.append(layer_attn @ rollout[-1])
else:
# Multiply each head separately, new layer on the left
# Shape: [num_heads, seq_len, seq_len]
prev_rollout = rollout[-1]
new_rollout = torch.bmm(layer_attn, prev_rollout)
rollout.append(new_rollout)
# Stack into tensor
# Shape: [num_layers+1, seq_len, seq_len] or [num_layers+1, num_heads, seq_len, seq_len]
self.rollout = torch.stack(rollout)
# Normalize rollout so each row sums to 1
# After composing attention, rows don't sum to 1 anymore
# We renormalize to maintain interpretability as attention weights
if average_heads:
# Shape: [num_layers+1, seq_len, seq_len]
row_sums = self.rollout.sum(dim=-1, keepdim=True)
# Avoid division by zero
row_sums = torch.clamp(row_sums, min=1e-10)
self.rollout = self.rollout / row_sums
else:
# Shape: [num_layers+1, num_heads, seq_len, seq_len]
row_sums = self.rollout.sum(dim=-1, keepdim=True)
row_sums = torch.clamp(row_sums, min=1e-10)
self.rollout = self.rollout / row_sums
logger.info(f"Computed attention rollout: shape={self.rollout.shape}")
# Debug: Check if rollout looks reasonable
if self.rollout.shape[0] > 0:
sample_weights = self.rollout[-1, 0, :] # Last layer, first position, all targets
logger.info(f"Sample rollout weights (pos 0): min={sample_weights.min().item():.6f}, max={sample_weights.max().item():.6f}, sum={sample_weights.sum().item():.6f}")
return self.rollout
def get_top_sources(self, target_token_idx: int, layer_idx: int, k: int = 8) -> List[Tuple[int, float]]:
"""
Get top-k source tokens that contribute most to target token at a specific layer.
Args:
target_token_idx: Index of target token in sequence
layer_idx: Which layer's rollout to use
k: Number of top sources to return
Returns:
List of (source_idx, weight) tuples, sorted by weight descending
"""
if self.rollout is None:
raise ValueError("Must call compute_rollout() first")
# Get rollout weights for target token
# Shape: [seq_len] (attention from all sources to target)
weights = self.rollout[layer_idx, :, target_token_idx]
# Get top-k
top_values, top_indices = torch.topk(weights, k=min(k, len(weights)))
# Convert to list of tuples
top_sources = [
(idx.item(), val.item())
for idx, val in zip(top_indices, top_values)
]
return top_sources
class HeadRanker:
"""
Rank attention heads by their contribution to model predictions.
Multiple ranking strategies:
1. Rollout contribution: How much each head's attention flows to output
2. Mean max weight: Average of maximum attention weight per head
3. Entropy: Uncertainty in head's attention distribution
"""
def __init__(self, attention_tensor: torch.Tensor, num_layers: int, num_heads: int):
"""
Args:
attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len]
num_layers: Number of layers
num_heads: Number of heads per layer
"""
self.attention_tensor = attention_tensor
self.num_layers = num_layers
self.num_heads = num_heads
def rank_by_rollout_contribution(self, token_idx: int = -1, top_k: int = 20) -> List[Tuple[int, int, float]]:
"""
Rank heads by their rollout contribution.
This measures how much information from each head flows to the final output.
Args:
token_idx: Which generated token to analyze
top_k: Number of top heads to return
Returns:
List of (layer_idx, head_idx, contribution_score) tuples
"""
# Compute rollout without averaging heads
rollout_computer = AttentionRollout(self.attention_tensor, self.num_layers, self.num_heads)
rollout = rollout_computer.compute_rollout(token_idx=token_idx, average_heads=False)
# For each head, compute contribution as sum of rollout weights
# Shape: [num_layers+1, num_heads, seq_len, seq_len]
head_contributions = []
for layer_idx in range(self.num_layers):
for head_idx in range(self.num_heads):
# Sum of all attention weights in final rollout for this head
contribution = rollout[-1, head_idx].sum().item()
head_contributions.append((layer_idx, head_idx, contribution))
# Sort by contribution descending
head_contributions.sort(key=lambda x: x[2], reverse=True)
# Return top-k
return head_contributions[:top_k]
def rank_by_max_weight(self, top_k: int = 20) -> List[Tuple[int, int, float]]:
"""
Rank heads by average maximum attention weight.
Heads with high max weights are focusing strongly on specific tokens.
Args:
top_k: Number of top heads to return
Returns:
List of (layer_idx, head_idx, avg_max_weight) tuples
"""
head_scores = []
# Average across all generated tokens
attn = self.attention_tensor.mean(dim=0) # [num_layers, num_heads, seq_len, seq_len]
for layer_idx in range(self.num_layers):
for head_idx in range(self.num_heads):
# Get max attention weight for each target token, then average
head_attn = attn[layer_idx, head_idx] # [seq_len, seq_len]
max_weights = head_attn.max(dim=0)[0] # Max per target token
avg_max = max_weights.mean().item()
head_scores.append((layer_idx, head_idx, avg_max))
# Sort by score descending
head_scores.sort(key=lambda x: x[2], reverse=True)
return head_scores[:top_k]
def rank_by_entropy(self, top_k: int = 20, high_entropy: bool = False) -> List[Tuple[int, int, float]]:
"""
Rank heads by attention distribution entropy.
Low entropy = focused attention (head attends to few tokens)
High entropy = diffuse attention (head attends to many tokens)
Args:
top_k: Number of top heads to return
high_entropy: If True, return highest entropy heads; if False, return lowest
Returns:
List of (layer_idx, head_idx, entropy) tuples
"""
head_entropies = []
# Average across all generated tokens
attn = self.attention_tensor.mean(dim=0) # [num_layers, num_heads, seq_len, seq_len]
for layer_idx in range(self.num_layers):
for head_idx in range(self.num_heads):
head_attn = attn[layer_idx, head_idx] # [seq_len, seq_len]
# Compute entropy for each target token's attention distribution
# H = -sum(p * log(p))
entropy_per_token = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=0)
avg_entropy = entropy_per_token.mean().item()
head_entropies.append((layer_idx, head_idx, avg_entropy))
# Sort by entropy
head_entropies.sort(key=lambda x: x[2], reverse=high_entropy)
return head_entropies[:top_k]
def identify_head_roles(attention_tensor: torch.Tensor, tokens: List[str],
num_layers: int, num_heads: int) -> Dict[str, List[Tuple[int, int]]]:
"""
Identify potential roles of attention heads based on attention patterns.
Heuristics:
- Delimiter heads: High attention to brackets, colons, etc.
- Positional heads: Attend primarily to adjacent tokens
- Broad heads: Uniform attention across many tokens
Args:
attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len]
tokens: List of token strings
num_layers: Number of layers
num_heads: Number of heads
Returns:
Dictionary mapping role names to list of (layer_idx, head_idx) tuples
"""
delimiter_tokens = {'(', ')', '{', '}', '[', ']', ':', ',', ';'}
roles = {
'delimiter_focused': [],
'positional': [],
'broad': []
}
# Average across all generated tokens
attn = attention_tensor.mean(dim=0) # [num_layers, num_heads, seq_len, seq_len]
for layer_idx in range(num_layers):
for head_idx in range(num_heads):
head_attn = attn[layer_idx, head_idx] # [seq_len, seq_len]
# Check for delimiter focus
delimiter_indices = [i for i, tok in enumerate(tokens) if tok in delimiter_tokens]
if delimiter_indices:
delimiter_attention = head_attn[:, delimiter_indices].mean().item()
if delimiter_attention > 0.5: # Threshold
roles['delimiter_focused'].append((layer_idx, head_idx))
# Check for positional pattern (diagonal attention)
# Create diagonal mask
diagonal_mask = torch.eye(head_attn.shape[0], dtype=torch.bool)
adjacent_mask = diagonal_mask.roll(1, dims=1) | diagonal_mask.roll(-1, dims=1)
positional_attention = head_attn[adjacent_mask].mean().item()
if positional_attention > 0.6:
roles['positional'].append((layer_idx, head_idx))
# Check for broad attention (high entropy)
entropy = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=1).mean().item()
if entropy > 2.0: # Threshold
roles['broad'].append((layer_idx, head_idx))
logger.info(f"Identified head roles: {[(k, len(v)) for k, v in roles.items()]}")
return roles
def compute_token_attention_maps(attention_tensor: torch.Tensor,
prompt_tokens: List[str],
generated_tokens: List[str],
num_layers: int,
num_heads: int,
prompt_length: int) -> List[Dict]:
"""
Compute attention maps showing which prompt tokens each generated token attends to.
This creates the INPUT → INTERNALS → OUTPUT connection for visualization.
Args:
attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len]
prompt_tokens: List of tokens in the prompt
generated_tokens: List of generated tokens
num_layers: Number of layers
num_heads: Number of heads
prompt_length: Number of tokens in the prompt
Returns:
List of dicts, one per generated token:
[{
'token_idx': int,
'token': str,
'attention_to_prompt': [
{'prompt_idx': int, 'prompt_token': str, 'weight': float},
...
]
}]
"""
token_maps = []
for token_idx, token in enumerate(generated_tokens):
# Get attention for this token: [num_layers, num_heads, seq_len, seq_len]
token_attn = attention_tensor[token_idx]
# Average across all layers and heads to get overall attention pattern
# Shape: [seq_len, seq_len]
avg_attn = token_attn.mean(dim=0).mean(dim=0)
# When generating this token, the model is at the last position
# in the current sequence (before adding the new token)
# Sequence length at generation time: prompt_length + token_idx
# Last position index: prompt_length + token_idx - 1
current_pos = prompt_length + token_idx - 1 if token_idx > 0 else prompt_length - 1
# Extract attention FROM current position TO prompt tokens
# This shows which prompt tokens the model attended to when generating this token
# Shape: [prompt_length]
attention_to_prompt = avg_attn[current_pos, :prompt_length]
# Debug: Log sample attention weights for first token
if token_idx == 0:
logger.info(f"Token 0 attention weights: min={attention_to_prompt.min().item():.6f}, max={attention_to_prompt.max().item():.6f}, sum={attention_to_prompt.sum().item():.6f}")
logger.info(f"First 5 weights: {attention_to_prompt[:5].tolist()}")
# Create list of prompt token attentions
prompt_attentions = []
for prompt_idx in range(prompt_length):
prompt_attentions.append({
'prompt_idx': prompt_idx,
'prompt_token': prompt_tokens[prompt_idx] if prompt_idx < len(prompt_tokens) else f'<{prompt_idx}>',
'weight': attention_to_prompt[prompt_idx].item()
})
# Sort by weight descending
prompt_attentions.sort(key=lambda x: x['weight'], reverse=True)
token_maps.append({
'token_idx': token_idx,
'token': token,
'position': current_pos,
'attention_to_prompt': prompt_attentions
})
logger.info(f"Computed attention maps for {len(token_maps)} generated tokens")
return token_maps
# Example usage
if __name__ == "__main__":
print("Attention analysis module loaded successfully")
# Example: Compute rollout on fake data
# num_tokens, num_layers, num_heads, seq_len = 5, 4, 8, 16
# fake_attn = torch.softmax(torch.randn(num_tokens, num_layers, num_heads, seq_len, seq_len), dim=-1)
#
# rollout = AttentionRollout(fake_attn, num_layers, num_heads)
# result = rollout.compute_rollout(token_idx=0)
# print(f"Rollout shape: {result.shape}")
|