Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """ | |
| Attention analysis utilities for interpretability. | |
| Implements: | |
| 1. Attention rollout (Kovaleva et al., 2019) - composition across layers | |
| 2. Head ranking by contribution | |
| 3. Helper functions for attention pattern analysis | |
| References: | |
| - Kovaleva et al. (2019): "Revealing the Dark Secrets of BERT" | |
| - Clark et al. (2019): "What Does BERT Look At?" | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import Dict, List, Tuple, Optional | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class AttentionRollout: | |
| """ | |
| Compute attention rollout to track information flow through transformer layers. | |
| Attention rollout composes attention weights across layers to show which | |
| input tokens contribute most to each output token through the entire network. | |
| For layer l, rollout is computed as: | |
| A_rollout(l) = A_rollout(l-1) @ A(l) | |
| Where @ is matrix multiplication and A(l) is the attention matrix at layer l. | |
| """ | |
| def __init__(self, attention_tensor: torch.Tensor, num_layers: int, num_heads: int): | |
| """ | |
| Args: | |
| attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len] | |
| num_layers: Number of layers | |
| num_heads: Number of attention heads per layer | |
| """ | |
| self.attention_tensor = attention_tensor | |
| self.num_layers = num_layers | |
| self.num_heads = num_heads | |
| # Will store rollout result | |
| self.rollout = None | |
| def compute_rollout(self, token_idx: int = -1, average_heads: bool = True) -> torch.Tensor: | |
| """ | |
| Compute attention rollout for a specific generated token. | |
| Args: | |
| token_idx: Which generated token to analyze (-1 = last token) | |
| average_heads: Whether to average across heads before composition | |
| Returns: | |
| Rollout matrix [num_layers, seq_len, seq_len] | |
| or [num_layers, num_heads, seq_len, seq_len] if not averaging | |
| """ | |
| # Extract attention for specific token | |
| # Shape: [num_layers, num_heads, seq_len, seq_len] | |
| attn = self.attention_tensor[token_idx] | |
| if average_heads: | |
| # Average across heads first | |
| # Shape: [num_layers, seq_len, seq_len] | |
| attn = attn.mean(dim=1) | |
| # Initialize rollout with identity matrix (no attention = self-attention) | |
| seq_len = attn.shape[-1] | |
| if average_heads: | |
| rollout = [torch.eye(seq_len)] | |
| else: | |
| # Keep heads separate | |
| rollout = [torch.eye(seq_len).unsqueeze(0).repeat(self.num_heads, 1, 1)] | |
| # Compose attention across layers | |
| # We build rollout from layer 0 to layer L, multiplying in the correct order: | |
| # rollout = attn[L] @ attn[L-1] @ ... @ attn[0] | |
| # To build iteratively, we apply new layers on the LEFT: new_rollout = attn[i] @ old_rollout | |
| for layer_idx in range(self.num_layers): | |
| layer_attn = attn[layer_idx] | |
| if average_heads: | |
| # Apply new layer attention on the left | |
| # Shape: [seq_len, seq_len] | |
| rollout.append(layer_attn @ rollout[-1]) | |
| else: | |
| # Multiply each head separately, new layer on the left | |
| # Shape: [num_heads, seq_len, seq_len] | |
| prev_rollout = rollout[-1] | |
| new_rollout = torch.bmm(layer_attn, prev_rollout) | |
| rollout.append(new_rollout) | |
| # Stack into tensor | |
| # Shape: [num_layers+1, seq_len, seq_len] or [num_layers+1, num_heads, seq_len, seq_len] | |
| self.rollout = torch.stack(rollout) | |
| # Normalize rollout so each row sums to 1 | |
| # After composing attention, rows don't sum to 1 anymore | |
| # We renormalize to maintain interpretability as attention weights | |
| if average_heads: | |
| # Shape: [num_layers+1, seq_len, seq_len] | |
| row_sums = self.rollout.sum(dim=-1, keepdim=True) | |
| # Avoid division by zero | |
| row_sums = torch.clamp(row_sums, min=1e-10) | |
| self.rollout = self.rollout / row_sums | |
| else: | |
| # Shape: [num_layers+1, num_heads, seq_len, seq_len] | |
| row_sums = self.rollout.sum(dim=-1, keepdim=True) | |
| row_sums = torch.clamp(row_sums, min=1e-10) | |
| self.rollout = self.rollout / row_sums | |
| logger.info(f"Computed attention rollout: shape={self.rollout.shape}") | |
| # Debug: Check if rollout looks reasonable | |
| if self.rollout.shape[0] > 0: | |
| sample_weights = self.rollout[-1, 0, :] # Last layer, first position, all targets | |
| logger.info(f"Sample rollout weights (pos 0): min={sample_weights.min().item():.6f}, max={sample_weights.max().item():.6f}, sum={sample_weights.sum().item():.6f}") | |
| return self.rollout | |
| def get_top_sources(self, target_token_idx: int, layer_idx: int, k: int = 8) -> List[Tuple[int, float]]: | |
| """ | |
| Get top-k source tokens that contribute most to target token at a specific layer. | |
| Args: | |
| target_token_idx: Index of target token in sequence | |
| layer_idx: Which layer's rollout to use | |
| k: Number of top sources to return | |
| Returns: | |
| List of (source_idx, weight) tuples, sorted by weight descending | |
| """ | |
| if self.rollout is None: | |
| raise ValueError("Must call compute_rollout() first") | |
| # Get rollout weights for target token | |
| # Shape: [seq_len] (attention from all sources to target) | |
| weights = self.rollout[layer_idx, :, target_token_idx] | |
| # Get top-k | |
| top_values, top_indices = torch.topk(weights, k=min(k, len(weights))) | |
| # Convert to list of tuples | |
| top_sources = [ | |
| (idx.item(), val.item()) | |
| for idx, val in zip(top_indices, top_values) | |
| ] | |
| return top_sources | |
| class HeadRanker: | |
| """ | |
| Rank attention heads by their contribution to model predictions. | |
| Multiple ranking strategies: | |
| 1. Rollout contribution: How much each head's attention flows to output | |
| 2. Mean max weight: Average of maximum attention weight per head | |
| 3. Entropy: Uncertainty in head's attention distribution | |
| """ | |
| def __init__(self, attention_tensor: torch.Tensor, num_layers: int, num_heads: int): | |
| """ | |
| Args: | |
| attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len] | |
| num_layers: Number of layers | |
| num_heads: Number of heads per layer | |
| """ | |
| self.attention_tensor = attention_tensor | |
| self.num_layers = num_layers | |
| self.num_heads = num_heads | |
| def rank_by_rollout_contribution(self, token_idx: int = -1, top_k: int = 20) -> List[Tuple[int, int, float]]: | |
| """ | |
| Rank heads by their rollout contribution. | |
| This measures how much information from each head flows to the final output. | |
| Args: | |
| token_idx: Which generated token to analyze | |
| top_k: Number of top heads to return | |
| Returns: | |
| List of (layer_idx, head_idx, contribution_score) tuples | |
| """ | |
| # Compute rollout without averaging heads | |
| rollout_computer = AttentionRollout(self.attention_tensor, self.num_layers, self.num_heads) | |
| rollout = rollout_computer.compute_rollout(token_idx=token_idx, average_heads=False) | |
| # For each head, compute contribution as sum of rollout weights | |
| # Shape: [num_layers+1, num_heads, seq_len, seq_len] | |
| head_contributions = [] | |
| for layer_idx in range(self.num_layers): | |
| for head_idx in range(self.num_heads): | |
| # Sum of all attention weights in final rollout for this head | |
| contribution = rollout[-1, head_idx].sum().item() | |
| head_contributions.append((layer_idx, head_idx, contribution)) | |
| # Sort by contribution descending | |
| head_contributions.sort(key=lambda x: x[2], reverse=True) | |
| # Return top-k | |
| return head_contributions[:top_k] | |
| def rank_by_max_weight(self, top_k: int = 20) -> List[Tuple[int, int, float]]: | |
| """ | |
| Rank heads by average maximum attention weight. | |
| Heads with high max weights are focusing strongly on specific tokens. | |
| Args: | |
| top_k: Number of top heads to return | |
| Returns: | |
| List of (layer_idx, head_idx, avg_max_weight) tuples | |
| """ | |
| head_scores = [] | |
| # Average across all generated tokens | |
| attn = self.attention_tensor.mean(dim=0) # [num_layers, num_heads, seq_len, seq_len] | |
| for layer_idx in range(self.num_layers): | |
| for head_idx in range(self.num_heads): | |
| # Get max attention weight for each target token, then average | |
| head_attn = attn[layer_idx, head_idx] # [seq_len, seq_len] | |
| max_weights = head_attn.max(dim=0)[0] # Max per target token | |
| avg_max = max_weights.mean().item() | |
| head_scores.append((layer_idx, head_idx, avg_max)) | |
| # Sort by score descending | |
| head_scores.sort(key=lambda x: x[2], reverse=True) | |
| return head_scores[:top_k] | |
| def rank_by_entropy(self, top_k: int = 20, high_entropy: bool = False) -> List[Tuple[int, int, float]]: | |
| """ | |
| Rank heads by attention distribution entropy. | |
| Low entropy = focused attention (head attends to few tokens) | |
| High entropy = diffuse attention (head attends to many tokens) | |
| Args: | |
| top_k: Number of top heads to return | |
| high_entropy: If True, return highest entropy heads; if False, return lowest | |
| Returns: | |
| List of (layer_idx, head_idx, entropy) tuples | |
| """ | |
| head_entropies = [] | |
| # Average across all generated tokens | |
| attn = self.attention_tensor.mean(dim=0) # [num_layers, num_heads, seq_len, seq_len] | |
| for layer_idx in range(self.num_layers): | |
| for head_idx in range(self.num_heads): | |
| head_attn = attn[layer_idx, head_idx] # [seq_len, seq_len] | |
| # Compute entropy for each target token's attention distribution | |
| # H = -sum(p * log(p)) | |
| entropy_per_token = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=0) | |
| avg_entropy = entropy_per_token.mean().item() | |
| head_entropies.append((layer_idx, head_idx, avg_entropy)) | |
| # Sort by entropy | |
| head_entropies.sort(key=lambda x: x[2], reverse=high_entropy) | |
| return head_entropies[:top_k] | |
| def identify_head_roles(attention_tensor: torch.Tensor, tokens: List[str], | |
| num_layers: int, num_heads: int) -> Dict[str, List[Tuple[int, int]]]: | |
| """ | |
| Identify potential roles of attention heads based on attention patterns. | |
| Heuristics: | |
| - Delimiter heads: High attention to brackets, colons, etc. | |
| - Positional heads: Attend primarily to adjacent tokens | |
| - Broad heads: Uniform attention across many tokens | |
| Args: | |
| attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len] | |
| tokens: List of token strings | |
| num_layers: Number of layers | |
| num_heads: Number of heads | |
| Returns: | |
| Dictionary mapping role names to list of (layer_idx, head_idx) tuples | |
| """ | |
| delimiter_tokens = {'(', ')', '{', '}', '[', ']', ':', ',', ';'} | |
| roles = { | |
| 'delimiter_focused': [], | |
| 'positional': [], | |
| 'broad': [] | |
| } | |
| # Average across all generated tokens | |
| attn = attention_tensor.mean(dim=0) # [num_layers, num_heads, seq_len, seq_len] | |
| for layer_idx in range(num_layers): | |
| for head_idx in range(num_heads): | |
| head_attn = attn[layer_idx, head_idx] # [seq_len, seq_len] | |
| # Check for delimiter focus | |
| delimiter_indices = [i for i, tok in enumerate(tokens) if tok in delimiter_tokens] | |
| if delimiter_indices: | |
| delimiter_attention = head_attn[:, delimiter_indices].mean().item() | |
| if delimiter_attention > 0.5: # Threshold | |
| roles['delimiter_focused'].append((layer_idx, head_idx)) | |
| # Check for positional pattern (diagonal attention) | |
| # Create diagonal mask | |
| diagonal_mask = torch.eye(head_attn.shape[0], dtype=torch.bool) | |
| adjacent_mask = diagonal_mask.roll(1, dims=1) | diagonal_mask.roll(-1, dims=1) | |
| positional_attention = head_attn[adjacent_mask].mean().item() | |
| if positional_attention > 0.6: | |
| roles['positional'].append((layer_idx, head_idx)) | |
| # Check for broad attention (high entropy) | |
| entropy = -(head_attn * torch.log(head_attn + 1e-10)).sum(dim=1).mean().item() | |
| if entropy > 2.0: # Threshold | |
| roles['broad'].append((layer_idx, head_idx)) | |
| logger.info(f"Identified head roles: {[(k, len(v)) for k, v in roles.items()]}") | |
| return roles | |
| def compute_token_attention_maps(attention_tensor: torch.Tensor, | |
| prompt_tokens: List[str], | |
| generated_tokens: List[str], | |
| num_layers: int, | |
| num_heads: int, | |
| prompt_length: int) -> List[Dict]: | |
| """ | |
| Compute attention maps showing which prompt tokens each generated token attends to. | |
| This creates the INPUT → INTERNALS → OUTPUT connection for visualization. | |
| Args: | |
| attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len] | |
| prompt_tokens: List of tokens in the prompt | |
| generated_tokens: List of generated tokens | |
| num_layers: Number of layers | |
| num_heads: Number of heads | |
| prompt_length: Number of tokens in the prompt | |
| Returns: | |
| List of dicts, one per generated token: | |
| [{ | |
| 'token_idx': int, | |
| 'token': str, | |
| 'attention_to_prompt': [ | |
| {'prompt_idx': int, 'prompt_token': str, 'weight': float}, | |
| ... | |
| ] | |
| }] | |
| """ | |
| token_maps = [] | |
| for token_idx, token in enumerate(generated_tokens): | |
| # Get attention for this token: [num_layers, num_heads, seq_len, seq_len] | |
| token_attn = attention_tensor[token_idx] | |
| # Average across all layers and heads to get overall attention pattern | |
| # Shape: [seq_len, seq_len] | |
| avg_attn = token_attn.mean(dim=0).mean(dim=0) | |
| # When generating this token, the model is at the last position | |
| # in the current sequence (before adding the new token) | |
| # Sequence length at generation time: prompt_length + token_idx | |
| # Last position index: prompt_length + token_idx - 1 | |
| current_pos = prompt_length + token_idx - 1 if token_idx > 0 else prompt_length - 1 | |
| # Extract attention FROM current position TO prompt tokens | |
| # This shows which prompt tokens the model attended to when generating this token | |
| # Shape: [prompt_length] | |
| attention_to_prompt = avg_attn[current_pos, :prompt_length] | |
| # Debug: Log sample attention weights for first token | |
| if token_idx == 0: | |
| logger.info(f"Token 0 attention weights: min={attention_to_prompt.min().item():.6f}, max={attention_to_prompt.max().item():.6f}, sum={attention_to_prompt.sum().item():.6f}") | |
| logger.info(f"First 5 weights: {attention_to_prompt[:5].tolist()}") | |
| # Create list of prompt token attentions | |
| prompt_attentions = [] | |
| for prompt_idx in range(prompt_length): | |
| prompt_attentions.append({ | |
| 'prompt_idx': prompt_idx, | |
| 'prompt_token': prompt_tokens[prompt_idx] if prompt_idx < len(prompt_tokens) else f'<{prompt_idx}>', | |
| 'weight': attention_to_prompt[prompt_idx].item() | |
| }) | |
| # Sort by weight descending | |
| prompt_attentions.sort(key=lambda x: x['weight'], reverse=True) | |
| token_maps.append({ | |
| 'token_idx': token_idx, | |
| 'token': token, | |
| 'position': current_pos, | |
| 'attention_to_prompt': prompt_attentions | |
| }) | |
| logger.info(f"Computed attention maps for {len(token_maps)} generated tokens") | |
| return token_maps | |
| # Example usage | |
| if __name__ == "__main__": | |
| print("Attention analysis module loaded successfully") | |
| # Example: Compute rollout on fake data | |
| # num_tokens, num_layers, num_heads, seq_len = 5, 4, 8, 16 | |
| # fake_attn = torch.softmax(torch.randn(num_tokens, num_layers, num_heads, seq_len, seq_len), dim=-1) | |
| # | |
| # rollout = AttentionRollout(fake_attn, num_layers, num_heads) | |
| # result = rollout.compute_rollout(token_idx=0) | |
| # print(f"Rollout shape: {result.shape}") | |