Spaces:
Sleeping
Sleeping
gary-boon
Claude
commited on
Commit
·
920a98d
1
Parent(s):
bb8a292
Add backend support for ICL emergence analysis
Browse files- Implement ICL attention extractor with PyTorch hooks
- Add induction head detector for pattern recognition
- Create context efficiency analyzer for optimal example usage
- Update model service with ICL emergence endpoints
- Support real-time attention weight extraction during generation
- Enable token-by-token generation for attention capture
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
- backend/context_efficiency_analyzer.py +335 -0
- backend/icl_attention_extractor.py +251 -0
- backend/icl_service.py +310 -0
- backend/induction_head_detector.py +327 -0
- backend/model_service.py +66 -0
backend/context_efficiency_analyzer.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Context Efficiency Analyzer for In-Context Learning
|
| 3 |
+
|
| 4 |
+
Measures how efficiently the model uses context examples to perform tasks.
|
| 5 |
+
Based on research showing that not all examples contribute equally and that
|
| 6 |
+
optimal context usage can significantly improve performance.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import List, Dict, Tuple, Optional
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class TokenEfficiency:
|
| 19 |
+
"""Efficiency metrics for individual tokens"""
|
| 20 |
+
token: str
|
| 21 |
+
position: int
|
| 22 |
+
information_content: float # Bits of information
|
| 23 |
+
redundancy_score: float # 0-1 (1 = completely redundant)
|
| 24 |
+
contribution_score: float # How much it contributes to output
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class ExampleEfficiency:
|
| 28 |
+
"""Efficiency metrics for each example"""
|
| 29 |
+
example_id: str
|
| 30 |
+
total_tokens: int
|
| 31 |
+
effective_tokens: int # Tokens that actually contribute
|
| 32 |
+
efficiency_ratio: float # effective/total
|
| 33 |
+
redundancy_rate: float # Percentage of redundant tokens
|
| 34 |
+
information_density: float # Bits per token
|
| 35 |
+
marginal_benefit: float # Additional benefit vs previous examples
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class ContextEfficiencyAnalysis:
|
| 39 |
+
"""Complete context efficiency analysis"""
|
| 40 |
+
overall_efficiency: float # 0-1 score
|
| 41 |
+
total_context_tokens: int
|
| 42 |
+
effective_context_tokens: int
|
| 43 |
+
example_efficiencies: List[ExampleEfficiency]
|
| 44 |
+
token_efficiencies: List[TokenEfficiency]
|
| 45 |
+
optimal_example_count: int # Suggested optimal number of examples
|
| 46 |
+
redundancy_patterns: Dict[str, float] # Pattern type -> frequency
|
| 47 |
+
compression_potential: float # How much context could be compressed
|
| 48 |
+
attention_utilization: float # How much of context gets attention
|
| 49 |
+
|
| 50 |
+
class ContextEfficiencyAnalyzer:
|
| 51 |
+
"""Analyzes how efficiently context is used in ICL"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, model, tokenizer):
|
| 54 |
+
self.model = model
|
| 55 |
+
self.tokenizer = tokenizer
|
| 56 |
+
self.device = next(model.parameters()).device
|
| 57 |
+
|
| 58 |
+
def analyze_context_efficiency(
|
| 59 |
+
self,
|
| 60 |
+
examples: List[Tuple[str, str]], # (input, output) pairs
|
| 61 |
+
test_prompt: str,
|
| 62 |
+
attention_weights: Optional[List[Dict]] = None,
|
| 63 |
+
generated_tokens: List[str] = None,
|
| 64 |
+
confidence_scores: List[float] = None
|
| 65 |
+
) -> ContextEfficiencyAnalysis:
|
| 66 |
+
"""
|
| 67 |
+
Comprehensive analysis of context efficiency
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# Tokenize all examples
|
| 71 |
+
example_tokens = []
|
| 72 |
+
example_boundaries = []
|
| 73 |
+
current_pos = 0
|
| 74 |
+
|
| 75 |
+
for idx, (input_text, output_text) in enumerate(examples):
|
| 76 |
+
example_text = f"{input_text}\n{output_text}\n"
|
| 77 |
+
tokens = self.tokenizer.tokenize(example_text)
|
| 78 |
+
example_tokens.extend(tokens)
|
| 79 |
+
example_boundaries.append((current_pos, current_pos + len(tokens)))
|
| 80 |
+
current_pos += len(tokens)
|
| 81 |
+
|
| 82 |
+
# Analyze each example's efficiency
|
| 83 |
+
example_efficiencies = []
|
| 84 |
+
for idx, (start, end) in enumerate(example_boundaries):
|
| 85 |
+
efficiency = self._analyze_example_efficiency(
|
| 86 |
+
example_idx=idx,
|
| 87 |
+
example_tokens=example_tokens[start:end],
|
| 88 |
+
all_tokens=example_tokens,
|
| 89 |
+
attention_weights=attention_weights,
|
| 90 |
+
generated_tokens=generated_tokens
|
| 91 |
+
)
|
| 92 |
+
example_efficiencies.append(efficiency)
|
| 93 |
+
|
| 94 |
+
# Analyze token-level efficiency
|
| 95 |
+
token_efficiencies = self._analyze_token_efficiency(
|
| 96 |
+
example_tokens=example_tokens,
|
| 97 |
+
attention_weights=attention_weights,
|
| 98 |
+
generated_tokens=generated_tokens
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Calculate redundancy patterns
|
| 102 |
+
redundancy_patterns = self._identify_redundancy_patterns(
|
| 103 |
+
example_tokens=example_tokens,
|
| 104 |
+
token_efficiencies=token_efficiencies
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Determine optimal example count
|
| 108 |
+
optimal_count = self._calculate_optimal_example_count(
|
| 109 |
+
example_efficiencies=example_efficiencies
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Calculate compression potential
|
| 113 |
+
compression_potential = self._calculate_compression_potential(
|
| 114 |
+
token_efficiencies=token_efficiencies
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Calculate attention utilization
|
| 118 |
+
attention_utilization = self._calculate_attention_utilization(
|
| 119 |
+
attention_weights=attention_weights,
|
| 120 |
+
total_context_tokens=len(example_tokens)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Calculate overall efficiency
|
| 124 |
+
effective_tokens = sum(1 for t in token_efficiencies if t.redundancy_score < 0.5)
|
| 125 |
+
overall_efficiency = effective_tokens / max(len(example_tokens), 1)
|
| 126 |
+
|
| 127 |
+
return ContextEfficiencyAnalysis(
|
| 128 |
+
overall_efficiency=overall_efficiency,
|
| 129 |
+
total_context_tokens=len(example_tokens),
|
| 130 |
+
effective_context_tokens=effective_tokens,
|
| 131 |
+
example_efficiencies=example_efficiencies,
|
| 132 |
+
token_efficiencies=token_efficiencies,
|
| 133 |
+
optimal_example_count=optimal_count,
|
| 134 |
+
redundancy_patterns=redundancy_patterns,
|
| 135 |
+
compression_potential=compression_potential,
|
| 136 |
+
attention_utilization=attention_utilization
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def _analyze_example_efficiency(
|
| 140 |
+
self,
|
| 141 |
+
example_idx: int,
|
| 142 |
+
example_tokens: List[str],
|
| 143 |
+
all_tokens: List[str],
|
| 144 |
+
attention_weights: Optional[List[Dict]],
|
| 145 |
+
generated_tokens: List[str]
|
| 146 |
+
) -> ExampleEfficiency:
|
| 147 |
+
"""Analyze efficiency of a single example"""
|
| 148 |
+
|
| 149 |
+
# Calculate redundancy with previous examples
|
| 150 |
+
redundant_count = 0
|
| 151 |
+
if example_idx > 0:
|
| 152 |
+
# Check for repeated patterns
|
| 153 |
+
for token in example_tokens:
|
| 154 |
+
if all_tokens[:example_idx * len(example_tokens)].count(token) > 2:
|
| 155 |
+
redundant_count += 1
|
| 156 |
+
|
| 157 |
+
redundancy_rate = redundant_count / max(len(example_tokens), 1)
|
| 158 |
+
|
| 159 |
+
# Calculate information density (simplified Shannon entropy)
|
| 160 |
+
unique_tokens = len(set(example_tokens))
|
| 161 |
+
information_density = np.log2(max(unique_tokens, 1)) / max(len(example_tokens), 1)
|
| 162 |
+
|
| 163 |
+
# Calculate marginal benefit (how much this example adds)
|
| 164 |
+
if example_idx == 0:
|
| 165 |
+
marginal_benefit = 1.0 # First example always has full benefit
|
| 166 |
+
else:
|
| 167 |
+
# Estimate based on new unique patterns introduced
|
| 168 |
+
new_patterns = set(example_tokens) - set(all_tokens[:example_idx * len(example_tokens)])
|
| 169 |
+
marginal_benefit = len(new_patterns) / max(len(example_tokens), 1)
|
| 170 |
+
|
| 171 |
+
# Calculate effective tokens (those that contribute)
|
| 172 |
+
effective_tokens = int(len(example_tokens) * (1 - redundancy_rate))
|
| 173 |
+
|
| 174 |
+
return ExampleEfficiency(
|
| 175 |
+
example_id=str(example_idx + 1),
|
| 176 |
+
total_tokens=len(example_tokens),
|
| 177 |
+
effective_tokens=effective_tokens,
|
| 178 |
+
efficiency_ratio=effective_tokens / max(len(example_tokens), 1),
|
| 179 |
+
redundancy_rate=redundancy_rate,
|
| 180 |
+
information_density=information_density,
|
| 181 |
+
marginal_benefit=marginal_benefit
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def _analyze_token_efficiency(
|
| 185 |
+
self,
|
| 186 |
+
example_tokens: List[str],
|
| 187 |
+
attention_weights: Optional[List[Dict]],
|
| 188 |
+
generated_tokens: List[str]
|
| 189 |
+
) -> List[TokenEfficiency]:
|
| 190 |
+
"""Analyze efficiency of individual tokens"""
|
| 191 |
+
|
| 192 |
+
token_efficiencies = []
|
| 193 |
+
|
| 194 |
+
for idx, token in enumerate(example_tokens):
|
| 195 |
+
# Calculate information content (simplified)
|
| 196 |
+
# Rare tokens have more information
|
| 197 |
+
frequency = example_tokens.count(token)
|
| 198 |
+
information_content = np.log2(len(example_tokens) / max(frequency, 1))
|
| 199 |
+
|
| 200 |
+
# Calculate redundancy
|
| 201 |
+
# Tokens that appear many times in same context are redundant
|
| 202 |
+
local_window = example_tokens[max(0, idx-5):min(len(example_tokens), idx+5)]
|
| 203 |
+
local_frequency = local_window.count(token)
|
| 204 |
+
redundancy_score = min(local_frequency / 3.0, 1.0) # Cap at 1.0
|
| 205 |
+
|
| 206 |
+
# Calculate contribution score
|
| 207 |
+
# Based on whether similar tokens appear in output
|
| 208 |
+
contribution_score = 0.0
|
| 209 |
+
if generated_tokens:
|
| 210 |
+
# Check if token or similar tokens appear in output
|
| 211 |
+
if token in generated_tokens:
|
| 212 |
+
contribution_score = 1.0
|
| 213 |
+
elif any(token.lower() in gen_token.lower() for gen_token in generated_tokens):
|
| 214 |
+
contribution_score = 0.5
|
| 215 |
+
|
| 216 |
+
token_efficiencies.append(TokenEfficiency(
|
| 217 |
+
token=token,
|
| 218 |
+
position=idx,
|
| 219 |
+
information_content=information_content,
|
| 220 |
+
redundancy_score=redundancy_score,
|
| 221 |
+
contribution_score=contribution_score
|
| 222 |
+
))
|
| 223 |
+
|
| 224 |
+
return token_efficiencies
|
| 225 |
+
|
| 226 |
+
def _identify_redundancy_patterns(
|
| 227 |
+
self,
|
| 228 |
+
example_tokens: List[str],
|
| 229 |
+
token_efficiencies: List[TokenEfficiency]
|
| 230 |
+
) -> Dict[str, float]:
|
| 231 |
+
"""Identify common redundancy patterns"""
|
| 232 |
+
|
| 233 |
+
patterns = {
|
| 234 |
+
'repeated_tokens': 0.0,
|
| 235 |
+
'boilerplate': 0.0,
|
| 236 |
+
'structural_repetition': 0.0,
|
| 237 |
+
'semantic_overlap': 0.0
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
# Count repeated tokens
|
| 241 |
+
token_counts = {}
|
| 242 |
+
for token in example_tokens:
|
| 243 |
+
token_counts[token] = token_counts.get(token, 0) + 1
|
| 244 |
+
|
| 245 |
+
repeated = sum(1 for count in token_counts.values() if count > 3)
|
| 246 |
+
patterns['repeated_tokens'] = repeated / max(len(token_counts), 1)
|
| 247 |
+
|
| 248 |
+
# Detect boilerplate (common programming patterns)
|
| 249 |
+
boilerplate_tokens = ['def', 'class', 'return', 'import', 'from', '"""', "'''"]
|
| 250 |
+
boilerplate_count = sum(1 for token in example_tokens if token in boilerplate_tokens)
|
| 251 |
+
patterns['boilerplate'] = boilerplate_count / max(len(example_tokens), 1)
|
| 252 |
+
|
| 253 |
+
# Detect structural repetition (same patterns)
|
| 254 |
+
# Look for sequences that repeat
|
| 255 |
+
sequence_length = 3
|
| 256 |
+
sequences = {}
|
| 257 |
+
for i in range(len(example_tokens) - sequence_length):
|
| 258 |
+
seq = tuple(example_tokens[i:i+sequence_length])
|
| 259 |
+
sequences[seq] = sequences.get(seq, 0) + 1
|
| 260 |
+
|
| 261 |
+
repeated_sequences = sum(1 for count in sequences.values() if count > 1)
|
| 262 |
+
patterns['structural_repetition'] = repeated_sequences / max(len(sequences), 1)
|
| 263 |
+
|
| 264 |
+
# Estimate semantic overlap (tokens with high redundancy scores)
|
| 265 |
+
high_redundancy = sum(1 for t in token_efficiencies if t.redundancy_score > 0.7)
|
| 266 |
+
patterns['semantic_overlap'] = high_redundancy / max(len(token_efficiencies), 1)
|
| 267 |
+
|
| 268 |
+
return patterns
|
| 269 |
+
|
| 270 |
+
def _calculate_optimal_example_count(
|
| 271 |
+
self,
|
| 272 |
+
example_efficiencies: List[ExampleEfficiency]
|
| 273 |
+
) -> int:
|
| 274 |
+
"""Determine the optimal number of examples based on marginal benefits"""
|
| 275 |
+
|
| 276 |
+
if not example_efficiencies:
|
| 277 |
+
return 0
|
| 278 |
+
|
| 279 |
+
# Find point where marginal benefit drops below threshold
|
| 280 |
+
threshold = 0.3 # Examples adding less than 30% benefit are not worth it
|
| 281 |
+
|
| 282 |
+
for idx, efficiency in enumerate(example_efficiencies):
|
| 283 |
+
if efficiency.marginal_benefit < threshold and idx > 0:
|
| 284 |
+
return idx
|
| 285 |
+
|
| 286 |
+
# If all examples have good marginal benefit, use all
|
| 287 |
+
return len(example_efficiencies)
|
| 288 |
+
|
| 289 |
+
def _calculate_compression_potential(
|
| 290 |
+
self,
|
| 291 |
+
token_efficiencies: List[TokenEfficiency]
|
| 292 |
+
) -> float:
|
| 293 |
+
"""Calculate how much the context could be compressed"""
|
| 294 |
+
|
| 295 |
+
if not token_efficiencies:
|
| 296 |
+
return 0.0
|
| 297 |
+
|
| 298 |
+
# Tokens with high redundancy and low contribution can be removed
|
| 299 |
+
removable = sum(
|
| 300 |
+
1 for t in token_efficiencies
|
| 301 |
+
if t.redundancy_score > 0.6 and t.contribution_score < 0.3
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
return removable / len(token_efficiencies)
|
| 305 |
+
|
| 306 |
+
def _calculate_attention_utilization(
|
| 307 |
+
self,
|
| 308 |
+
attention_weights: Optional[List[Dict]],
|
| 309 |
+
total_context_tokens: int
|
| 310 |
+
) -> float:
|
| 311 |
+
"""Calculate what percentage of context receives significant attention"""
|
| 312 |
+
|
| 313 |
+
if not attention_weights or total_context_tokens == 0:
|
| 314 |
+
return 0.0
|
| 315 |
+
|
| 316 |
+
# Aggregate attention across all layers and heads
|
| 317 |
+
attended_positions = set()
|
| 318 |
+
|
| 319 |
+
for record in attention_weights:
|
| 320 |
+
attn = record.get('attention')
|
| 321 |
+
if attn is not None and attn.dim() >= 3:
|
| 322 |
+
# Average across heads and look at which positions get attention
|
| 323 |
+
avg_attn = attn.mean(dim=1) # Average across heads
|
| 324 |
+
|
| 325 |
+
# Positions with attention > threshold are considered "utilized"
|
| 326 |
+
threshold = 0.05
|
| 327 |
+
high_attention = (avg_attn > threshold).nonzero(as_tuple=True)
|
| 328 |
+
|
| 329 |
+
if len(high_attention) > 1:
|
| 330 |
+
attended_positions.update(high_attention[1].tolist())
|
| 331 |
+
|
| 332 |
+
# Filter to only context positions
|
| 333 |
+
context_attended = [pos for pos in attended_positions if pos < total_context_tokens]
|
| 334 |
+
|
| 335 |
+
return len(context_attended) / total_context_tokens if total_context_tokens > 0 else 0.0
|
backend/icl_attention_extractor.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Real Attention Extraction for In-Context Learning Analysis
|
| 3 |
+
|
| 4 |
+
This module hooks into transformer models to extract actual attention weights
|
| 5 |
+
during generation, providing real data for ICL analysis.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import List, Dict, Tuple, Optional, Any
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class AttentionData:
|
| 19 |
+
"""Stores attention data from model generation"""
|
| 20 |
+
layer_attentions: List[torch.Tensor] # Attention from each layer
|
| 21 |
+
token_positions: List[int] # Position of each generated token
|
| 22 |
+
example_boundaries: List[Tuple[int, int]] # Start/end positions of examples
|
| 23 |
+
|
| 24 |
+
class AttentionExtractor:
|
| 25 |
+
"""Extracts real attention patterns from transformer models during generation"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, model, tokenizer):
|
| 28 |
+
self.model = model
|
| 29 |
+
self.tokenizer = tokenizer
|
| 30 |
+
self.device = next(model.parameters()).device
|
| 31 |
+
|
| 32 |
+
# Storage for attention during generation
|
| 33 |
+
self.attention_weights = []
|
| 34 |
+
self.handles = []
|
| 35 |
+
|
| 36 |
+
def register_hooks(self):
|
| 37 |
+
"""Register forward hooks to capture attention weights"""
|
| 38 |
+
self.clear_hooks()
|
| 39 |
+
|
| 40 |
+
# For CodeGen models, attention is in the transformer blocks
|
| 41 |
+
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
|
| 42 |
+
# Hook into each transformer layer
|
| 43 |
+
for i, layer in enumerate(self.model.transformer.h):
|
| 44 |
+
if hasattr(layer, 'attn'):
|
| 45 |
+
handle = layer.attn.register_forward_hook(
|
| 46 |
+
lambda module, input, output, layer_idx=i:
|
| 47 |
+
self._attention_hook(module, input, output, layer_idx)
|
| 48 |
+
)
|
| 49 |
+
self.handles.append(handle)
|
| 50 |
+
|
| 51 |
+
logger.info(f"Registered {len(self.handles)} attention hooks")
|
| 52 |
+
|
| 53 |
+
def _attention_hook(self, module, input, output, layer_idx):
|
| 54 |
+
"""Hook function to capture attention weights"""
|
| 55 |
+
# For CodeGen, output is (hidden_states, attention_weights)
|
| 56 |
+
if isinstance(output, tuple) and len(output) >= 2:
|
| 57 |
+
attention = output[1]
|
| 58 |
+
if attention is not None:
|
| 59 |
+
# Store attention weights
|
| 60 |
+
self.attention_weights.append({
|
| 61 |
+
'layer': layer_idx,
|
| 62 |
+
'attention': attention.detach().cpu()
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
def clear_hooks(self):
|
| 66 |
+
"""Remove all hooks"""
|
| 67 |
+
for handle in self.handles:
|
| 68 |
+
handle.remove()
|
| 69 |
+
self.handles = []
|
| 70 |
+
self.attention_weights = []
|
| 71 |
+
|
| 72 |
+
def extract_attention_with_generation(
|
| 73 |
+
self,
|
| 74 |
+
input_ids: torch.Tensor,
|
| 75 |
+
attention_mask: torch.Tensor,
|
| 76 |
+
max_new_tokens: int = 50,
|
| 77 |
+
temperature: float = 0.7
|
| 78 |
+
) -> Tuple[torch.Tensor, List[Dict], List[torch.Tensor]]:
|
| 79 |
+
"""Generate text while extracting attention patterns"""
|
| 80 |
+
|
| 81 |
+
# Register hooks before generation
|
| 82 |
+
self.register_hooks()
|
| 83 |
+
self.attention_weights = []
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
# Generate token by token to capture attention at each step
|
| 87 |
+
generated_ids = []
|
| 88 |
+
all_scores = [] # Store scores for confidence calculation
|
| 89 |
+
current_input_ids = input_ids.clone()
|
| 90 |
+
current_attention_mask = attention_mask.clone()
|
| 91 |
+
|
| 92 |
+
for _ in range(max_new_tokens):
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
# Forward pass through model
|
| 95 |
+
outputs = self.model(
|
| 96 |
+
input_ids=current_input_ids,
|
| 97 |
+
attention_mask=current_attention_mask,
|
| 98 |
+
use_cache=False, # Don't use cache to get full attention
|
| 99 |
+
output_attentions=True,
|
| 100 |
+
return_dict=True
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Capture attention from outputs if hooks didn't get it
|
| 104 |
+
if hasattr(outputs, 'attentions') and outputs.attentions is not None:
|
| 105 |
+
for layer_idx, attn in enumerate(outputs.attentions):
|
| 106 |
+
self.attention_weights.append({
|
| 107 |
+
'layer': layer_idx,
|
| 108 |
+
'attention': attn.detach().cpu()
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
# Get next token logits
|
| 112 |
+
next_token_logits = outputs.logits[:, -1, :]
|
| 113 |
+
|
| 114 |
+
# Store the scores
|
| 115 |
+
all_scores.append(next_token_logits)
|
| 116 |
+
|
| 117 |
+
# Apply temperature
|
| 118 |
+
if temperature > 0:
|
| 119 |
+
next_token_logits = next_token_logits / temperature
|
| 120 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 121 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 122 |
+
else:
|
| 123 |
+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
| 124 |
+
|
| 125 |
+
# Stop if EOS token
|
| 126 |
+
if next_token.item() == self.tokenizer.eos_token_id:
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
# Append token
|
| 130 |
+
generated_ids.append(next_token.item())
|
| 131 |
+
current_input_ids = torch.cat([current_input_ids, next_token], dim=1)
|
| 132 |
+
current_attention_mask = torch.cat([
|
| 133 |
+
current_attention_mask,
|
| 134 |
+
torch.ones((1, 1), device=self.device)
|
| 135 |
+
], dim=1)
|
| 136 |
+
|
| 137 |
+
# Convert to tensor
|
| 138 |
+
if generated_ids:
|
| 139 |
+
generated_tensor = torch.tensor(generated_ids, device=self.device).unsqueeze(0)
|
| 140 |
+
else:
|
| 141 |
+
generated_tensor = torch.tensor([[]], device=self.device, dtype=torch.long)
|
| 142 |
+
|
| 143 |
+
return generated_tensor, self.attention_weights, all_scores
|
| 144 |
+
|
| 145 |
+
finally:
|
| 146 |
+
# Always clear hooks after generation
|
| 147 |
+
self.clear_hooks()
|
| 148 |
+
|
| 149 |
+
def aggregate_attention_to_examples(
|
| 150 |
+
self,
|
| 151 |
+
attention_data: List[Dict],
|
| 152 |
+
example_boundaries: List[Tuple[int, int]],
|
| 153 |
+
prompt_length: int
|
| 154 |
+
) -> Dict[str, List[float]]:
|
| 155 |
+
"""
|
| 156 |
+
Aggregate attention from generated tokens back to example regions
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Dict mapping example_id -> list of attention weights per generated token
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
if not attention_data or not example_boundaries:
|
| 163 |
+
return {}
|
| 164 |
+
|
| 165 |
+
attention_to_examples = {}
|
| 166 |
+
|
| 167 |
+
# Process attention for each generated token position
|
| 168 |
+
# We have attention data for each layer for each generated token
|
| 169 |
+
# Count unique positions based on attention data
|
| 170 |
+
num_layers = 20 # CodeGen has 20 layers
|
| 171 |
+
num_generated = len(attention_data) // num_layers if attention_data else 0
|
| 172 |
+
|
| 173 |
+
logger.info(f"Processing {len(attention_data)} attention records for {num_generated} generated tokens")
|
| 174 |
+
|
| 175 |
+
for example_idx, (start, end) in enumerate(example_boundaries):
|
| 176 |
+
example_id = str(example_idx + 1)
|
| 177 |
+
example_attention = []
|
| 178 |
+
|
| 179 |
+
# For each generated token
|
| 180 |
+
for gen_idx in range(num_generated):
|
| 181 |
+
# Aggregate attention across all layers for this generated position
|
| 182 |
+
total_attention = 0.0
|
| 183 |
+
|
| 184 |
+
# Get attention records for this generated position
|
| 185 |
+
layer_count = 0
|
| 186 |
+
for i, attn_record in enumerate(attention_data):
|
| 187 |
+
# Each generated token should have attention from all layers
|
| 188 |
+
# So records [gen_idx*num_layers:(gen_idx+1)*num_layers] correspond to gen_idx
|
| 189 |
+
if i >= gen_idx * num_layers and i < (gen_idx + 1) * num_layers:
|
| 190 |
+
if 'attention' in attn_record:
|
| 191 |
+
attn_tensor = attn_record['attention']
|
| 192 |
+
|
| 193 |
+
# Get attention from generated position to example region
|
| 194 |
+
if attn_tensor.dim() >= 3:
|
| 195 |
+
# Shape: [batch, heads, seq_len, seq_len]
|
| 196 |
+
# The last position in the attention matrix corresponds to the newly generated token
|
| 197 |
+
seq_len = attn_tensor.shape[-1]
|
| 198 |
+
|
| 199 |
+
# Average across heads, get attention from last position to example region
|
| 200 |
+
if end <= seq_len:
|
| 201 |
+
attn_to_example = attn_tensor[0, :, -1, start:end].mean().item()
|
| 202 |
+
total_attention += attn_to_example
|
| 203 |
+
layer_count += 1
|
| 204 |
+
|
| 205 |
+
# Average across layers
|
| 206 |
+
if layer_count > 0:
|
| 207 |
+
example_attention.append(total_attention / layer_count)
|
| 208 |
+
else:
|
| 209 |
+
example_attention.append(0.0)
|
| 210 |
+
|
| 211 |
+
attention_to_examples[example_id] = example_attention
|
| 212 |
+
|
| 213 |
+
# Normalize attention for each generated token
|
| 214 |
+
for gen_idx in range(num_generated):
|
| 215 |
+
total = sum(
|
| 216 |
+
attention_to_examples[ex_id][gen_idx]
|
| 217 |
+
for ex_id in attention_to_examples
|
| 218 |
+
if gen_idx < len(attention_to_examples[ex_id])
|
| 219 |
+
)
|
| 220 |
+
if total > 0:
|
| 221 |
+
for ex_id in attention_to_examples:
|
| 222 |
+
if gen_idx < len(attention_to_examples[ex_id]):
|
| 223 |
+
attention_to_examples[ex_id][gen_idx] /= total
|
| 224 |
+
|
| 225 |
+
return attention_to_examples
|
| 226 |
+
|
| 227 |
+
def calculate_example_influences(
|
| 228 |
+
self,
|
| 229 |
+
attention_to_examples: Dict[str, List[float]]
|
| 230 |
+
) -> Dict[str, float]:
|
| 231 |
+
"""
|
| 232 |
+
Calculate overall influence of each example based on attention patterns
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
Dict mapping example_id -> influence score (0-1)
|
| 236 |
+
"""
|
| 237 |
+
influences = {}
|
| 238 |
+
|
| 239 |
+
for example_id, attention_weights in attention_to_examples.items():
|
| 240 |
+
# Overall influence is the mean attention across all generated tokens
|
| 241 |
+
if attention_weights:
|
| 242 |
+
influences[example_id] = float(np.mean(attention_weights))
|
| 243 |
+
else:
|
| 244 |
+
influences[example_id] = 0.0
|
| 245 |
+
|
| 246 |
+
# Normalize to sum to 1
|
| 247 |
+
total = sum(influences.values())
|
| 248 |
+
if total > 0:
|
| 249 |
+
influences = {k: v/total for k, v in influences.items()}
|
| 250 |
+
|
| 251 |
+
return influences
|
backend/icl_service.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
In-Context Learning Analysis Service
|
| 3 |
+
|
| 4 |
+
Analyzes how examples influence model behavior during code generation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import List, Dict, Optional, Any, Tuple
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from .icl_attention_extractor import AttentionExtractor
|
| 14 |
+
from .induction_head_detector import InductionHeadDetector, ICLEmergenceAnalysis
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ICLExample:
|
| 21 |
+
"""Represents an in-context learning example"""
|
| 22 |
+
input: str
|
| 23 |
+
output: str
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ICLAnalysisResult:
|
| 27 |
+
"""Results from ICL analysis"""
|
| 28 |
+
shot_count: int
|
| 29 |
+
generated_code: str
|
| 30 |
+
tokens: List[str]
|
| 31 |
+
confidence_scores: List[float]
|
| 32 |
+
attention_from_examples: Dict[str, List[float]] # example_id -> attention weights per token
|
| 33 |
+
perplexity: float
|
| 34 |
+
avg_confidence: float
|
| 35 |
+
example_influences: Dict[str, float] # example_id -> overall influence score
|
| 36 |
+
hidden_state_drift: Optional[List[float]] = None # magnitude of hidden state changes
|
| 37 |
+
icl_emergence: Optional[ICLEmergenceAnalysis] = None # When/how ICL kicks in
|
| 38 |
+
|
| 39 |
+
class ICLAnalyzer:
|
| 40 |
+
"""Analyzes in-context learning effects on model behavior"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
|
| 43 |
+
self.model = model
|
| 44 |
+
self.tokenizer = tokenizer
|
| 45 |
+
self.device = next(model.parameters()).device
|
| 46 |
+
|
| 47 |
+
# Initialize attention extractor for real attention data
|
| 48 |
+
self.attention_extractor = AttentionExtractor(model, tokenizer)
|
| 49 |
+
|
| 50 |
+
# Initialize induction head detector
|
| 51 |
+
self.induction_detector = InductionHeadDetector(model, tokenizer)
|
| 52 |
+
|
| 53 |
+
# Storage for attention patterns
|
| 54 |
+
self.attention_maps = []
|
| 55 |
+
self.hidden_states = []
|
| 56 |
+
|
| 57 |
+
def prepare_prompt_with_examples(self, examples: List[ICLExample], test_prompt: str) -> str:
|
| 58 |
+
"""Construct prompt with examples in standard format"""
|
| 59 |
+
if not examples:
|
| 60 |
+
return test_prompt
|
| 61 |
+
|
| 62 |
+
prompt_parts = []
|
| 63 |
+
for example in examples:
|
| 64 |
+
prompt_parts.append(f"{example.input}\n{example.output}\n")
|
| 65 |
+
prompt_parts.append(test_prompt)
|
| 66 |
+
|
| 67 |
+
return "\n".join(prompt_parts)
|
| 68 |
+
|
| 69 |
+
def extract_attention_patterns(self, outputs, input_ids, example_boundaries: List[Tuple[int, int]]) -> Dict[str, List[float]]:
|
| 70 |
+
"""Extract attention patterns - real if available, simulated otherwise"""
|
| 71 |
+
|
| 72 |
+
# Try to use real attention data if available
|
| 73 |
+
if hasattr(self, 'last_attention_data') and self.last_attention_data:
|
| 74 |
+
logger.info("Using real attention data from model hooks")
|
| 75 |
+
prompt_length = len(input_ids[0])
|
| 76 |
+
return self.attention_extractor.aggregate_attention_to_examples(
|
| 77 |
+
self.last_attention_data,
|
| 78 |
+
example_boundaries,
|
| 79 |
+
prompt_length
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Fall back to simulated patterns
|
| 83 |
+
logger.info("Using simulated attention patterns")
|
| 84 |
+
attention_from_examples = {}
|
| 85 |
+
|
| 86 |
+
if not example_boundaries:
|
| 87 |
+
return attention_from_examples
|
| 88 |
+
|
| 89 |
+
generated_ids = outputs.sequences[0][len(input_ids[0]):]
|
| 90 |
+
num_generated = len(generated_ids)
|
| 91 |
+
|
| 92 |
+
if num_generated == 0:
|
| 93 |
+
return attention_from_examples
|
| 94 |
+
|
| 95 |
+
# Create simulated patterns (existing code)
|
| 96 |
+
for idx, (start, end) in enumerate(example_boundaries):
|
| 97 |
+
example_id = str(idx + 1)
|
| 98 |
+
base_weight = 0.3 + (idx * 0.1) / len(example_boundaries)
|
| 99 |
+
|
| 100 |
+
attention_weights = []
|
| 101 |
+
for token_idx in range(num_generated):
|
| 102 |
+
weight = base_weight * np.exp(-token_idx * 0.05)
|
| 103 |
+
weight += np.random.normal(0, 0.02)
|
| 104 |
+
weight = max(0, min(1, weight))
|
| 105 |
+
attention_weights.append(weight)
|
| 106 |
+
|
| 107 |
+
attention_from_examples[example_id] = attention_weights
|
| 108 |
+
|
| 109 |
+
# Normalize
|
| 110 |
+
if len(attention_from_examples) > 1:
|
| 111 |
+
for token_idx in range(num_generated):
|
| 112 |
+
total = sum(weights[token_idx] for weights in attention_from_examples.values())
|
| 113 |
+
if total > 0:
|
| 114 |
+
for example_id in attention_from_examples:
|
| 115 |
+
attention_from_examples[example_id][token_idx] /= total
|
| 116 |
+
|
| 117 |
+
return attention_from_examples
|
| 118 |
+
|
| 119 |
+
def calculate_example_influences(self, attention_from_examples: Dict[str, List[float]]) -> Dict[str, float]:
|
| 120 |
+
"""Calculate overall influence score for each example"""
|
| 121 |
+
|
| 122 |
+
# If we have real attention data, use the extractor's method
|
| 123 |
+
if hasattr(self, 'last_attention_data') and self.last_attention_data:
|
| 124 |
+
return self.attention_extractor.calculate_example_influences(attention_from_examples)
|
| 125 |
+
|
| 126 |
+
# Otherwise use existing calculation
|
| 127 |
+
influences = {}
|
| 128 |
+
|
| 129 |
+
for example_id, weights in attention_from_examples.items():
|
| 130 |
+
influences[example_id] = float(np.mean(weights)) if weights else 0.0
|
| 131 |
+
|
| 132 |
+
total = sum(influences.values())
|
| 133 |
+
if total > 0 and total != 1.0:
|
| 134 |
+
influences = {k: v/total for k, v in influences.items()}
|
| 135 |
+
|
| 136 |
+
return influences
|
| 137 |
+
|
| 138 |
+
def track_hidden_state_drift(self, base_hidden_states, example_hidden_states) -> List[float]:
|
| 139 |
+
"""Track how hidden states change from base (no examples) to with examples"""
|
| 140 |
+
if base_hidden_states is None or example_hidden_states is None:
|
| 141 |
+
return []
|
| 142 |
+
|
| 143 |
+
# Calculate L2 distance between hidden states at each position
|
| 144 |
+
drift = []
|
| 145 |
+
min_len = min(len(base_hidden_states), len(example_hidden_states))
|
| 146 |
+
|
| 147 |
+
for i in range(min_len):
|
| 148 |
+
base = base_hidden_states[i]
|
| 149 |
+
example = example_hidden_states[i]
|
| 150 |
+
|
| 151 |
+
if isinstance(base, torch.Tensor):
|
| 152 |
+
base = base.cpu().numpy()
|
| 153 |
+
if isinstance(example, torch.Tensor):
|
| 154 |
+
example = example.cpu().numpy()
|
| 155 |
+
|
| 156 |
+
distance = np.linalg.norm(example - base)
|
| 157 |
+
drift.append(float(distance))
|
| 158 |
+
|
| 159 |
+
return drift
|
| 160 |
+
|
| 161 |
+
def analyze_generation(
|
| 162 |
+
self,
|
| 163 |
+
examples: List[ICLExample],
|
| 164 |
+
test_prompt: str,
|
| 165 |
+
max_length: int = 150,
|
| 166 |
+
temperature: float = 0.7,
|
| 167 |
+
base_hidden_states: Optional[Any] = None
|
| 168 |
+
) -> ICLAnalysisResult:
|
| 169 |
+
"""Analyze how examples influence generation"""
|
| 170 |
+
|
| 171 |
+
# Prepare prompt
|
| 172 |
+
full_prompt = self.prepare_prompt_with_examples(examples, test_prompt)
|
| 173 |
+
|
| 174 |
+
# Tokenize
|
| 175 |
+
inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
|
| 176 |
+
input_ids = inputs["input_ids"].to(self.device)
|
| 177 |
+
attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(self.device)
|
| 178 |
+
|
| 179 |
+
# Find example boundaries in token space
|
| 180 |
+
example_boundaries = []
|
| 181 |
+
if examples:
|
| 182 |
+
current_pos = 0
|
| 183 |
+
for example in examples:
|
| 184 |
+
example_text = f"{example.input}\n{example.output}\n"
|
| 185 |
+
example_tokens = self.tokenizer(example_text, add_special_tokens=False)["input_ids"]
|
| 186 |
+
example_boundaries.append((current_pos, current_pos + len(example_tokens)))
|
| 187 |
+
current_pos += len(example_tokens)
|
| 188 |
+
|
| 189 |
+
# First do standard generation to get scores and text
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
outputs = self.model.generate(
|
| 192 |
+
input_ids,
|
| 193 |
+
attention_mask=attention_mask,
|
| 194 |
+
max_length=max_length,
|
| 195 |
+
temperature=temperature,
|
| 196 |
+
do_sample=temperature > 0,
|
| 197 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 198 |
+
return_dict_in_generate=True,
|
| 199 |
+
output_scores=True,
|
| 200 |
+
output_hidden_states=False
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Then try to extract real attention data
|
| 204 |
+
try:
|
| 205 |
+
logger.info("Extracting real attention data")
|
| 206 |
+
_, attention_data, _ = self.attention_extractor.extract_attention_with_generation(
|
| 207 |
+
input_ids=input_ids,
|
| 208 |
+
attention_mask=attention_mask,
|
| 209 |
+
max_new_tokens=min(30, max_length - len(input_ids[0])), # Limit for performance
|
| 210 |
+
temperature=temperature
|
| 211 |
+
)
|
| 212 |
+
self.last_attention_data = attention_data
|
| 213 |
+
logger.info(f"Successfully extracted {len(attention_data)} attention records")
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.warning(f"Real attention extraction failed: {e}")
|
| 216 |
+
self.last_attention_data = None
|
| 217 |
+
|
| 218 |
+
# Extract generated tokens - show raw output, no trimming
|
| 219 |
+
generated_ids = outputs.sequences[0][len(input_ids[0]):]
|
| 220 |
+
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 221 |
+
tokens = [self.tokenizer.decode([token_id]) for token_id in generated_ids]
|
| 222 |
+
|
| 223 |
+
# Calculate confidence scores
|
| 224 |
+
confidence_scores = []
|
| 225 |
+
if outputs.scores:
|
| 226 |
+
for score in outputs.scores:
|
| 227 |
+
probs = F.softmax(score[0], dim=-1)
|
| 228 |
+
max_prob = probs.max().item()
|
| 229 |
+
confidence_scores.append(max_prob)
|
| 230 |
+
|
| 231 |
+
# Calculate perplexity
|
| 232 |
+
if outputs.scores:
|
| 233 |
+
log_probs = []
|
| 234 |
+
for i, score in enumerate(outputs.scores):
|
| 235 |
+
if i < len(generated_ids):
|
| 236 |
+
token_id = generated_ids[i]
|
| 237 |
+
log_prob = F.log_softmax(score[0], dim=-1)[token_id].item()
|
| 238 |
+
log_probs.append(log_prob)
|
| 239 |
+
perplexity = np.exp(-np.mean(log_probs)) if log_probs else 0.0
|
| 240 |
+
else:
|
| 241 |
+
perplexity = 0.0
|
| 242 |
+
|
| 243 |
+
# Extract attention patterns
|
| 244 |
+
attention_from_examples = self.extract_attention_patterns(outputs, input_ids, example_boundaries)
|
| 245 |
+
|
| 246 |
+
# Calculate example influences
|
| 247 |
+
example_influences = self.calculate_example_influences(attention_from_examples)
|
| 248 |
+
|
| 249 |
+
# Track hidden state drift if base states provided
|
| 250 |
+
hidden_state_drift = None
|
| 251 |
+
if base_hidden_states is not None and hasattr(outputs, 'hidden_states'):
|
| 252 |
+
current_hidden = outputs.hidden_states[-1] if outputs.hidden_states else None
|
| 253 |
+
if current_hidden is not None:
|
| 254 |
+
hidden_state_drift = self.track_hidden_state_drift(base_hidden_states, current_hidden)
|
| 255 |
+
|
| 256 |
+
# Analyze ICL emergence if we have attention data and examples
|
| 257 |
+
icl_emergence = None
|
| 258 |
+
if self.last_attention_data and len(examples) > 0:
|
| 259 |
+
try:
|
| 260 |
+
icl_emergence = self.induction_detector.analyze_icl_emergence(
|
| 261 |
+
self.last_attention_data,
|
| 262 |
+
input_ids,
|
| 263 |
+
example_boundaries,
|
| 264 |
+
generated_ids.tolist() if generated_ids.numel() > 0 else []
|
| 265 |
+
)
|
| 266 |
+
logger.info(f"ICL emergence analysis: detected={icl_emergence.emergence_detected}, "
|
| 267 |
+
f"token={icl_emergence.emergence_token}, confidence={icl_emergence.confidence:.2f}")
|
| 268 |
+
except Exception as e:
|
| 269 |
+
logger.warning(f"ICL emergence analysis failed: {e}")
|
| 270 |
+
|
| 271 |
+
return ICLAnalysisResult(
|
| 272 |
+
shot_count=len(examples),
|
| 273 |
+
generated_code=generated_text,
|
| 274 |
+
tokens=tokens,
|
| 275 |
+
confidence_scores=confidence_scores,
|
| 276 |
+
attention_from_examples=attention_from_examples,
|
| 277 |
+
perplexity=perplexity,
|
| 278 |
+
avg_confidence=np.mean(confidence_scores) if confidence_scores else 0.0,
|
| 279 |
+
example_influences=example_influences,
|
| 280 |
+
hidden_state_drift=hidden_state_drift,
|
| 281 |
+
icl_emergence=icl_emergence
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def compare_shot_settings(
|
| 285 |
+
self,
|
| 286 |
+
examples: List[ICLExample],
|
| 287 |
+
test_prompt: str,
|
| 288 |
+
max_length: int = 150,
|
| 289 |
+
temperature: float = 0.7
|
| 290 |
+
) -> Dict[str, ICLAnalysisResult]:
|
| 291 |
+
"""Compare 0-shot, 1-shot, and few-shot generation"""
|
| 292 |
+
results = {}
|
| 293 |
+
|
| 294 |
+
# 0-shot (no examples)
|
| 295 |
+
results['zero_shot'] = self.analyze_generation([], test_prompt, max_length, temperature)
|
| 296 |
+
base_hidden = results['zero_shot'].hidden_state_drift # Use as baseline
|
| 297 |
+
|
| 298 |
+
# 1-shot (first example only)
|
| 299 |
+
if len(examples) >= 1:
|
| 300 |
+
results['one_shot'] = self.analyze_generation(
|
| 301 |
+
examples[:1], test_prompt, max_length, temperature, base_hidden
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Few-shot (all examples)
|
| 305 |
+
if len(examples) >= 2:
|
| 306 |
+
results['few_shot'] = self.analyze_generation(
|
| 307 |
+
examples, test_prompt, max_length, temperature, base_hidden
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
return results
|
backend/induction_head_detector.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Induction Head Detection for In-Context Learning
|
| 3 |
+
|
| 4 |
+
Based on research showing that ICL emerges abruptly in transformers through
|
| 5 |
+
the formation of induction heads - attention patterns that copy from context.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import List, Dict, Tuple, Optional
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class InductionHeadSignal:
|
| 18 |
+
"""Signals indicating induction head behavior"""
|
| 19 |
+
layer: int
|
| 20 |
+
head: int
|
| 21 |
+
strength: float # 0-1 score of induction pattern strength
|
| 22 |
+
pattern_type: str # 'copy', 'prefix_match', 'abstract'
|
| 23 |
+
emergence_point: Optional[int] # Token position where pattern emerges
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ICLEmergenceAnalysis:
|
| 27 |
+
"""Analysis of when and how ICL emerges"""
|
| 28 |
+
emergence_detected: bool
|
| 29 |
+
emergence_token: Optional[int] # Token position where ICL kicks in
|
| 30 |
+
emergence_layer: Optional[int] # Layer where strongest signal appears
|
| 31 |
+
confidence: float # Confidence in detection (0-1)
|
| 32 |
+
induction_heads: List[InductionHeadSignal]
|
| 33 |
+
attention_entropy_drop: List[float] # Entropy at each position
|
| 34 |
+
pattern_consistency: float # How consistent the pattern is
|
| 35 |
+
|
| 36 |
+
class InductionHeadDetector:
|
| 37 |
+
"""Detects induction heads and ICL emergence in transformer models"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, model, tokenizer):
|
| 40 |
+
self.model = model
|
| 41 |
+
self.tokenizer = tokenizer
|
| 42 |
+
self.device = next(model.parameters()).device
|
| 43 |
+
|
| 44 |
+
def detect_induction_heads(
|
| 45 |
+
self,
|
| 46 |
+
attention_weights: List[Dict],
|
| 47 |
+
input_ids: torch.Tensor,
|
| 48 |
+
example_boundaries: List[Tuple[int, int]]
|
| 49 |
+
) -> List[InductionHeadSignal]:
|
| 50 |
+
"""
|
| 51 |
+
Detect induction heads by looking for attention patterns that:
|
| 52 |
+
1. Copy from previous occurrences (classic induction)
|
| 53 |
+
2. Match prefixes across examples
|
| 54 |
+
3. Show abstract pattern matching
|
| 55 |
+
"""
|
| 56 |
+
induction_heads = []
|
| 57 |
+
|
| 58 |
+
if not attention_weights or not example_boundaries:
|
| 59 |
+
return induction_heads
|
| 60 |
+
|
| 61 |
+
# Analyze each layer and head
|
| 62 |
+
layers_analyzed = {}
|
| 63 |
+
for record in attention_weights:
|
| 64 |
+
layer_idx = record.get('layer', 0)
|
| 65 |
+
attn = record.get('attention')
|
| 66 |
+
|
| 67 |
+
if attn is None or layer_idx in layers_analyzed:
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
layers_analyzed[layer_idx] = True
|
| 71 |
+
|
| 72 |
+
# Analyze each attention head
|
| 73 |
+
if attn.dim() >= 3:
|
| 74 |
+
num_heads = attn.shape[1]
|
| 75 |
+
seq_len = attn.shape[-1]
|
| 76 |
+
|
| 77 |
+
for head_idx in range(num_heads):
|
| 78 |
+
head_attn = attn[0, head_idx] # [seq_len, seq_len]
|
| 79 |
+
|
| 80 |
+
# Detect different induction patterns
|
| 81 |
+
copy_score = self._detect_copy_pattern(head_attn, input_ids)
|
| 82 |
+
prefix_score = self._detect_prefix_matching(head_attn, example_boundaries)
|
| 83 |
+
abstract_score = self._detect_abstract_pattern(head_attn, seq_len)
|
| 84 |
+
|
| 85 |
+
# Determine strongest pattern
|
| 86 |
+
max_score = max(copy_score, prefix_score, abstract_score)
|
| 87 |
+
if max_score > 0.3: # Threshold for significant pattern
|
| 88 |
+
pattern_type = 'copy' if copy_score == max_score else \
|
| 89 |
+
'prefix_match' if prefix_score == max_score else 'abstract'
|
| 90 |
+
|
| 91 |
+
# Find emergence point (where pattern suddenly strengthens)
|
| 92 |
+
emergence_point = self._find_emergence_point(head_attn)
|
| 93 |
+
|
| 94 |
+
induction_heads.append(InductionHeadSignal(
|
| 95 |
+
layer=layer_idx,
|
| 96 |
+
head=head_idx,
|
| 97 |
+
strength=max_score,
|
| 98 |
+
pattern_type=pattern_type,
|
| 99 |
+
emergence_point=emergence_point
|
| 100 |
+
))
|
| 101 |
+
|
| 102 |
+
return induction_heads
|
| 103 |
+
|
| 104 |
+
def _detect_copy_pattern(self, attn_matrix: torch.Tensor, input_ids: torch.Tensor) -> float:
|
| 105 |
+
"""Detect if attention head copies from previous occurrences"""
|
| 106 |
+
seq_len = attn_matrix.shape[0]
|
| 107 |
+
copy_score = 0.0
|
| 108 |
+
count = 0
|
| 109 |
+
|
| 110 |
+
# Look for positions that attend strongly to previous same/similar tokens
|
| 111 |
+
for i in range(1, min(seq_len, 50)): # Limit analysis for efficiency
|
| 112 |
+
if i >= len(input_ids[0]):
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
current_token = input_ids[0][i].item()
|
| 116 |
+
|
| 117 |
+
# Find previous occurrences of the same token
|
| 118 |
+
for j in range(i):
|
| 119 |
+
if j < len(input_ids[0]) and input_ids[0][j].item() == current_token:
|
| 120 |
+
# Check if attention is strong to this position
|
| 121 |
+
if attn_matrix[i, j] > 0.1: # Threshold for significant attention
|
| 122 |
+
copy_score += attn_matrix[i, j].item()
|
| 123 |
+
count += 1
|
| 124 |
+
|
| 125 |
+
return copy_score / max(count, 1)
|
| 126 |
+
|
| 127 |
+
def _detect_prefix_matching(
|
| 128 |
+
self,
|
| 129 |
+
attn_matrix: torch.Tensor,
|
| 130 |
+
example_boundaries: List[Tuple[int, int]]
|
| 131 |
+
) -> float:
|
| 132 |
+
"""Detect if attention matches prefixes across examples"""
|
| 133 |
+
if len(example_boundaries) < 2:
|
| 134 |
+
return 0.0
|
| 135 |
+
|
| 136 |
+
prefix_score = 0.0
|
| 137 |
+
count = 0
|
| 138 |
+
|
| 139 |
+
# Check if tokens attend to similar positions in different examples
|
| 140 |
+
for i, (start1, end1) in enumerate(example_boundaries[:-1]):
|
| 141 |
+
for j, (start2, end2) in enumerate(example_boundaries[i+1:], i+1):
|
| 142 |
+
# Compare attention patterns between examples
|
| 143 |
+
for offset in range(min(5, end1-start1, end2-start2)): # Check first 5 tokens
|
| 144 |
+
pos1 = start1 + offset
|
| 145 |
+
pos2 = start2 + offset
|
| 146 |
+
|
| 147 |
+
if pos1 < attn_matrix.shape[0] and pos2 < attn_matrix.shape[1]:
|
| 148 |
+
# Check if later example attends to earlier example at same offset
|
| 149 |
+
if pos2 < attn_matrix.shape[0] and pos1 < attn_matrix.shape[1]:
|
| 150 |
+
attention_strength = attn_matrix[pos2, pos1].item()
|
| 151 |
+
if attention_strength > 0.1:
|
| 152 |
+
prefix_score += attention_strength
|
| 153 |
+
count += 1
|
| 154 |
+
|
| 155 |
+
return prefix_score / max(count, 1)
|
| 156 |
+
|
| 157 |
+
def _detect_abstract_pattern(self, attn_matrix: torch.Tensor, seq_len: int) -> float:
|
| 158 |
+
"""Detect abstract pattern matching (e.g., function->function mapping)"""
|
| 159 |
+
# Look for diagonal patterns offset by example length
|
| 160 |
+
# This indicates attending to structurally similar positions
|
| 161 |
+
|
| 162 |
+
abstract_score = 0.0
|
| 163 |
+
window_size = 10
|
| 164 |
+
|
| 165 |
+
for i in range(window_size, min(seq_len, 50)):
|
| 166 |
+
# Check if attention follows a diagonal pattern with offset
|
| 167 |
+
diagonal_sum = 0.0
|
| 168 |
+
for offset in range(1, min(window_size, i)):
|
| 169 |
+
if i - offset >= 0:
|
| 170 |
+
diagonal_sum += attn_matrix[i, i - offset].item()
|
| 171 |
+
|
| 172 |
+
# High diagonal attention indicates structural copying
|
| 173 |
+
if diagonal_sum / window_size > 0.1:
|
| 174 |
+
abstract_score += diagonal_sum / window_size
|
| 175 |
+
|
| 176 |
+
return min(abstract_score / 10, 1.0) # Normalize
|
| 177 |
+
|
| 178 |
+
def _find_emergence_point(self, attn_matrix: torch.Tensor) -> Optional[int]:
|
| 179 |
+
"""Find the token position where the pattern suddenly emerges"""
|
| 180 |
+
seq_len = min(attn_matrix.shape[0], 50) # Limit for efficiency
|
| 181 |
+
|
| 182 |
+
if seq_len < 10:
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
# Calculate attention entropy at each position
|
| 186 |
+
entropies = []
|
| 187 |
+
for i in range(seq_len):
|
| 188 |
+
attn_dist = attn_matrix[i, :i+1] # Only look at previous positions
|
| 189 |
+
if attn_dist.sum() > 0:
|
| 190 |
+
attn_dist = attn_dist / attn_dist.sum()
|
| 191 |
+
# Calculate entropy
|
| 192 |
+
entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
|
| 193 |
+
entropies.append(entropy)
|
| 194 |
+
else:
|
| 195 |
+
entropies.append(0.0)
|
| 196 |
+
|
| 197 |
+
# Find sudden drops in entropy (indicating focused attention)
|
| 198 |
+
if len(entropies) < 5:
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
for i in range(4, len(entropies)):
|
| 202 |
+
recent_avg = np.mean(entropies[i-4:i])
|
| 203 |
+
current = entropies[i]
|
| 204 |
+
|
| 205 |
+
# Sudden drop indicates emergence
|
| 206 |
+
if recent_avg > 0 and current < recent_avg * 0.5:
|
| 207 |
+
return i
|
| 208 |
+
|
| 209 |
+
return None
|
| 210 |
+
|
| 211 |
+
def analyze_icl_emergence(
|
| 212 |
+
self,
|
| 213 |
+
attention_weights: List[Dict],
|
| 214 |
+
input_ids: torch.Tensor,
|
| 215 |
+
example_boundaries: List[Tuple[int, int]],
|
| 216 |
+
generated_tokens: List[int]
|
| 217 |
+
) -> ICLEmergenceAnalysis:
|
| 218 |
+
"""
|
| 219 |
+
Comprehensive analysis of when and how ICL emerges during generation
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
# Detect induction heads
|
| 223 |
+
induction_heads = self.detect_induction_heads(
|
| 224 |
+
attention_weights, input_ids, example_boundaries
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Calculate attention entropy trajectory
|
| 228 |
+
entropy_trajectory = self._calculate_entropy_trajectory(
|
| 229 |
+
attention_weights, len(generated_tokens)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Determine emergence point
|
| 233 |
+
emergence_token = None
|
| 234 |
+
emergence_layer = None
|
| 235 |
+
emergence_confidence = 0.0
|
| 236 |
+
|
| 237 |
+
if induction_heads:
|
| 238 |
+
# Find strongest induction signal
|
| 239 |
+
strongest_head = max(induction_heads, key=lambda h: h.strength)
|
| 240 |
+
|
| 241 |
+
# Check for consistent emergence points across heads
|
| 242 |
+
emergence_points = [h.emergence_point for h in induction_heads if h.emergence_point]
|
| 243 |
+
if emergence_points:
|
| 244 |
+
# Most common emergence point
|
| 245 |
+
emergence_token = int(np.median(emergence_points))
|
| 246 |
+
emergence_layer = strongest_head.layer
|
| 247 |
+
|
| 248 |
+
# Confidence based on consistency and strength
|
| 249 |
+
consistency = len(emergence_points) / len(induction_heads)
|
| 250 |
+
emergence_confidence = min(strongest_head.strength * consistency, 1.0)
|
| 251 |
+
|
| 252 |
+
# Check for entropy drop as additional signal
|
| 253 |
+
if entropy_trajectory and len(entropy_trajectory) > 5:
|
| 254 |
+
for i in range(5, len(entropy_trajectory)):
|
| 255 |
+
recent_avg = np.mean(entropy_trajectory[i-5:i])
|
| 256 |
+
if recent_avg > 0 and entropy_trajectory[i] < recent_avg * 0.6:
|
| 257 |
+
if emergence_token is None:
|
| 258 |
+
emergence_token = i
|
| 259 |
+
emergence_confidence = 0.5
|
| 260 |
+
break
|
| 261 |
+
|
| 262 |
+
# Calculate pattern consistency
|
| 263 |
+
pattern_consistency = self._calculate_pattern_consistency(induction_heads)
|
| 264 |
+
|
| 265 |
+
return ICLEmergenceAnalysis(
|
| 266 |
+
emergence_detected=emergence_token is not None,
|
| 267 |
+
emergence_token=emergence_token,
|
| 268 |
+
emergence_layer=emergence_layer,
|
| 269 |
+
confidence=emergence_confidence,
|
| 270 |
+
induction_heads=induction_heads,
|
| 271 |
+
attention_entropy_drop=entropy_trajectory,
|
| 272 |
+
pattern_consistency=pattern_consistency
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
def _calculate_entropy_trajectory(
|
| 276 |
+
self,
|
| 277 |
+
attention_weights: List[Dict],
|
| 278 |
+
num_generated: int
|
| 279 |
+
) -> List[float]:
|
| 280 |
+
"""Calculate attention entropy at each generated position"""
|
| 281 |
+
entropies = []
|
| 282 |
+
|
| 283 |
+
if not attention_weights:
|
| 284 |
+
return entropies
|
| 285 |
+
|
| 286 |
+
# Group attention by position
|
| 287 |
+
num_layers = 20 # CodeGen model
|
| 288 |
+
|
| 289 |
+
for gen_idx in range(num_generated):
|
| 290 |
+
position_entropy = []
|
| 291 |
+
|
| 292 |
+
# Get attention for this generated position across all layers
|
| 293 |
+
for i in range(gen_idx * num_layers, min((gen_idx + 1) * num_layers, len(attention_weights))):
|
| 294 |
+
if i < len(attention_weights):
|
| 295 |
+
attn = attention_weights[i].get('attention')
|
| 296 |
+
if attn is not None and attn.dim() >= 3:
|
| 297 |
+
# Average across heads
|
| 298 |
+
avg_attn = attn[0].mean(dim=0)
|
| 299 |
+
if avg_attn.shape[0] > gen_idx:
|
| 300 |
+
# Get attention distribution for this position
|
| 301 |
+
attn_dist = avg_attn[-1] # Last position is newly generated
|
| 302 |
+
if attn_dist.sum() > 0:
|
| 303 |
+
attn_dist = attn_dist / attn_dist.sum()
|
| 304 |
+
# Calculate entropy
|
| 305 |
+
entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
|
| 306 |
+
position_entropy.append(entropy)
|
| 307 |
+
|
| 308 |
+
if position_entropy:
|
| 309 |
+
entropies.append(np.mean(position_entropy))
|
| 310 |
+
else:
|
| 311 |
+
entropies.append(0.0)
|
| 312 |
+
|
| 313 |
+
return entropies
|
| 314 |
+
|
| 315 |
+
def _calculate_pattern_consistency(self, induction_heads: List[InductionHeadSignal]) -> float:
|
| 316 |
+
"""Calculate how consistent the induction patterns are across heads"""
|
| 317 |
+
if not induction_heads:
|
| 318 |
+
return 0.0
|
| 319 |
+
|
| 320 |
+
# Group by pattern type
|
| 321 |
+
pattern_counts = {}
|
| 322 |
+
for head in induction_heads:
|
| 323 |
+
pattern_counts[head.pattern_type] = pattern_counts.get(head.pattern_type, 0) + 1
|
| 324 |
+
|
| 325 |
+
# Consistency is ratio of dominant pattern
|
| 326 |
+
max_count = max(pattern_counts.values())
|
| 327 |
+
return max_count / len(induction_heads)
|
backend/model_service.py
CHANGED
|
@@ -57,6 +57,17 @@ class AblatedGenerationRequest(BaseModel):
|
|
| 57 |
extract_traces: bool = False
|
| 58 |
disabled_components: Optional[Dict[str, Any]] = None
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
class DemoRequest(BaseModel):
|
| 61 |
demo_id: str
|
| 62 |
|
|
@@ -855,6 +866,61 @@ async def generate_ablated(request: AblatedGenerationRequest, authenticated: boo
|
|
| 855 |
)
|
| 856 |
return result
|
| 857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
@app.get("/demos")
|
| 859 |
async def list_demos(authenticated: bool = Depends(verify_api_key)):
|
| 860 |
"""List available demo prompts"""
|
|
|
|
| 57 |
extract_traces: bool = False
|
| 58 |
disabled_components: Optional[Dict[str, Any]] = None
|
| 59 |
|
| 60 |
+
class ICLExample(BaseModel):
|
| 61 |
+
input: str
|
| 62 |
+
output: str
|
| 63 |
+
|
| 64 |
+
class ICLGenerationRequest(BaseModel):
|
| 65 |
+
examples: List[ICLExample]
|
| 66 |
+
prompt: str
|
| 67 |
+
max_tokens: int = 200 # Increased to accommodate examples + generation
|
| 68 |
+
temperature: float = 0.7
|
| 69 |
+
analyze: bool = True
|
| 70 |
+
|
| 71 |
class DemoRequest(BaseModel):
|
| 72 |
demo_id: str
|
| 73 |
|
|
|
|
| 866 |
)
|
| 867 |
return result
|
| 868 |
|
| 869 |
+
@app.post("/generate/icl")
|
| 870 |
+
async def generate_icl(request: ICLGenerationRequest, authenticated: bool = Depends(verify_api_key)):
|
| 871 |
+
"""Generate text with in-context learning analysis"""
|
| 872 |
+
from .icl_service import ICLAnalyzer, ICLExample as ICLExampleData
|
| 873 |
+
|
| 874 |
+
# Initialize ICL analyzer
|
| 875 |
+
analyzer = ICLAnalyzer(manager.model, manager.tokenizer)
|
| 876 |
+
|
| 877 |
+
# Convert request examples to ICLExample format
|
| 878 |
+
examples = [ICLExampleData(input=ex.input, output=ex.output) for ex in request.examples]
|
| 879 |
+
|
| 880 |
+
# Analyze generation with examples
|
| 881 |
+
result = analyzer.analyze_generation(
|
| 882 |
+
examples=examples,
|
| 883 |
+
test_prompt=request.prompt,
|
| 884 |
+
max_length=request.max_tokens,
|
| 885 |
+
temperature=request.temperature
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
# Convert result to dict for JSON response
|
| 889 |
+
response_data = {
|
| 890 |
+
"shotCount": result.shot_count,
|
| 891 |
+
"generatedCode": result.generated_code,
|
| 892 |
+
"tokens": result.tokens,
|
| 893 |
+
"confidenceScores": result.confidence_scores,
|
| 894 |
+
"attentionFromExamples": result.attention_from_examples,
|
| 895 |
+
"perplexity": result.perplexity,
|
| 896 |
+
"avgConfidence": result.avg_confidence,
|
| 897 |
+
"exampleInfluences": result.example_influences,
|
| 898 |
+
"hiddenStateDrift": result.hidden_state_drift
|
| 899 |
+
}
|
| 900 |
+
|
| 901 |
+
# Add ICL emergence data if available
|
| 902 |
+
if result.icl_emergence:
|
| 903 |
+
response_data["iclEmergence"] = {
|
| 904 |
+
"emergenceDetected": result.icl_emergence.emergence_detected,
|
| 905 |
+
"emergenceToken": result.icl_emergence.emergence_token,
|
| 906 |
+
"emergenceLayer": result.icl_emergence.emergence_layer,
|
| 907 |
+
"confidence": result.icl_emergence.confidence,
|
| 908 |
+
"inductionHeads": [
|
| 909 |
+
{
|
| 910 |
+
"layer": h.layer,
|
| 911 |
+
"head": h.head,
|
| 912 |
+
"strength": h.strength,
|
| 913 |
+
"patternType": h.pattern_type,
|
| 914 |
+
"emergencePoint": h.emergence_point
|
| 915 |
+
}
|
| 916 |
+
for h in result.icl_emergence.induction_heads
|
| 917 |
+
],
|
| 918 |
+
"attentionEntropyDrop": result.icl_emergence.attention_entropy_drop,
|
| 919 |
+
"patternConsistency": result.icl_emergence.pattern_consistency
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
return response_data
|
| 923 |
+
|
| 924 |
@app.get("/demos")
|
| 925 |
async def list_demos(authenticated: bool = Depends(verify_api_key)):
|
| 926 |
"""List available demo prompts"""
|