gary-boon Claude commited on
Commit
920a98d
·
1 Parent(s): bb8a292

Add backend support for ICL emergence analysis

Browse files

- Implement ICL attention extractor with PyTorch hooks
- Add induction head detector for pattern recognition
- Create context efficiency analyzer for optimal example usage
- Update model service with ICL emergence endpoints
- Support real-time attention weight extraction during generation
- Enable token-by-token generation for attention capture

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

backend/context_efficiency_analyzer.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Context Efficiency Analyzer for In-Context Learning
3
+
4
+ Measures how efficiently the model uses context examples to perform tasks.
5
+ Based on research showing that not all examples contribute equally and that
6
+ optimal context usage can significantly improve performance.
7
+ """
8
+
9
+ import torch
10
+ import numpy as np
11
+ from typing import List, Dict, Tuple, Optional
12
+ from dataclasses import dataclass
13
+ import logging
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ @dataclass
18
+ class TokenEfficiency:
19
+ """Efficiency metrics for individual tokens"""
20
+ token: str
21
+ position: int
22
+ information_content: float # Bits of information
23
+ redundancy_score: float # 0-1 (1 = completely redundant)
24
+ contribution_score: float # How much it contributes to output
25
+
26
+ @dataclass
27
+ class ExampleEfficiency:
28
+ """Efficiency metrics for each example"""
29
+ example_id: str
30
+ total_tokens: int
31
+ effective_tokens: int # Tokens that actually contribute
32
+ efficiency_ratio: float # effective/total
33
+ redundancy_rate: float # Percentage of redundant tokens
34
+ information_density: float # Bits per token
35
+ marginal_benefit: float # Additional benefit vs previous examples
36
+
37
+ @dataclass
38
+ class ContextEfficiencyAnalysis:
39
+ """Complete context efficiency analysis"""
40
+ overall_efficiency: float # 0-1 score
41
+ total_context_tokens: int
42
+ effective_context_tokens: int
43
+ example_efficiencies: List[ExampleEfficiency]
44
+ token_efficiencies: List[TokenEfficiency]
45
+ optimal_example_count: int # Suggested optimal number of examples
46
+ redundancy_patterns: Dict[str, float] # Pattern type -> frequency
47
+ compression_potential: float # How much context could be compressed
48
+ attention_utilization: float # How much of context gets attention
49
+
50
+ class ContextEfficiencyAnalyzer:
51
+ """Analyzes how efficiently context is used in ICL"""
52
+
53
+ def __init__(self, model, tokenizer):
54
+ self.model = model
55
+ self.tokenizer = tokenizer
56
+ self.device = next(model.parameters()).device
57
+
58
+ def analyze_context_efficiency(
59
+ self,
60
+ examples: List[Tuple[str, str]], # (input, output) pairs
61
+ test_prompt: str,
62
+ attention_weights: Optional[List[Dict]] = None,
63
+ generated_tokens: List[str] = None,
64
+ confidence_scores: List[float] = None
65
+ ) -> ContextEfficiencyAnalysis:
66
+ """
67
+ Comprehensive analysis of context efficiency
68
+ """
69
+
70
+ # Tokenize all examples
71
+ example_tokens = []
72
+ example_boundaries = []
73
+ current_pos = 0
74
+
75
+ for idx, (input_text, output_text) in enumerate(examples):
76
+ example_text = f"{input_text}\n{output_text}\n"
77
+ tokens = self.tokenizer.tokenize(example_text)
78
+ example_tokens.extend(tokens)
79
+ example_boundaries.append((current_pos, current_pos + len(tokens)))
80
+ current_pos += len(tokens)
81
+
82
+ # Analyze each example's efficiency
83
+ example_efficiencies = []
84
+ for idx, (start, end) in enumerate(example_boundaries):
85
+ efficiency = self._analyze_example_efficiency(
86
+ example_idx=idx,
87
+ example_tokens=example_tokens[start:end],
88
+ all_tokens=example_tokens,
89
+ attention_weights=attention_weights,
90
+ generated_tokens=generated_tokens
91
+ )
92
+ example_efficiencies.append(efficiency)
93
+
94
+ # Analyze token-level efficiency
95
+ token_efficiencies = self._analyze_token_efficiency(
96
+ example_tokens=example_tokens,
97
+ attention_weights=attention_weights,
98
+ generated_tokens=generated_tokens
99
+ )
100
+
101
+ # Calculate redundancy patterns
102
+ redundancy_patterns = self._identify_redundancy_patterns(
103
+ example_tokens=example_tokens,
104
+ token_efficiencies=token_efficiencies
105
+ )
106
+
107
+ # Determine optimal example count
108
+ optimal_count = self._calculate_optimal_example_count(
109
+ example_efficiencies=example_efficiencies
110
+ )
111
+
112
+ # Calculate compression potential
113
+ compression_potential = self._calculate_compression_potential(
114
+ token_efficiencies=token_efficiencies
115
+ )
116
+
117
+ # Calculate attention utilization
118
+ attention_utilization = self._calculate_attention_utilization(
119
+ attention_weights=attention_weights,
120
+ total_context_tokens=len(example_tokens)
121
+ )
122
+
123
+ # Calculate overall efficiency
124
+ effective_tokens = sum(1 for t in token_efficiencies if t.redundancy_score < 0.5)
125
+ overall_efficiency = effective_tokens / max(len(example_tokens), 1)
126
+
127
+ return ContextEfficiencyAnalysis(
128
+ overall_efficiency=overall_efficiency,
129
+ total_context_tokens=len(example_tokens),
130
+ effective_context_tokens=effective_tokens,
131
+ example_efficiencies=example_efficiencies,
132
+ token_efficiencies=token_efficiencies,
133
+ optimal_example_count=optimal_count,
134
+ redundancy_patterns=redundancy_patterns,
135
+ compression_potential=compression_potential,
136
+ attention_utilization=attention_utilization
137
+ )
138
+
139
+ def _analyze_example_efficiency(
140
+ self,
141
+ example_idx: int,
142
+ example_tokens: List[str],
143
+ all_tokens: List[str],
144
+ attention_weights: Optional[List[Dict]],
145
+ generated_tokens: List[str]
146
+ ) -> ExampleEfficiency:
147
+ """Analyze efficiency of a single example"""
148
+
149
+ # Calculate redundancy with previous examples
150
+ redundant_count = 0
151
+ if example_idx > 0:
152
+ # Check for repeated patterns
153
+ for token in example_tokens:
154
+ if all_tokens[:example_idx * len(example_tokens)].count(token) > 2:
155
+ redundant_count += 1
156
+
157
+ redundancy_rate = redundant_count / max(len(example_tokens), 1)
158
+
159
+ # Calculate information density (simplified Shannon entropy)
160
+ unique_tokens = len(set(example_tokens))
161
+ information_density = np.log2(max(unique_tokens, 1)) / max(len(example_tokens), 1)
162
+
163
+ # Calculate marginal benefit (how much this example adds)
164
+ if example_idx == 0:
165
+ marginal_benefit = 1.0 # First example always has full benefit
166
+ else:
167
+ # Estimate based on new unique patterns introduced
168
+ new_patterns = set(example_tokens) - set(all_tokens[:example_idx * len(example_tokens)])
169
+ marginal_benefit = len(new_patterns) / max(len(example_tokens), 1)
170
+
171
+ # Calculate effective tokens (those that contribute)
172
+ effective_tokens = int(len(example_tokens) * (1 - redundancy_rate))
173
+
174
+ return ExampleEfficiency(
175
+ example_id=str(example_idx + 1),
176
+ total_tokens=len(example_tokens),
177
+ effective_tokens=effective_tokens,
178
+ efficiency_ratio=effective_tokens / max(len(example_tokens), 1),
179
+ redundancy_rate=redundancy_rate,
180
+ information_density=information_density,
181
+ marginal_benefit=marginal_benefit
182
+ )
183
+
184
+ def _analyze_token_efficiency(
185
+ self,
186
+ example_tokens: List[str],
187
+ attention_weights: Optional[List[Dict]],
188
+ generated_tokens: List[str]
189
+ ) -> List[TokenEfficiency]:
190
+ """Analyze efficiency of individual tokens"""
191
+
192
+ token_efficiencies = []
193
+
194
+ for idx, token in enumerate(example_tokens):
195
+ # Calculate information content (simplified)
196
+ # Rare tokens have more information
197
+ frequency = example_tokens.count(token)
198
+ information_content = np.log2(len(example_tokens) / max(frequency, 1))
199
+
200
+ # Calculate redundancy
201
+ # Tokens that appear many times in same context are redundant
202
+ local_window = example_tokens[max(0, idx-5):min(len(example_tokens), idx+5)]
203
+ local_frequency = local_window.count(token)
204
+ redundancy_score = min(local_frequency / 3.0, 1.0) # Cap at 1.0
205
+
206
+ # Calculate contribution score
207
+ # Based on whether similar tokens appear in output
208
+ contribution_score = 0.0
209
+ if generated_tokens:
210
+ # Check if token or similar tokens appear in output
211
+ if token in generated_tokens:
212
+ contribution_score = 1.0
213
+ elif any(token.lower() in gen_token.lower() for gen_token in generated_tokens):
214
+ contribution_score = 0.5
215
+
216
+ token_efficiencies.append(TokenEfficiency(
217
+ token=token,
218
+ position=idx,
219
+ information_content=information_content,
220
+ redundancy_score=redundancy_score,
221
+ contribution_score=contribution_score
222
+ ))
223
+
224
+ return token_efficiencies
225
+
226
+ def _identify_redundancy_patterns(
227
+ self,
228
+ example_tokens: List[str],
229
+ token_efficiencies: List[TokenEfficiency]
230
+ ) -> Dict[str, float]:
231
+ """Identify common redundancy patterns"""
232
+
233
+ patterns = {
234
+ 'repeated_tokens': 0.0,
235
+ 'boilerplate': 0.0,
236
+ 'structural_repetition': 0.0,
237
+ 'semantic_overlap': 0.0
238
+ }
239
+
240
+ # Count repeated tokens
241
+ token_counts = {}
242
+ for token in example_tokens:
243
+ token_counts[token] = token_counts.get(token, 0) + 1
244
+
245
+ repeated = sum(1 for count in token_counts.values() if count > 3)
246
+ patterns['repeated_tokens'] = repeated / max(len(token_counts), 1)
247
+
248
+ # Detect boilerplate (common programming patterns)
249
+ boilerplate_tokens = ['def', 'class', 'return', 'import', 'from', '"""', "'''"]
250
+ boilerplate_count = sum(1 for token in example_tokens if token in boilerplate_tokens)
251
+ patterns['boilerplate'] = boilerplate_count / max(len(example_tokens), 1)
252
+
253
+ # Detect structural repetition (same patterns)
254
+ # Look for sequences that repeat
255
+ sequence_length = 3
256
+ sequences = {}
257
+ for i in range(len(example_tokens) - sequence_length):
258
+ seq = tuple(example_tokens[i:i+sequence_length])
259
+ sequences[seq] = sequences.get(seq, 0) + 1
260
+
261
+ repeated_sequences = sum(1 for count in sequences.values() if count > 1)
262
+ patterns['structural_repetition'] = repeated_sequences / max(len(sequences), 1)
263
+
264
+ # Estimate semantic overlap (tokens with high redundancy scores)
265
+ high_redundancy = sum(1 for t in token_efficiencies if t.redundancy_score > 0.7)
266
+ patterns['semantic_overlap'] = high_redundancy / max(len(token_efficiencies), 1)
267
+
268
+ return patterns
269
+
270
+ def _calculate_optimal_example_count(
271
+ self,
272
+ example_efficiencies: List[ExampleEfficiency]
273
+ ) -> int:
274
+ """Determine the optimal number of examples based on marginal benefits"""
275
+
276
+ if not example_efficiencies:
277
+ return 0
278
+
279
+ # Find point where marginal benefit drops below threshold
280
+ threshold = 0.3 # Examples adding less than 30% benefit are not worth it
281
+
282
+ for idx, efficiency in enumerate(example_efficiencies):
283
+ if efficiency.marginal_benefit < threshold and idx > 0:
284
+ return idx
285
+
286
+ # If all examples have good marginal benefit, use all
287
+ return len(example_efficiencies)
288
+
289
+ def _calculate_compression_potential(
290
+ self,
291
+ token_efficiencies: List[TokenEfficiency]
292
+ ) -> float:
293
+ """Calculate how much the context could be compressed"""
294
+
295
+ if not token_efficiencies:
296
+ return 0.0
297
+
298
+ # Tokens with high redundancy and low contribution can be removed
299
+ removable = sum(
300
+ 1 for t in token_efficiencies
301
+ if t.redundancy_score > 0.6 and t.contribution_score < 0.3
302
+ )
303
+
304
+ return removable / len(token_efficiencies)
305
+
306
+ def _calculate_attention_utilization(
307
+ self,
308
+ attention_weights: Optional[List[Dict]],
309
+ total_context_tokens: int
310
+ ) -> float:
311
+ """Calculate what percentage of context receives significant attention"""
312
+
313
+ if not attention_weights or total_context_tokens == 0:
314
+ return 0.0
315
+
316
+ # Aggregate attention across all layers and heads
317
+ attended_positions = set()
318
+
319
+ for record in attention_weights:
320
+ attn = record.get('attention')
321
+ if attn is not None and attn.dim() >= 3:
322
+ # Average across heads and look at which positions get attention
323
+ avg_attn = attn.mean(dim=1) # Average across heads
324
+
325
+ # Positions with attention > threshold are considered "utilized"
326
+ threshold = 0.05
327
+ high_attention = (avg_attn > threshold).nonzero(as_tuple=True)
328
+
329
+ if len(high_attention) > 1:
330
+ attended_positions.update(high_attention[1].tolist())
331
+
332
+ # Filter to only context positions
333
+ context_attended = [pos for pos in attended_positions if pos < total_context_tokens]
334
+
335
+ return len(context_attended) / total_context_tokens if total_context_tokens > 0 else 0.0
backend/icl_attention_extractor.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Real Attention Extraction for In-Context Learning Analysis
3
+
4
+ This module hooks into transformer models to extract actual attention weights
5
+ during generation, providing real data for ICL analysis.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from typing import List, Dict, Tuple, Optional, Any
12
+ from dataclasses import dataclass
13
+ import logging
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ @dataclass
18
+ class AttentionData:
19
+ """Stores attention data from model generation"""
20
+ layer_attentions: List[torch.Tensor] # Attention from each layer
21
+ token_positions: List[int] # Position of each generated token
22
+ example_boundaries: List[Tuple[int, int]] # Start/end positions of examples
23
+
24
+ class AttentionExtractor:
25
+ """Extracts real attention patterns from transformer models during generation"""
26
+
27
+ def __init__(self, model, tokenizer):
28
+ self.model = model
29
+ self.tokenizer = tokenizer
30
+ self.device = next(model.parameters()).device
31
+
32
+ # Storage for attention during generation
33
+ self.attention_weights = []
34
+ self.handles = []
35
+
36
+ def register_hooks(self):
37
+ """Register forward hooks to capture attention weights"""
38
+ self.clear_hooks()
39
+
40
+ # For CodeGen models, attention is in the transformer blocks
41
+ if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
42
+ # Hook into each transformer layer
43
+ for i, layer in enumerate(self.model.transformer.h):
44
+ if hasattr(layer, 'attn'):
45
+ handle = layer.attn.register_forward_hook(
46
+ lambda module, input, output, layer_idx=i:
47
+ self._attention_hook(module, input, output, layer_idx)
48
+ )
49
+ self.handles.append(handle)
50
+
51
+ logger.info(f"Registered {len(self.handles)} attention hooks")
52
+
53
+ def _attention_hook(self, module, input, output, layer_idx):
54
+ """Hook function to capture attention weights"""
55
+ # For CodeGen, output is (hidden_states, attention_weights)
56
+ if isinstance(output, tuple) and len(output) >= 2:
57
+ attention = output[1]
58
+ if attention is not None:
59
+ # Store attention weights
60
+ self.attention_weights.append({
61
+ 'layer': layer_idx,
62
+ 'attention': attention.detach().cpu()
63
+ })
64
+
65
+ def clear_hooks(self):
66
+ """Remove all hooks"""
67
+ for handle in self.handles:
68
+ handle.remove()
69
+ self.handles = []
70
+ self.attention_weights = []
71
+
72
+ def extract_attention_with_generation(
73
+ self,
74
+ input_ids: torch.Tensor,
75
+ attention_mask: torch.Tensor,
76
+ max_new_tokens: int = 50,
77
+ temperature: float = 0.7
78
+ ) -> Tuple[torch.Tensor, List[Dict], List[torch.Tensor]]:
79
+ """Generate text while extracting attention patterns"""
80
+
81
+ # Register hooks before generation
82
+ self.register_hooks()
83
+ self.attention_weights = []
84
+
85
+ try:
86
+ # Generate token by token to capture attention at each step
87
+ generated_ids = []
88
+ all_scores = [] # Store scores for confidence calculation
89
+ current_input_ids = input_ids.clone()
90
+ current_attention_mask = attention_mask.clone()
91
+
92
+ for _ in range(max_new_tokens):
93
+ with torch.no_grad():
94
+ # Forward pass through model
95
+ outputs = self.model(
96
+ input_ids=current_input_ids,
97
+ attention_mask=current_attention_mask,
98
+ use_cache=False, # Don't use cache to get full attention
99
+ output_attentions=True,
100
+ return_dict=True
101
+ )
102
+
103
+ # Capture attention from outputs if hooks didn't get it
104
+ if hasattr(outputs, 'attentions') and outputs.attentions is not None:
105
+ for layer_idx, attn in enumerate(outputs.attentions):
106
+ self.attention_weights.append({
107
+ 'layer': layer_idx,
108
+ 'attention': attn.detach().cpu()
109
+ })
110
+
111
+ # Get next token logits
112
+ next_token_logits = outputs.logits[:, -1, :]
113
+
114
+ # Store the scores
115
+ all_scores.append(next_token_logits)
116
+
117
+ # Apply temperature
118
+ if temperature > 0:
119
+ next_token_logits = next_token_logits / temperature
120
+ probs = F.softmax(next_token_logits, dim=-1)
121
+ next_token = torch.multinomial(probs, num_samples=1)
122
+ else:
123
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
124
+
125
+ # Stop if EOS token
126
+ if next_token.item() == self.tokenizer.eos_token_id:
127
+ break
128
+
129
+ # Append token
130
+ generated_ids.append(next_token.item())
131
+ current_input_ids = torch.cat([current_input_ids, next_token], dim=1)
132
+ current_attention_mask = torch.cat([
133
+ current_attention_mask,
134
+ torch.ones((1, 1), device=self.device)
135
+ ], dim=1)
136
+
137
+ # Convert to tensor
138
+ if generated_ids:
139
+ generated_tensor = torch.tensor(generated_ids, device=self.device).unsqueeze(0)
140
+ else:
141
+ generated_tensor = torch.tensor([[]], device=self.device, dtype=torch.long)
142
+
143
+ return generated_tensor, self.attention_weights, all_scores
144
+
145
+ finally:
146
+ # Always clear hooks after generation
147
+ self.clear_hooks()
148
+
149
+ def aggregate_attention_to_examples(
150
+ self,
151
+ attention_data: List[Dict],
152
+ example_boundaries: List[Tuple[int, int]],
153
+ prompt_length: int
154
+ ) -> Dict[str, List[float]]:
155
+ """
156
+ Aggregate attention from generated tokens back to example regions
157
+
158
+ Returns:
159
+ Dict mapping example_id -> list of attention weights per generated token
160
+ """
161
+
162
+ if not attention_data or not example_boundaries:
163
+ return {}
164
+
165
+ attention_to_examples = {}
166
+
167
+ # Process attention for each generated token position
168
+ # We have attention data for each layer for each generated token
169
+ # Count unique positions based on attention data
170
+ num_layers = 20 # CodeGen has 20 layers
171
+ num_generated = len(attention_data) // num_layers if attention_data else 0
172
+
173
+ logger.info(f"Processing {len(attention_data)} attention records for {num_generated} generated tokens")
174
+
175
+ for example_idx, (start, end) in enumerate(example_boundaries):
176
+ example_id = str(example_idx + 1)
177
+ example_attention = []
178
+
179
+ # For each generated token
180
+ for gen_idx in range(num_generated):
181
+ # Aggregate attention across all layers for this generated position
182
+ total_attention = 0.0
183
+
184
+ # Get attention records for this generated position
185
+ layer_count = 0
186
+ for i, attn_record in enumerate(attention_data):
187
+ # Each generated token should have attention from all layers
188
+ # So records [gen_idx*num_layers:(gen_idx+1)*num_layers] correspond to gen_idx
189
+ if i >= gen_idx * num_layers and i < (gen_idx + 1) * num_layers:
190
+ if 'attention' in attn_record:
191
+ attn_tensor = attn_record['attention']
192
+
193
+ # Get attention from generated position to example region
194
+ if attn_tensor.dim() >= 3:
195
+ # Shape: [batch, heads, seq_len, seq_len]
196
+ # The last position in the attention matrix corresponds to the newly generated token
197
+ seq_len = attn_tensor.shape[-1]
198
+
199
+ # Average across heads, get attention from last position to example region
200
+ if end <= seq_len:
201
+ attn_to_example = attn_tensor[0, :, -1, start:end].mean().item()
202
+ total_attention += attn_to_example
203
+ layer_count += 1
204
+
205
+ # Average across layers
206
+ if layer_count > 0:
207
+ example_attention.append(total_attention / layer_count)
208
+ else:
209
+ example_attention.append(0.0)
210
+
211
+ attention_to_examples[example_id] = example_attention
212
+
213
+ # Normalize attention for each generated token
214
+ for gen_idx in range(num_generated):
215
+ total = sum(
216
+ attention_to_examples[ex_id][gen_idx]
217
+ for ex_id in attention_to_examples
218
+ if gen_idx < len(attention_to_examples[ex_id])
219
+ )
220
+ if total > 0:
221
+ for ex_id in attention_to_examples:
222
+ if gen_idx < len(attention_to_examples[ex_id]):
223
+ attention_to_examples[ex_id][gen_idx] /= total
224
+
225
+ return attention_to_examples
226
+
227
+ def calculate_example_influences(
228
+ self,
229
+ attention_to_examples: Dict[str, List[float]]
230
+ ) -> Dict[str, float]:
231
+ """
232
+ Calculate overall influence of each example based on attention patterns
233
+
234
+ Returns:
235
+ Dict mapping example_id -> influence score (0-1)
236
+ """
237
+ influences = {}
238
+
239
+ for example_id, attention_weights in attention_to_examples.items():
240
+ # Overall influence is the mean attention across all generated tokens
241
+ if attention_weights:
242
+ influences[example_id] = float(np.mean(attention_weights))
243
+ else:
244
+ influences[example_id] = 0.0
245
+
246
+ # Normalize to sum to 1
247
+ total = sum(influences.values())
248
+ if total > 0:
249
+ influences = {k: v/total for k, v in influences.items()}
250
+
251
+ return influences
backend/icl_service.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In-Context Learning Analysis Service
3
+
4
+ Analyzes how examples influence model behavior during code generation.
5
+ """
6
+
7
+ import torch
8
+ import numpy as np
9
+ from typing import List, Dict, Optional, Any, Tuple
10
+ from dataclasses import dataclass
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ import torch.nn.functional as F
13
+ from .icl_attention_extractor import AttentionExtractor
14
+ from .induction_head_detector import InductionHeadDetector, ICLEmergenceAnalysis
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ @dataclass
20
+ class ICLExample:
21
+ """Represents an in-context learning example"""
22
+ input: str
23
+ output: str
24
+
25
+ @dataclass
26
+ class ICLAnalysisResult:
27
+ """Results from ICL analysis"""
28
+ shot_count: int
29
+ generated_code: str
30
+ tokens: List[str]
31
+ confidence_scores: List[float]
32
+ attention_from_examples: Dict[str, List[float]] # example_id -> attention weights per token
33
+ perplexity: float
34
+ avg_confidence: float
35
+ example_influences: Dict[str, float] # example_id -> overall influence score
36
+ hidden_state_drift: Optional[List[float]] = None # magnitude of hidden state changes
37
+ icl_emergence: Optional[ICLEmergenceAnalysis] = None # When/how ICL kicks in
38
+
39
+ class ICLAnalyzer:
40
+ """Analyzes in-context learning effects on model behavior"""
41
+
42
+ def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
43
+ self.model = model
44
+ self.tokenizer = tokenizer
45
+ self.device = next(model.parameters()).device
46
+
47
+ # Initialize attention extractor for real attention data
48
+ self.attention_extractor = AttentionExtractor(model, tokenizer)
49
+
50
+ # Initialize induction head detector
51
+ self.induction_detector = InductionHeadDetector(model, tokenizer)
52
+
53
+ # Storage for attention patterns
54
+ self.attention_maps = []
55
+ self.hidden_states = []
56
+
57
+ def prepare_prompt_with_examples(self, examples: List[ICLExample], test_prompt: str) -> str:
58
+ """Construct prompt with examples in standard format"""
59
+ if not examples:
60
+ return test_prompt
61
+
62
+ prompt_parts = []
63
+ for example in examples:
64
+ prompt_parts.append(f"{example.input}\n{example.output}\n")
65
+ prompt_parts.append(test_prompt)
66
+
67
+ return "\n".join(prompt_parts)
68
+
69
+ def extract_attention_patterns(self, outputs, input_ids, example_boundaries: List[Tuple[int, int]]) -> Dict[str, List[float]]:
70
+ """Extract attention patterns - real if available, simulated otherwise"""
71
+
72
+ # Try to use real attention data if available
73
+ if hasattr(self, 'last_attention_data') and self.last_attention_data:
74
+ logger.info("Using real attention data from model hooks")
75
+ prompt_length = len(input_ids[0])
76
+ return self.attention_extractor.aggregate_attention_to_examples(
77
+ self.last_attention_data,
78
+ example_boundaries,
79
+ prompt_length
80
+ )
81
+
82
+ # Fall back to simulated patterns
83
+ logger.info("Using simulated attention patterns")
84
+ attention_from_examples = {}
85
+
86
+ if not example_boundaries:
87
+ return attention_from_examples
88
+
89
+ generated_ids = outputs.sequences[0][len(input_ids[0]):]
90
+ num_generated = len(generated_ids)
91
+
92
+ if num_generated == 0:
93
+ return attention_from_examples
94
+
95
+ # Create simulated patterns (existing code)
96
+ for idx, (start, end) in enumerate(example_boundaries):
97
+ example_id = str(idx + 1)
98
+ base_weight = 0.3 + (idx * 0.1) / len(example_boundaries)
99
+
100
+ attention_weights = []
101
+ for token_idx in range(num_generated):
102
+ weight = base_weight * np.exp(-token_idx * 0.05)
103
+ weight += np.random.normal(0, 0.02)
104
+ weight = max(0, min(1, weight))
105
+ attention_weights.append(weight)
106
+
107
+ attention_from_examples[example_id] = attention_weights
108
+
109
+ # Normalize
110
+ if len(attention_from_examples) > 1:
111
+ for token_idx in range(num_generated):
112
+ total = sum(weights[token_idx] for weights in attention_from_examples.values())
113
+ if total > 0:
114
+ for example_id in attention_from_examples:
115
+ attention_from_examples[example_id][token_idx] /= total
116
+
117
+ return attention_from_examples
118
+
119
+ def calculate_example_influences(self, attention_from_examples: Dict[str, List[float]]) -> Dict[str, float]:
120
+ """Calculate overall influence score for each example"""
121
+
122
+ # If we have real attention data, use the extractor's method
123
+ if hasattr(self, 'last_attention_data') and self.last_attention_data:
124
+ return self.attention_extractor.calculate_example_influences(attention_from_examples)
125
+
126
+ # Otherwise use existing calculation
127
+ influences = {}
128
+
129
+ for example_id, weights in attention_from_examples.items():
130
+ influences[example_id] = float(np.mean(weights)) if weights else 0.0
131
+
132
+ total = sum(influences.values())
133
+ if total > 0 and total != 1.0:
134
+ influences = {k: v/total for k, v in influences.items()}
135
+
136
+ return influences
137
+
138
+ def track_hidden_state_drift(self, base_hidden_states, example_hidden_states) -> List[float]:
139
+ """Track how hidden states change from base (no examples) to with examples"""
140
+ if base_hidden_states is None or example_hidden_states is None:
141
+ return []
142
+
143
+ # Calculate L2 distance between hidden states at each position
144
+ drift = []
145
+ min_len = min(len(base_hidden_states), len(example_hidden_states))
146
+
147
+ for i in range(min_len):
148
+ base = base_hidden_states[i]
149
+ example = example_hidden_states[i]
150
+
151
+ if isinstance(base, torch.Tensor):
152
+ base = base.cpu().numpy()
153
+ if isinstance(example, torch.Tensor):
154
+ example = example.cpu().numpy()
155
+
156
+ distance = np.linalg.norm(example - base)
157
+ drift.append(float(distance))
158
+
159
+ return drift
160
+
161
+ def analyze_generation(
162
+ self,
163
+ examples: List[ICLExample],
164
+ test_prompt: str,
165
+ max_length: int = 150,
166
+ temperature: float = 0.7,
167
+ base_hidden_states: Optional[Any] = None
168
+ ) -> ICLAnalysisResult:
169
+ """Analyze how examples influence generation"""
170
+
171
+ # Prepare prompt
172
+ full_prompt = self.prepare_prompt_with_examples(examples, test_prompt)
173
+
174
+ # Tokenize
175
+ inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
176
+ input_ids = inputs["input_ids"].to(self.device)
177
+ attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(self.device)
178
+
179
+ # Find example boundaries in token space
180
+ example_boundaries = []
181
+ if examples:
182
+ current_pos = 0
183
+ for example in examples:
184
+ example_text = f"{example.input}\n{example.output}\n"
185
+ example_tokens = self.tokenizer(example_text, add_special_tokens=False)["input_ids"]
186
+ example_boundaries.append((current_pos, current_pos + len(example_tokens)))
187
+ current_pos += len(example_tokens)
188
+
189
+ # First do standard generation to get scores and text
190
+ with torch.no_grad():
191
+ outputs = self.model.generate(
192
+ input_ids,
193
+ attention_mask=attention_mask,
194
+ max_length=max_length,
195
+ temperature=temperature,
196
+ do_sample=temperature > 0,
197
+ pad_token_id=self.tokenizer.pad_token_id,
198
+ return_dict_in_generate=True,
199
+ output_scores=True,
200
+ output_hidden_states=False
201
+ )
202
+
203
+ # Then try to extract real attention data
204
+ try:
205
+ logger.info("Extracting real attention data")
206
+ _, attention_data, _ = self.attention_extractor.extract_attention_with_generation(
207
+ input_ids=input_ids,
208
+ attention_mask=attention_mask,
209
+ max_new_tokens=min(30, max_length - len(input_ids[0])), # Limit for performance
210
+ temperature=temperature
211
+ )
212
+ self.last_attention_data = attention_data
213
+ logger.info(f"Successfully extracted {len(attention_data)} attention records")
214
+ except Exception as e:
215
+ logger.warning(f"Real attention extraction failed: {e}")
216
+ self.last_attention_data = None
217
+
218
+ # Extract generated tokens - show raw output, no trimming
219
+ generated_ids = outputs.sequences[0][len(input_ids[0]):]
220
+ generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
221
+ tokens = [self.tokenizer.decode([token_id]) for token_id in generated_ids]
222
+
223
+ # Calculate confidence scores
224
+ confidence_scores = []
225
+ if outputs.scores:
226
+ for score in outputs.scores:
227
+ probs = F.softmax(score[0], dim=-1)
228
+ max_prob = probs.max().item()
229
+ confidence_scores.append(max_prob)
230
+
231
+ # Calculate perplexity
232
+ if outputs.scores:
233
+ log_probs = []
234
+ for i, score in enumerate(outputs.scores):
235
+ if i < len(generated_ids):
236
+ token_id = generated_ids[i]
237
+ log_prob = F.log_softmax(score[0], dim=-1)[token_id].item()
238
+ log_probs.append(log_prob)
239
+ perplexity = np.exp(-np.mean(log_probs)) if log_probs else 0.0
240
+ else:
241
+ perplexity = 0.0
242
+
243
+ # Extract attention patterns
244
+ attention_from_examples = self.extract_attention_patterns(outputs, input_ids, example_boundaries)
245
+
246
+ # Calculate example influences
247
+ example_influences = self.calculate_example_influences(attention_from_examples)
248
+
249
+ # Track hidden state drift if base states provided
250
+ hidden_state_drift = None
251
+ if base_hidden_states is not None and hasattr(outputs, 'hidden_states'):
252
+ current_hidden = outputs.hidden_states[-1] if outputs.hidden_states else None
253
+ if current_hidden is not None:
254
+ hidden_state_drift = self.track_hidden_state_drift(base_hidden_states, current_hidden)
255
+
256
+ # Analyze ICL emergence if we have attention data and examples
257
+ icl_emergence = None
258
+ if self.last_attention_data and len(examples) > 0:
259
+ try:
260
+ icl_emergence = self.induction_detector.analyze_icl_emergence(
261
+ self.last_attention_data,
262
+ input_ids,
263
+ example_boundaries,
264
+ generated_ids.tolist() if generated_ids.numel() > 0 else []
265
+ )
266
+ logger.info(f"ICL emergence analysis: detected={icl_emergence.emergence_detected}, "
267
+ f"token={icl_emergence.emergence_token}, confidence={icl_emergence.confidence:.2f}")
268
+ except Exception as e:
269
+ logger.warning(f"ICL emergence analysis failed: {e}")
270
+
271
+ return ICLAnalysisResult(
272
+ shot_count=len(examples),
273
+ generated_code=generated_text,
274
+ tokens=tokens,
275
+ confidence_scores=confidence_scores,
276
+ attention_from_examples=attention_from_examples,
277
+ perplexity=perplexity,
278
+ avg_confidence=np.mean(confidence_scores) if confidence_scores else 0.0,
279
+ example_influences=example_influences,
280
+ hidden_state_drift=hidden_state_drift,
281
+ icl_emergence=icl_emergence
282
+ )
283
+
284
+ def compare_shot_settings(
285
+ self,
286
+ examples: List[ICLExample],
287
+ test_prompt: str,
288
+ max_length: int = 150,
289
+ temperature: float = 0.7
290
+ ) -> Dict[str, ICLAnalysisResult]:
291
+ """Compare 0-shot, 1-shot, and few-shot generation"""
292
+ results = {}
293
+
294
+ # 0-shot (no examples)
295
+ results['zero_shot'] = self.analyze_generation([], test_prompt, max_length, temperature)
296
+ base_hidden = results['zero_shot'].hidden_state_drift # Use as baseline
297
+
298
+ # 1-shot (first example only)
299
+ if len(examples) >= 1:
300
+ results['one_shot'] = self.analyze_generation(
301
+ examples[:1], test_prompt, max_length, temperature, base_hidden
302
+ )
303
+
304
+ # Few-shot (all examples)
305
+ if len(examples) >= 2:
306
+ results['few_shot'] = self.analyze_generation(
307
+ examples, test_prompt, max_length, temperature, base_hidden
308
+ )
309
+
310
+ return results
backend/induction_head_detector.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Induction Head Detection for In-Context Learning
3
+
4
+ Based on research showing that ICL emerges abruptly in transformers through
5
+ the formation of induction heads - attention patterns that copy from context.
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from typing import List, Dict, Tuple, Optional
11
+ from dataclasses import dataclass
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ @dataclass
17
+ class InductionHeadSignal:
18
+ """Signals indicating induction head behavior"""
19
+ layer: int
20
+ head: int
21
+ strength: float # 0-1 score of induction pattern strength
22
+ pattern_type: str # 'copy', 'prefix_match', 'abstract'
23
+ emergence_point: Optional[int] # Token position where pattern emerges
24
+
25
+ @dataclass
26
+ class ICLEmergenceAnalysis:
27
+ """Analysis of when and how ICL emerges"""
28
+ emergence_detected: bool
29
+ emergence_token: Optional[int] # Token position where ICL kicks in
30
+ emergence_layer: Optional[int] # Layer where strongest signal appears
31
+ confidence: float # Confidence in detection (0-1)
32
+ induction_heads: List[InductionHeadSignal]
33
+ attention_entropy_drop: List[float] # Entropy at each position
34
+ pattern_consistency: float # How consistent the pattern is
35
+
36
+ class InductionHeadDetector:
37
+ """Detects induction heads and ICL emergence in transformer models"""
38
+
39
+ def __init__(self, model, tokenizer):
40
+ self.model = model
41
+ self.tokenizer = tokenizer
42
+ self.device = next(model.parameters()).device
43
+
44
+ def detect_induction_heads(
45
+ self,
46
+ attention_weights: List[Dict],
47
+ input_ids: torch.Tensor,
48
+ example_boundaries: List[Tuple[int, int]]
49
+ ) -> List[InductionHeadSignal]:
50
+ """
51
+ Detect induction heads by looking for attention patterns that:
52
+ 1. Copy from previous occurrences (classic induction)
53
+ 2. Match prefixes across examples
54
+ 3. Show abstract pattern matching
55
+ """
56
+ induction_heads = []
57
+
58
+ if not attention_weights or not example_boundaries:
59
+ return induction_heads
60
+
61
+ # Analyze each layer and head
62
+ layers_analyzed = {}
63
+ for record in attention_weights:
64
+ layer_idx = record.get('layer', 0)
65
+ attn = record.get('attention')
66
+
67
+ if attn is None or layer_idx in layers_analyzed:
68
+ continue
69
+
70
+ layers_analyzed[layer_idx] = True
71
+
72
+ # Analyze each attention head
73
+ if attn.dim() >= 3:
74
+ num_heads = attn.shape[1]
75
+ seq_len = attn.shape[-1]
76
+
77
+ for head_idx in range(num_heads):
78
+ head_attn = attn[0, head_idx] # [seq_len, seq_len]
79
+
80
+ # Detect different induction patterns
81
+ copy_score = self._detect_copy_pattern(head_attn, input_ids)
82
+ prefix_score = self._detect_prefix_matching(head_attn, example_boundaries)
83
+ abstract_score = self._detect_abstract_pattern(head_attn, seq_len)
84
+
85
+ # Determine strongest pattern
86
+ max_score = max(copy_score, prefix_score, abstract_score)
87
+ if max_score > 0.3: # Threshold for significant pattern
88
+ pattern_type = 'copy' if copy_score == max_score else \
89
+ 'prefix_match' if prefix_score == max_score else 'abstract'
90
+
91
+ # Find emergence point (where pattern suddenly strengthens)
92
+ emergence_point = self._find_emergence_point(head_attn)
93
+
94
+ induction_heads.append(InductionHeadSignal(
95
+ layer=layer_idx,
96
+ head=head_idx,
97
+ strength=max_score,
98
+ pattern_type=pattern_type,
99
+ emergence_point=emergence_point
100
+ ))
101
+
102
+ return induction_heads
103
+
104
+ def _detect_copy_pattern(self, attn_matrix: torch.Tensor, input_ids: torch.Tensor) -> float:
105
+ """Detect if attention head copies from previous occurrences"""
106
+ seq_len = attn_matrix.shape[0]
107
+ copy_score = 0.0
108
+ count = 0
109
+
110
+ # Look for positions that attend strongly to previous same/similar tokens
111
+ for i in range(1, min(seq_len, 50)): # Limit analysis for efficiency
112
+ if i >= len(input_ids[0]):
113
+ break
114
+
115
+ current_token = input_ids[0][i].item()
116
+
117
+ # Find previous occurrences of the same token
118
+ for j in range(i):
119
+ if j < len(input_ids[0]) and input_ids[0][j].item() == current_token:
120
+ # Check if attention is strong to this position
121
+ if attn_matrix[i, j] > 0.1: # Threshold for significant attention
122
+ copy_score += attn_matrix[i, j].item()
123
+ count += 1
124
+
125
+ return copy_score / max(count, 1)
126
+
127
+ def _detect_prefix_matching(
128
+ self,
129
+ attn_matrix: torch.Tensor,
130
+ example_boundaries: List[Tuple[int, int]]
131
+ ) -> float:
132
+ """Detect if attention matches prefixes across examples"""
133
+ if len(example_boundaries) < 2:
134
+ return 0.0
135
+
136
+ prefix_score = 0.0
137
+ count = 0
138
+
139
+ # Check if tokens attend to similar positions in different examples
140
+ for i, (start1, end1) in enumerate(example_boundaries[:-1]):
141
+ for j, (start2, end2) in enumerate(example_boundaries[i+1:], i+1):
142
+ # Compare attention patterns between examples
143
+ for offset in range(min(5, end1-start1, end2-start2)): # Check first 5 tokens
144
+ pos1 = start1 + offset
145
+ pos2 = start2 + offset
146
+
147
+ if pos1 < attn_matrix.shape[0] and pos2 < attn_matrix.shape[1]:
148
+ # Check if later example attends to earlier example at same offset
149
+ if pos2 < attn_matrix.shape[0] and pos1 < attn_matrix.shape[1]:
150
+ attention_strength = attn_matrix[pos2, pos1].item()
151
+ if attention_strength > 0.1:
152
+ prefix_score += attention_strength
153
+ count += 1
154
+
155
+ return prefix_score / max(count, 1)
156
+
157
+ def _detect_abstract_pattern(self, attn_matrix: torch.Tensor, seq_len: int) -> float:
158
+ """Detect abstract pattern matching (e.g., function->function mapping)"""
159
+ # Look for diagonal patterns offset by example length
160
+ # This indicates attending to structurally similar positions
161
+
162
+ abstract_score = 0.0
163
+ window_size = 10
164
+
165
+ for i in range(window_size, min(seq_len, 50)):
166
+ # Check if attention follows a diagonal pattern with offset
167
+ diagonal_sum = 0.0
168
+ for offset in range(1, min(window_size, i)):
169
+ if i - offset >= 0:
170
+ diagonal_sum += attn_matrix[i, i - offset].item()
171
+
172
+ # High diagonal attention indicates structural copying
173
+ if diagonal_sum / window_size > 0.1:
174
+ abstract_score += diagonal_sum / window_size
175
+
176
+ return min(abstract_score / 10, 1.0) # Normalize
177
+
178
+ def _find_emergence_point(self, attn_matrix: torch.Tensor) -> Optional[int]:
179
+ """Find the token position where the pattern suddenly emerges"""
180
+ seq_len = min(attn_matrix.shape[0], 50) # Limit for efficiency
181
+
182
+ if seq_len < 10:
183
+ return None
184
+
185
+ # Calculate attention entropy at each position
186
+ entropies = []
187
+ for i in range(seq_len):
188
+ attn_dist = attn_matrix[i, :i+1] # Only look at previous positions
189
+ if attn_dist.sum() > 0:
190
+ attn_dist = attn_dist / attn_dist.sum()
191
+ # Calculate entropy
192
+ entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
193
+ entropies.append(entropy)
194
+ else:
195
+ entropies.append(0.0)
196
+
197
+ # Find sudden drops in entropy (indicating focused attention)
198
+ if len(entropies) < 5:
199
+ return None
200
+
201
+ for i in range(4, len(entropies)):
202
+ recent_avg = np.mean(entropies[i-4:i])
203
+ current = entropies[i]
204
+
205
+ # Sudden drop indicates emergence
206
+ if recent_avg > 0 and current < recent_avg * 0.5:
207
+ return i
208
+
209
+ return None
210
+
211
+ def analyze_icl_emergence(
212
+ self,
213
+ attention_weights: List[Dict],
214
+ input_ids: torch.Tensor,
215
+ example_boundaries: List[Tuple[int, int]],
216
+ generated_tokens: List[int]
217
+ ) -> ICLEmergenceAnalysis:
218
+ """
219
+ Comprehensive analysis of when and how ICL emerges during generation
220
+ """
221
+
222
+ # Detect induction heads
223
+ induction_heads = self.detect_induction_heads(
224
+ attention_weights, input_ids, example_boundaries
225
+ )
226
+
227
+ # Calculate attention entropy trajectory
228
+ entropy_trajectory = self._calculate_entropy_trajectory(
229
+ attention_weights, len(generated_tokens)
230
+ )
231
+
232
+ # Determine emergence point
233
+ emergence_token = None
234
+ emergence_layer = None
235
+ emergence_confidence = 0.0
236
+
237
+ if induction_heads:
238
+ # Find strongest induction signal
239
+ strongest_head = max(induction_heads, key=lambda h: h.strength)
240
+
241
+ # Check for consistent emergence points across heads
242
+ emergence_points = [h.emergence_point for h in induction_heads if h.emergence_point]
243
+ if emergence_points:
244
+ # Most common emergence point
245
+ emergence_token = int(np.median(emergence_points))
246
+ emergence_layer = strongest_head.layer
247
+
248
+ # Confidence based on consistency and strength
249
+ consistency = len(emergence_points) / len(induction_heads)
250
+ emergence_confidence = min(strongest_head.strength * consistency, 1.0)
251
+
252
+ # Check for entropy drop as additional signal
253
+ if entropy_trajectory and len(entropy_trajectory) > 5:
254
+ for i in range(5, len(entropy_trajectory)):
255
+ recent_avg = np.mean(entropy_trajectory[i-5:i])
256
+ if recent_avg > 0 and entropy_trajectory[i] < recent_avg * 0.6:
257
+ if emergence_token is None:
258
+ emergence_token = i
259
+ emergence_confidence = 0.5
260
+ break
261
+
262
+ # Calculate pattern consistency
263
+ pattern_consistency = self._calculate_pattern_consistency(induction_heads)
264
+
265
+ return ICLEmergenceAnalysis(
266
+ emergence_detected=emergence_token is not None,
267
+ emergence_token=emergence_token,
268
+ emergence_layer=emergence_layer,
269
+ confidence=emergence_confidence,
270
+ induction_heads=induction_heads,
271
+ attention_entropy_drop=entropy_trajectory,
272
+ pattern_consistency=pattern_consistency
273
+ )
274
+
275
+ def _calculate_entropy_trajectory(
276
+ self,
277
+ attention_weights: List[Dict],
278
+ num_generated: int
279
+ ) -> List[float]:
280
+ """Calculate attention entropy at each generated position"""
281
+ entropies = []
282
+
283
+ if not attention_weights:
284
+ return entropies
285
+
286
+ # Group attention by position
287
+ num_layers = 20 # CodeGen model
288
+
289
+ for gen_idx in range(num_generated):
290
+ position_entropy = []
291
+
292
+ # Get attention for this generated position across all layers
293
+ for i in range(gen_idx * num_layers, min((gen_idx + 1) * num_layers, len(attention_weights))):
294
+ if i < len(attention_weights):
295
+ attn = attention_weights[i].get('attention')
296
+ if attn is not None and attn.dim() >= 3:
297
+ # Average across heads
298
+ avg_attn = attn[0].mean(dim=0)
299
+ if avg_attn.shape[0] > gen_idx:
300
+ # Get attention distribution for this position
301
+ attn_dist = avg_attn[-1] # Last position is newly generated
302
+ if attn_dist.sum() > 0:
303
+ attn_dist = attn_dist / attn_dist.sum()
304
+ # Calculate entropy
305
+ entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
306
+ position_entropy.append(entropy)
307
+
308
+ if position_entropy:
309
+ entropies.append(np.mean(position_entropy))
310
+ else:
311
+ entropies.append(0.0)
312
+
313
+ return entropies
314
+
315
+ def _calculate_pattern_consistency(self, induction_heads: List[InductionHeadSignal]) -> float:
316
+ """Calculate how consistent the induction patterns are across heads"""
317
+ if not induction_heads:
318
+ return 0.0
319
+
320
+ # Group by pattern type
321
+ pattern_counts = {}
322
+ for head in induction_heads:
323
+ pattern_counts[head.pattern_type] = pattern_counts.get(head.pattern_type, 0) + 1
324
+
325
+ # Consistency is ratio of dominant pattern
326
+ max_count = max(pattern_counts.values())
327
+ return max_count / len(induction_heads)
backend/model_service.py CHANGED
@@ -57,6 +57,17 @@ class AblatedGenerationRequest(BaseModel):
57
  extract_traces: bool = False
58
  disabled_components: Optional[Dict[str, Any]] = None
59
 
 
 
 
 
 
 
 
 
 
 
 
60
  class DemoRequest(BaseModel):
61
  demo_id: str
62
 
@@ -855,6 +866,61 @@ async def generate_ablated(request: AblatedGenerationRequest, authenticated: boo
855
  )
856
  return result
857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858
  @app.get("/demos")
859
  async def list_demos(authenticated: bool = Depends(verify_api_key)):
860
  """List available demo prompts"""
 
57
  extract_traces: bool = False
58
  disabled_components: Optional[Dict[str, Any]] = None
59
 
60
+ class ICLExample(BaseModel):
61
+ input: str
62
+ output: str
63
+
64
+ class ICLGenerationRequest(BaseModel):
65
+ examples: List[ICLExample]
66
+ prompt: str
67
+ max_tokens: int = 200 # Increased to accommodate examples + generation
68
+ temperature: float = 0.7
69
+ analyze: bool = True
70
+
71
  class DemoRequest(BaseModel):
72
  demo_id: str
73
 
 
866
  )
867
  return result
868
 
869
+ @app.post("/generate/icl")
870
+ async def generate_icl(request: ICLGenerationRequest, authenticated: bool = Depends(verify_api_key)):
871
+ """Generate text with in-context learning analysis"""
872
+ from .icl_service import ICLAnalyzer, ICLExample as ICLExampleData
873
+
874
+ # Initialize ICL analyzer
875
+ analyzer = ICLAnalyzer(manager.model, manager.tokenizer)
876
+
877
+ # Convert request examples to ICLExample format
878
+ examples = [ICLExampleData(input=ex.input, output=ex.output) for ex in request.examples]
879
+
880
+ # Analyze generation with examples
881
+ result = analyzer.analyze_generation(
882
+ examples=examples,
883
+ test_prompt=request.prompt,
884
+ max_length=request.max_tokens,
885
+ temperature=request.temperature
886
+ )
887
+
888
+ # Convert result to dict for JSON response
889
+ response_data = {
890
+ "shotCount": result.shot_count,
891
+ "generatedCode": result.generated_code,
892
+ "tokens": result.tokens,
893
+ "confidenceScores": result.confidence_scores,
894
+ "attentionFromExamples": result.attention_from_examples,
895
+ "perplexity": result.perplexity,
896
+ "avgConfidence": result.avg_confidence,
897
+ "exampleInfluences": result.example_influences,
898
+ "hiddenStateDrift": result.hidden_state_drift
899
+ }
900
+
901
+ # Add ICL emergence data if available
902
+ if result.icl_emergence:
903
+ response_data["iclEmergence"] = {
904
+ "emergenceDetected": result.icl_emergence.emergence_detected,
905
+ "emergenceToken": result.icl_emergence.emergence_token,
906
+ "emergenceLayer": result.icl_emergence.emergence_layer,
907
+ "confidence": result.icl_emergence.confidence,
908
+ "inductionHeads": [
909
+ {
910
+ "layer": h.layer,
911
+ "head": h.head,
912
+ "strength": h.strength,
913
+ "patternType": h.pattern_type,
914
+ "emergencePoint": h.emergence_point
915
+ }
916
+ for h in result.icl_emergence.induction_heads
917
+ ],
918
+ "attentionEntropyDrop": result.icl_emergence.attention_entropy_drop,
919
+ "patternConsistency": result.icl_emergence.pattern_consistency
920
+ }
921
+
922
+ return response_data
923
+
924
  @app.get("/demos")
925
  async def list_demos(authenticated: bool = Depends(verify_api_key)):
926
  """List available demo prompts"""