File size: 11,368 Bytes
920a98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed40a9a
 
920a98d
 
ed40a9a
920a98d
ed40a9a
920a98d
 
 
 
 
 
 
ed40a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
920a98d
 
 
 
ed40a9a
920a98d
 
 
ed40a9a
920a98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""
Real Attention Extraction for In-Context Learning Analysis

This module hooks into transformer models to extract actual attention weights
during generation, providing real data for ICL analysis.
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
import logging

logger = logging.getLogger(__name__)

@dataclass
class AttentionData:
    """Stores attention data from model generation"""
    layer_attentions: List[torch.Tensor]  # Attention from each layer
    token_positions: List[int]  # Position of each generated token
    example_boundaries: List[Tuple[int, int]]  # Start/end positions of examples
    
class AttentionExtractor:
    """Extracts real attention patterns from transformer models during generation"""

    def __init__(self, model, tokenizer, adapter=None):
        self.model = model
        self.tokenizer = tokenizer
        self.adapter = adapter  # Model adapter for multi-architecture support
        self.device = next(model.parameters()).device

        # Storage for attention during generation
        self.attention_weights = []
        self.handles = []
        
    def register_hooks(self):
        """Register forward hooks to capture attention weights"""
        self.clear_hooks()

        # Use adapter if available for multi-architecture support
        if self.adapter:
            num_layers = self.adapter.get_num_layers()
            for i in range(num_layers):
                attn_module = self.adapter.get_attention_module(i)
                if attn_module:
                    handle = attn_module.register_forward_hook(
                        lambda module, input, output, layer_idx=i:
                        self._attention_hook(module, input, output, layer_idx)
                    )
                    self.handles.append(handle)
        # Fallback for CodeGen models without adapter
        elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
            # Hook into each transformer layer
            for i, layer in enumerate(self.model.transformer.h):
                if hasattr(layer, 'attn'):
                    handle = layer.attn.register_forward_hook(
                        lambda module, input, output, layer_idx=i:
                        self._attention_hook(module, input, output, layer_idx)
                    )
                    self.handles.append(handle)

        logger.info(f"Registered {len(self.handles)} attention hooks")
                    
    def _attention_hook(self, module, input, output, layer_idx):
        """Hook function to capture attention weights"""
        # For CodeGen, output is (hidden_states, attention_weights)
        if isinstance(output, tuple) and len(output) >= 2:
            attention = output[1]
            if attention is not None:
                # Store attention weights
                self.attention_weights.append({
                    'layer': layer_idx,
                    'attention': attention.detach().cpu()
                })
    
    def clear_hooks(self):
        """Remove all hooks"""
        for handle in self.handles:
            handle.remove()
        self.handles = []
        self.attention_weights = []
        
    def extract_attention_with_generation(
        self, 
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        max_new_tokens: int = 50,
        temperature: float = 0.7
    ) -> Tuple[torch.Tensor, List[Dict], List[torch.Tensor]]:
        """Generate text while extracting attention patterns"""
        
        # Register hooks before generation
        self.register_hooks()
        self.attention_weights = []
        
        try:
            # Generate token by token to capture attention at each step
            generated_ids = []
            all_scores = []  # Store scores for confidence calculation
            current_input_ids = input_ids.clone()
            current_attention_mask = attention_mask.clone()
            
            for _ in range(max_new_tokens):
                with torch.no_grad():
                    # Forward pass through model
                    outputs = self.model(
                        input_ids=current_input_ids,
                        attention_mask=current_attention_mask,
                        use_cache=False,  # Don't use cache to get full attention
                        output_attentions=True,
                        return_dict=True
                    )
                    
                    # Capture attention from outputs if hooks didn't get it
                    if hasattr(outputs, 'attentions') and outputs.attentions is not None:
                        for layer_idx, attn in enumerate(outputs.attentions):
                            self.attention_weights.append({
                                'layer': layer_idx,
                                'attention': attn.detach().cpu()
                            })
                    
                    # Get next token logits
                    next_token_logits = outputs.logits[:, -1, :]
                    
                    # Store the scores
                    all_scores.append(next_token_logits)
                    
                    # Apply temperature
                    if temperature > 0:
                        next_token_logits = next_token_logits / temperature
                        probs = F.softmax(next_token_logits, dim=-1)
                        next_token = torch.multinomial(probs, num_samples=1)
                    else:
                        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
                    
                    # Stop if EOS token
                    if next_token.item() == self.tokenizer.eos_token_id:
                        break
                        
                    # Append token
                    generated_ids.append(next_token.item())
                    current_input_ids = torch.cat([current_input_ids, next_token], dim=1)
                    current_attention_mask = torch.cat([
                        current_attention_mask,
                        torch.ones((1, 1), device=self.device)
                    ], dim=1)
            
            # Convert to tensor
            if generated_ids:
                generated_tensor = torch.tensor(generated_ids, device=self.device).unsqueeze(0)
            else:
                generated_tensor = torch.tensor([[]], device=self.device, dtype=torch.long)
                
            return generated_tensor, self.attention_weights, all_scores
            
        finally:
            # Always clear hooks after generation
            self.clear_hooks()
    
    def aggregate_attention_to_examples(
        self,
        attention_data: List[Dict],
        example_boundaries: List[Tuple[int, int]],
        prompt_length: int
    ) -> Dict[str, List[float]]:
        """
        Aggregate attention from generated tokens back to example regions
        
        Returns:
            Dict mapping example_id -> list of attention weights per generated token
        """
        
        if not attention_data or not example_boundaries:
            return {}
            
        attention_to_examples = {}
        
        # Process attention for each generated token position
        # We have attention data for each layer for each generated token
        # Count unique positions based on attention data
        num_layers = 20  # CodeGen has 20 layers
        num_generated = len(attention_data) // num_layers if attention_data else 0
        
        logger.info(f"Processing {len(attention_data)} attention records for {num_generated} generated tokens")
        
        for example_idx, (start, end) in enumerate(example_boundaries):
            example_id = str(example_idx + 1)
            example_attention = []
            
            # For each generated token
            for gen_idx in range(num_generated):
                # Aggregate attention across all layers for this generated position
                total_attention = 0.0
                
                # Get attention records for this generated position
                layer_count = 0
                for i, attn_record in enumerate(attention_data):
                    # Each generated token should have attention from all layers
                    # So records [gen_idx*num_layers:(gen_idx+1)*num_layers] correspond to gen_idx
                    if i >= gen_idx * num_layers and i < (gen_idx + 1) * num_layers:
                        if 'attention' in attn_record:
                            attn_tensor = attn_record['attention']
                            
                            # Get attention from generated position to example region
                            if attn_tensor.dim() >= 3:
                                # Shape: [batch, heads, seq_len, seq_len]
                                # The last position in the attention matrix corresponds to the newly generated token
                                seq_len = attn_tensor.shape[-1]
                                
                                # Average across heads, get attention from last position to example region
                                if end <= seq_len:
                                    attn_to_example = attn_tensor[0, :, -1, start:end].mean().item()
                                    total_attention += attn_to_example
                                    layer_count += 1
                
                # Average across layers
                if layer_count > 0:
                    example_attention.append(total_attention / layer_count)
                else:
                    example_attention.append(0.0)
            
            attention_to_examples[example_id] = example_attention
            
        # Normalize attention for each generated token
        for gen_idx in range(num_generated):
            total = sum(
                attention_to_examples[ex_id][gen_idx] 
                for ex_id in attention_to_examples 
                if gen_idx < len(attention_to_examples[ex_id])
            )
            if total > 0:
                for ex_id in attention_to_examples:
                    if gen_idx < len(attention_to_examples[ex_id]):
                        attention_to_examples[ex_id][gen_idx] /= total
                        
        return attention_to_examples
    
    def calculate_example_influences(
        self, 
        attention_to_examples: Dict[str, List[float]]
    ) -> Dict[str, float]:
        """
        Calculate overall influence of each example based on attention patterns
        
        Returns:
            Dict mapping example_id -> influence score (0-1)
        """
        influences = {}
        
        for example_id, attention_weights in attention_to_examples.items():
            # Overall influence is the mean attention across all generated tokens
            if attention_weights:
                influences[example_id] = float(np.mean(attention_weights))
            else:
                influences[example_id] = 0.0
        
        # Normalize to sum to 1
        total = sum(influences.values())
        if total > 0:
            influences = {k: v/total for k, v in influences.items()}
            
        return influences