""" Real Attention Extraction for In-Context Learning Analysis This module hooks into transformer models to extract actual attention weights during generation, providing real data for ICL analysis. """ import torch import torch.nn.functional as F import numpy as np from typing import List, Dict, Tuple, Optional, Any from dataclasses import dataclass import logging logger = logging.getLogger(__name__) @dataclass class AttentionData: """Stores attention data from model generation""" layer_attentions: List[torch.Tensor] # Attention from each layer token_positions: List[int] # Position of each generated token example_boundaries: List[Tuple[int, int]] # Start/end positions of examples class AttentionExtractor: """Extracts real attention patterns from transformer models during generation""" def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer self.device = next(model.parameters()).device # Storage for attention during generation self.attention_weights = [] self.handles = [] def register_hooks(self): """Register forward hooks to capture attention weights""" self.clear_hooks() # For CodeGen models, attention is in the transformer blocks if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'): # Hook into each transformer layer for i, layer in enumerate(self.model.transformer.h): if hasattr(layer, 'attn'): handle = layer.attn.register_forward_hook( lambda module, input, output, layer_idx=i: self._attention_hook(module, input, output, layer_idx) ) self.handles.append(handle) logger.info(f"Registered {len(self.handles)} attention hooks") def _attention_hook(self, module, input, output, layer_idx): """Hook function to capture attention weights""" # For CodeGen, output is (hidden_states, attention_weights) if isinstance(output, tuple) and len(output) >= 2: attention = output[1] if attention is not None: # Store attention weights self.attention_weights.append({ 'layer': layer_idx, 'attention': attention.detach().cpu() }) def clear_hooks(self): """Remove all hooks""" for handle in self.handles: handle.remove() self.handles = [] self.attention_weights = [] def extract_attention_with_generation( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, max_new_tokens: int = 50, temperature: float = 0.7 ) -> Tuple[torch.Tensor, List[Dict], List[torch.Tensor]]: """Generate text while extracting attention patterns""" # Register hooks before generation self.register_hooks() self.attention_weights = [] try: # Generate token by token to capture attention at each step generated_ids = [] all_scores = [] # Store scores for confidence calculation current_input_ids = input_ids.clone() current_attention_mask = attention_mask.clone() for _ in range(max_new_tokens): with torch.no_grad(): # Forward pass through model outputs = self.model( input_ids=current_input_ids, attention_mask=current_attention_mask, use_cache=False, # Don't use cache to get full attention output_attentions=True, return_dict=True ) # Capture attention from outputs if hooks didn't get it if hasattr(outputs, 'attentions') and outputs.attentions is not None: for layer_idx, attn in enumerate(outputs.attentions): self.attention_weights.append({ 'layer': layer_idx, 'attention': attn.detach().cpu() }) # Get next token logits next_token_logits = outputs.logits[:, -1, :] # Store the scores all_scores.append(next_token_logits) # Apply temperature if temperature > 0: next_token_logits = next_token_logits / temperature probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # Stop if EOS token if next_token.item() == self.tokenizer.eos_token_id: break # Append token generated_ids.append(next_token.item()) current_input_ids = torch.cat([current_input_ids, next_token], dim=1) current_attention_mask = torch.cat([ current_attention_mask, torch.ones((1, 1), device=self.device) ], dim=1) # Convert to tensor if generated_ids: generated_tensor = torch.tensor(generated_ids, device=self.device).unsqueeze(0) else: generated_tensor = torch.tensor([[]], device=self.device, dtype=torch.long) return generated_tensor, self.attention_weights, all_scores finally: # Always clear hooks after generation self.clear_hooks() def aggregate_attention_to_examples( self, attention_data: List[Dict], example_boundaries: List[Tuple[int, int]], prompt_length: int ) -> Dict[str, List[float]]: """ Aggregate attention from generated tokens back to example regions Returns: Dict mapping example_id -> list of attention weights per generated token """ if not attention_data or not example_boundaries: return {} attention_to_examples = {} # Process attention for each generated token position # We have attention data for each layer for each generated token # Count unique positions based on attention data num_layers = 20 # CodeGen has 20 layers num_generated = len(attention_data) // num_layers if attention_data else 0 logger.info(f"Processing {len(attention_data)} attention records for {num_generated} generated tokens") for example_idx, (start, end) in enumerate(example_boundaries): example_id = str(example_idx + 1) example_attention = [] # For each generated token for gen_idx in range(num_generated): # Aggregate attention across all layers for this generated position total_attention = 0.0 # Get attention records for this generated position layer_count = 0 for i, attn_record in enumerate(attention_data): # Each generated token should have attention from all layers # So records [gen_idx*num_layers:(gen_idx+1)*num_layers] correspond to gen_idx if i >= gen_idx * num_layers and i < (gen_idx + 1) * num_layers: if 'attention' in attn_record: attn_tensor = attn_record['attention'] # Get attention from generated position to example region if attn_tensor.dim() >= 3: # Shape: [batch, heads, seq_len, seq_len] # The last position in the attention matrix corresponds to the newly generated token seq_len = attn_tensor.shape[-1] # Average across heads, get attention from last position to example region if end <= seq_len: attn_to_example = attn_tensor[0, :, -1, start:end].mean().item() total_attention += attn_to_example layer_count += 1 # Average across layers if layer_count > 0: example_attention.append(total_attention / layer_count) else: example_attention.append(0.0) attention_to_examples[example_id] = example_attention # Normalize attention for each generated token for gen_idx in range(num_generated): total = sum( attention_to_examples[ex_id][gen_idx] for ex_id in attention_to_examples if gen_idx < len(attention_to_examples[ex_id]) ) if total > 0: for ex_id in attention_to_examples: if gen_idx < len(attention_to_examples[ex_id]): attention_to_examples[ex_id][gen_idx] /= total return attention_to_examples def calculate_example_influences( self, attention_to_examples: Dict[str, List[float]] ) -> Dict[str, float]: """ Calculate overall influence of each example based on attention patterns Returns: Dict mapping example_id -> influence score (0-1) """ influences = {} for example_id, attention_weights in attention_to_examples.items(): # Overall influence is the mean attention across all generated tokens if attention_weights: influences[example_id] = float(np.mean(attention_weights)) else: influences[example_id] = 0.0 # Normalize to sum to 1 total = sum(influences.values()) if total > 0: influences = {k: v/total for k, v in influences.items()} return influences