File size: 13,790 Bytes
920a98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed40a9a
 
920a98d
 
ed40a9a
920a98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed40a9a
920a98d
 
 
 
 
ed40a9a
920a98d
 
ed40a9a
920a98d
ed40a9a
920a98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
"""
Induction Head Detection for In-Context Learning

Based on research showing that ICL emerges abruptly in transformers through
the formation of induction heads - attention patterns that copy from context.
"""

import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import logging

logger = logging.getLogger(__name__)

@dataclass
class InductionHeadSignal:
    """Signals indicating induction head behavior"""
    layer: int
    head: int
    strength: float  # 0-1 score of induction pattern strength
    pattern_type: str  # 'copy', 'prefix_match', 'abstract'
    emergence_point: Optional[int]  # Token position where pattern emerges

@dataclass
class ICLEmergenceAnalysis:
    """Analysis of when and how ICL emerges"""
    emergence_detected: bool
    emergence_token: Optional[int]  # Token position where ICL kicks in
    emergence_layer: Optional[int]  # Layer where strongest signal appears
    confidence: float  # Confidence in detection (0-1)
    induction_heads: List[InductionHeadSignal]
    attention_entropy_drop: List[float]  # Entropy at each position
    pattern_consistency: float  # How consistent the pattern is

class InductionHeadDetector:
    """Detects induction heads and ICL emergence in transformer models"""

    def __init__(self, model, tokenizer, adapter=None):
        self.model = model
        self.tokenizer = tokenizer
        self.adapter = adapter
        self.device = next(model.parameters()).device
        
    def detect_induction_heads(
        self,
        attention_weights: List[Dict],
        input_ids: torch.Tensor,
        example_boundaries: List[Tuple[int, int]]
    ) -> List[InductionHeadSignal]:
        """
        Detect induction heads by looking for attention patterns that:
        1. Copy from previous occurrences (classic induction)
        2. Match prefixes across examples
        3. Show abstract pattern matching
        """
        induction_heads = []
        
        if not attention_weights or not example_boundaries:
            return induction_heads
            
        # Analyze each layer and head
        layers_analyzed = {}
        for record in attention_weights:
            layer_idx = record.get('layer', 0)
            attn = record.get('attention')
            
            if attn is None or layer_idx in layers_analyzed:
                continue
                
            layers_analyzed[layer_idx] = True
            
            # Analyze each attention head
            if attn.dim() >= 3:
                num_heads = attn.shape[1]
                seq_len = attn.shape[-1]
                
                for head_idx in range(num_heads):
                    head_attn = attn[0, head_idx]  # [seq_len, seq_len]
                    
                    # Detect different induction patterns
                    copy_score = self._detect_copy_pattern(head_attn, input_ids)
                    prefix_score = self._detect_prefix_matching(head_attn, example_boundaries)
                    abstract_score = self._detect_abstract_pattern(head_attn, seq_len)
                    
                    # Determine strongest pattern
                    max_score = max(copy_score, prefix_score, abstract_score)
                    if max_score > 0.3:  # Threshold for significant pattern
                        pattern_type = 'copy' if copy_score == max_score else \
                                     'prefix_match' if prefix_score == max_score else 'abstract'
                        
                        # Find emergence point (where pattern suddenly strengthens)
                        emergence_point = self._find_emergence_point(head_attn)
                        
                        induction_heads.append(InductionHeadSignal(
                            layer=layer_idx,
                            head=head_idx,
                            strength=max_score,
                            pattern_type=pattern_type,
                            emergence_point=emergence_point
                        ))
        
        return induction_heads
    
    def _detect_copy_pattern(self, attn_matrix: torch.Tensor, input_ids: torch.Tensor) -> float:
        """Detect if attention head copies from previous occurrences"""
        seq_len = attn_matrix.shape[0]
        copy_score = 0.0
        count = 0
        
        # Look for positions that attend strongly to previous same/similar tokens
        for i in range(1, min(seq_len, 50)):  # Limit analysis for efficiency
            if i >= len(input_ids[0]):
                break
                
            current_token = input_ids[0][i].item()
            
            # Find previous occurrences of the same token
            for j in range(i):
                if j < len(input_ids[0]) and input_ids[0][j].item() == current_token:
                    # Check if attention is strong to this position
                    if attn_matrix[i, j] > 0.1:  # Threshold for significant attention
                        copy_score += attn_matrix[i, j].item()
                        count += 1
        
        return copy_score / max(count, 1)
    
    def _detect_prefix_matching(
        self, 
        attn_matrix: torch.Tensor, 
        example_boundaries: List[Tuple[int, int]]
    ) -> float:
        """Detect if attention matches prefixes across examples"""
        if len(example_boundaries) < 2:
            return 0.0
            
        prefix_score = 0.0
        count = 0
        
        # Check if tokens attend to similar positions in different examples
        for i, (start1, end1) in enumerate(example_boundaries[:-1]):
            for j, (start2, end2) in enumerate(example_boundaries[i+1:], i+1):
                # Compare attention patterns between examples
                for offset in range(min(5, end1-start1, end2-start2)):  # Check first 5 tokens
                    pos1 = start1 + offset
                    pos2 = start2 + offset
                    
                    if pos1 < attn_matrix.shape[0] and pos2 < attn_matrix.shape[1]:
                        # Check if later example attends to earlier example at same offset
                        if pos2 < attn_matrix.shape[0] and pos1 < attn_matrix.shape[1]:
                            attention_strength = attn_matrix[pos2, pos1].item()
                            if attention_strength > 0.1:
                                prefix_score += attention_strength
                                count += 1
        
        return prefix_score / max(count, 1)
    
    def _detect_abstract_pattern(self, attn_matrix: torch.Tensor, seq_len: int) -> float:
        """Detect abstract pattern matching (e.g., function->function mapping)"""
        # Look for diagonal patterns offset by example length
        # This indicates attending to structurally similar positions
        
        abstract_score = 0.0
        window_size = 10
        
        for i in range(window_size, min(seq_len, 50)):
            # Check if attention follows a diagonal pattern with offset
            diagonal_sum = 0.0
            for offset in range(1, min(window_size, i)):
                if i - offset >= 0:
                    diagonal_sum += attn_matrix[i, i - offset].item()
            
            # High diagonal attention indicates structural copying
            if diagonal_sum / window_size > 0.1:
                abstract_score += diagonal_sum / window_size
        
        return min(abstract_score / 10, 1.0)  # Normalize
    
    def _find_emergence_point(self, attn_matrix: torch.Tensor) -> Optional[int]:
        """Find the token position where the pattern suddenly emerges"""
        seq_len = min(attn_matrix.shape[0], 50)  # Limit for efficiency
        
        if seq_len < 10:
            return None
            
        # Calculate attention entropy at each position
        entropies = []
        for i in range(seq_len):
            attn_dist = attn_matrix[i, :i+1]  # Only look at previous positions
            if attn_dist.sum() > 0:
                attn_dist = attn_dist / attn_dist.sum()
                # Calculate entropy
                entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
                entropies.append(entropy)
            else:
                entropies.append(0.0)
        
        # Find sudden drops in entropy (indicating focused attention)
        if len(entropies) < 5:
            return None
            
        for i in range(4, len(entropies)):
            recent_avg = np.mean(entropies[i-4:i])
            current = entropies[i]
            
            # Sudden drop indicates emergence
            if recent_avg > 0 and current < recent_avg * 0.5:
                return i
        
        return None
    
    def analyze_icl_emergence(
        self,
        attention_weights: List[Dict],
        input_ids: torch.Tensor,
        example_boundaries: List[Tuple[int, int]],
        generated_tokens: List[int]
    ) -> ICLEmergenceAnalysis:
        """
        Comprehensive analysis of when and how ICL emerges during generation
        """
        
        # Detect induction heads
        induction_heads = self.detect_induction_heads(
            attention_weights, input_ids, example_boundaries
        )
        
        # Calculate attention entropy trajectory
        entropy_trajectory = self._calculate_entropy_trajectory(
            attention_weights, len(generated_tokens)
        )
        
        # Determine emergence point
        emergence_token = None
        emergence_layer = None
        emergence_confidence = 0.0
        
        if induction_heads:
            # Find strongest induction signal
            strongest_head = max(induction_heads, key=lambda h: h.strength)
            
            # Check for consistent emergence points across heads
            emergence_points = [h.emergence_point for h in induction_heads if h.emergence_point]
            if emergence_points:
                # Most common emergence point
                emergence_token = int(np.median(emergence_points))
                emergence_layer = strongest_head.layer
                
                # Confidence based on consistency and strength
                consistency = len(emergence_points) / len(induction_heads)
                emergence_confidence = min(strongest_head.strength * consistency, 1.0)
        
        # Check for entropy drop as additional signal
        if entropy_trajectory and len(entropy_trajectory) > 5:
            for i in range(5, len(entropy_trajectory)):
                recent_avg = np.mean(entropy_trajectory[i-5:i])
                if recent_avg > 0 and entropy_trajectory[i] < recent_avg * 0.6:
                    if emergence_token is None:
                        emergence_token = i
                        emergence_confidence = 0.5
                    break
        
        # Calculate pattern consistency
        pattern_consistency = self._calculate_pattern_consistency(induction_heads)
        
        return ICLEmergenceAnalysis(
            emergence_detected=emergence_token is not None,
            emergence_token=emergence_token,
            emergence_layer=emergence_layer,
            confidence=emergence_confidence,
            induction_heads=induction_heads,
            attention_entropy_drop=entropy_trajectory,
            pattern_consistency=pattern_consistency
        )
    
    def _calculate_entropy_trajectory(
        self,
        attention_weights: List[Dict],
        num_generated: int
    ) -> List[float]:
        """Calculate attention entropy at each generated position"""
        entropies = []

        if not attention_weights:
            return entropies

        # Group attention by position
        num_layers = self.adapter.get_num_layers() if self.adapter else 20  # Use adapter or fallback to CodeGen's 20
        
        for gen_idx in range(num_generated):
            position_entropy = []
            
            # Get attention for this generated position across all layers
            for i in range(gen_idx * num_layers, min((gen_idx + 1) * num_layers, len(attention_weights))):
                if i < len(attention_weights):
                    attn = attention_weights[i].get('attention')
                    if attn is not None and attn.dim() >= 3:
                        # Average across heads
                        avg_attn = attn[0].mean(dim=0)
                        if avg_attn.shape[0] > gen_idx:
                            # Get attention distribution for this position
                            attn_dist = avg_attn[-1]  # Last position is newly generated
                            if attn_dist.sum() > 0:
                                attn_dist = attn_dist / attn_dist.sum()
                                # Calculate entropy
                                entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
                                position_entropy.append(entropy)
            
            if position_entropy:
                entropies.append(np.mean(position_entropy))
            else:
                entropies.append(0.0)
        
        return entropies
    
    def _calculate_pattern_consistency(self, induction_heads: List[InductionHeadSignal]) -> float:
        """Calculate how consistent the induction patterns are across heads"""
        if not induction_heads:
            return 0.0
            
        # Group by pattern type
        pattern_counts = {}
        for head in induction_heads:
            pattern_counts[head.pattern_type] = pattern_counts.get(head.pattern_type, 0) + 1
        
        # Consistency is ratio of dominant pattern
        max_count = max(pattern_counts.values())
        return max_count / len(induction_heads)