Spaces:
Sleeping
Sleeping
| """ | |
| Real Attention Extraction for In-Context Learning Analysis | |
| This module hooks into transformer models to extract actual attention weights | |
| during generation, providing real data for ICL analysis. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List, Dict, Tuple, Optional, Any | |
| from dataclasses import dataclass | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class AttentionData: | |
| """Stores attention data from model generation""" | |
| layer_attentions: List[torch.Tensor] # Attention from each layer | |
| token_positions: List[int] # Position of each generated token | |
| example_boundaries: List[Tuple[int, int]] # Start/end positions of examples | |
| class AttentionExtractor: | |
| """Extracts real attention patterns from transformer models during generation""" | |
| def __init__(self, model, tokenizer, adapter=None): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.adapter = adapter # Model adapter for multi-architecture support | |
| self.device = next(model.parameters()).device | |
| # Storage for attention during generation | |
| self.attention_weights = [] | |
| self.handles = [] | |
| def register_hooks(self): | |
| """Register forward hooks to capture attention weights""" | |
| self.clear_hooks() | |
| # Use adapter if available for multi-architecture support | |
| if self.adapter: | |
| num_layers = self.adapter.get_num_layers() | |
| for i in range(num_layers): | |
| attn_module = self.adapter.get_attention_module(i) | |
| if attn_module: | |
| handle = attn_module.register_forward_hook( | |
| lambda module, input, output, layer_idx=i: | |
| self._attention_hook(module, input, output, layer_idx) | |
| ) | |
| self.handles.append(handle) | |
| # Fallback for CodeGen models without adapter | |
| elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'): | |
| # Hook into each transformer layer | |
| for i, layer in enumerate(self.model.transformer.h): | |
| if hasattr(layer, 'attn'): | |
| handle = layer.attn.register_forward_hook( | |
| lambda module, input, output, layer_idx=i: | |
| self._attention_hook(module, input, output, layer_idx) | |
| ) | |
| self.handles.append(handle) | |
| logger.info(f"Registered {len(self.handles)} attention hooks") | |
| def _attention_hook(self, module, input, output, layer_idx): | |
| """Hook function to capture attention weights""" | |
| # For CodeGen, output is (hidden_states, attention_weights) | |
| if isinstance(output, tuple) and len(output) >= 2: | |
| attention = output[1] | |
| if attention is not None: | |
| # Store attention weights | |
| self.attention_weights.append({ | |
| 'layer': layer_idx, | |
| 'attention': attention.detach().cpu() | |
| }) | |
| def clear_hooks(self): | |
| """Remove all hooks""" | |
| for handle in self.handles: | |
| handle.remove() | |
| self.handles = [] | |
| self.attention_weights = [] | |
| def extract_attention_with_generation( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| max_new_tokens: int = 50, | |
| temperature: float = 0.7 | |
| ) -> Tuple[torch.Tensor, List[Dict], List[torch.Tensor]]: | |
| """Generate text while extracting attention patterns""" | |
| # Register hooks before generation | |
| self.register_hooks() | |
| self.attention_weights = [] | |
| try: | |
| # Generate token by token to capture attention at each step | |
| generated_ids = [] | |
| all_scores = [] # Store scores for confidence calculation | |
| current_input_ids = input_ids.clone() | |
| current_attention_mask = attention_mask.clone() | |
| for _ in range(max_new_tokens): | |
| with torch.no_grad(): | |
| # Forward pass through model | |
| outputs = self.model( | |
| input_ids=current_input_ids, | |
| attention_mask=current_attention_mask, | |
| use_cache=False, # Don't use cache to get full attention | |
| output_attentions=True, | |
| return_dict=True | |
| ) | |
| # Capture attention from outputs if hooks didn't get it | |
| if hasattr(outputs, 'attentions') and outputs.attentions is not None: | |
| for layer_idx, attn in enumerate(outputs.attentions): | |
| self.attention_weights.append({ | |
| 'layer': layer_idx, | |
| 'attention': attn.detach().cpu() | |
| }) | |
| # Get next token logits | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # Store the scores | |
| all_scores.append(next_token_logits) | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| probs = F.softmax(next_token_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| else: | |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) | |
| # Stop if EOS token | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| # Append token | |
| generated_ids.append(next_token.item()) | |
| current_input_ids = torch.cat([current_input_ids, next_token], dim=1) | |
| current_attention_mask = torch.cat([ | |
| current_attention_mask, | |
| torch.ones((1, 1), device=self.device) | |
| ], dim=1) | |
| # Convert to tensor | |
| if generated_ids: | |
| generated_tensor = torch.tensor(generated_ids, device=self.device).unsqueeze(0) | |
| else: | |
| generated_tensor = torch.tensor([[]], device=self.device, dtype=torch.long) | |
| return generated_tensor, self.attention_weights, all_scores | |
| finally: | |
| # Always clear hooks after generation | |
| self.clear_hooks() | |
| def aggregate_attention_to_examples( | |
| self, | |
| attention_data: List[Dict], | |
| example_boundaries: List[Tuple[int, int]], | |
| prompt_length: int | |
| ) -> Dict[str, List[float]]: | |
| """ | |
| Aggregate attention from generated tokens back to example regions | |
| Returns: | |
| Dict mapping example_id -> list of attention weights per generated token | |
| """ | |
| if not attention_data or not example_boundaries: | |
| return {} | |
| attention_to_examples = {} | |
| # Process attention for each generated token position | |
| # We have attention data for each layer for each generated token | |
| # Count unique positions based on attention data | |
| num_layers = 20 # CodeGen has 20 layers | |
| num_generated = len(attention_data) // num_layers if attention_data else 0 | |
| logger.info(f"Processing {len(attention_data)} attention records for {num_generated} generated tokens") | |
| for example_idx, (start, end) in enumerate(example_boundaries): | |
| example_id = str(example_idx + 1) | |
| example_attention = [] | |
| # For each generated token | |
| for gen_idx in range(num_generated): | |
| # Aggregate attention across all layers for this generated position | |
| total_attention = 0.0 | |
| # Get attention records for this generated position | |
| layer_count = 0 | |
| for i, attn_record in enumerate(attention_data): | |
| # Each generated token should have attention from all layers | |
| # So records [gen_idx*num_layers:(gen_idx+1)*num_layers] correspond to gen_idx | |
| if i >= gen_idx * num_layers and i < (gen_idx + 1) * num_layers: | |
| if 'attention' in attn_record: | |
| attn_tensor = attn_record['attention'] | |
| # Get attention from generated position to example region | |
| if attn_tensor.dim() >= 3: | |
| # Shape: [batch, heads, seq_len, seq_len] | |
| # The last position in the attention matrix corresponds to the newly generated token | |
| seq_len = attn_tensor.shape[-1] | |
| # Average across heads, get attention from last position to example region | |
| if end <= seq_len: | |
| attn_to_example = attn_tensor[0, :, -1, start:end].mean().item() | |
| total_attention += attn_to_example | |
| layer_count += 1 | |
| # Average across layers | |
| if layer_count > 0: | |
| example_attention.append(total_attention / layer_count) | |
| else: | |
| example_attention.append(0.0) | |
| attention_to_examples[example_id] = example_attention | |
| # Normalize attention for each generated token | |
| for gen_idx in range(num_generated): | |
| total = sum( | |
| attention_to_examples[ex_id][gen_idx] | |
| for ex_id in attention_to_examples | |
| if gen_idx < len(attention_to_examples[ex_id]) | |
| ) | |
| if total > 0: | |
| for ex_id in attention_to_examples: | |
| if gen_idx < len(attention_to_examples[ex_id]): | |
| attention_to_examples[ex_id][gen_idx] /= total | |
| return attention_to_examples | |
| def calculate_example_influences( | |
| self, | |
| attention_to_examples: Dict[str, List[float]] | |
| ) -> Dict[str, float]: | |
| """ | |
| Calculate overall influence of each example based on attention patterns | |
| Returns: | |
| Dict mapping example_id -> influence score (0-1) | |
| """ | |
| influences = {} | |
| for example_id, attention_weights in attention_to_examples.items(): | |
| # Overall influence is the mean attention across all generated tokens | |
| if attention_weights: | |
| influences[example_id] = float(np.mean(attention_weights)) | |
| else: | |
| influences[example_id] = 0.0 | |
| # Normalize to sum to 1 | |
| total = sum(influences.values()) | |
| if total > 0: | |
| influences = {k: v/total for k, v in influences.items()} | |
| return influences |