Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 13,790 Bytes
920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d ed40a9a 920a98d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 |
"""
Induction Head Detection for In-Context Learning
Based on research showing that ICL emerges abruptly in transformers through
the formation of induction heads - attention patterns that copy from context.
"""
import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class InductionHeadSignal:
"""Signals indicating induction head behavior"""
layer: int
head: int
strength: float # 0-1 score of induction pattern strength
pattern_type: str # 'copy', 'prefix_match', 'abstract'
emergence_point: Optional[int] # Token position where pattern emerges
@dataclass
class ICLEmergenceAnalysis:
"""Analysis of when and how ICL emerges"""
emergence_detected: bool
emergence_token: Optional[int] # Token position where ICL kicks in
emergence_layer: Optional[int] # Layer where strongest signal appears
confidence: float # Confidence in detection (0-1)
induction_heads: List[InductionHeadSignal]
attention_entropy_drop: List[float] # Entropy at each position
pattern_consistency: float # How consistent the pattern is
class InductionHeadDetector:
"""Detects induction heads and ICL emergence in transformer models"""
def __init__(self, model, tokenizer, adapter=None):
self.model = model
self.tokenizer = tokenizer
self.adapter = adapter
self.device = next(model.parameters()).device
def detect_induction_heads(
self,
attention_weights: List[Dict],
input_ids: torch.Tensor,
example_boundaries: List[Tuple[int, int]]
) -> List[InductionHeadSignal]:
"""
Detect induction heads by looking for attention patterns that:
1. Copy from previous occurrences (classic induction)
2. Match prefixes across examples
3. Show abstract pattern matching
"""
induction_heads = []
if not attention_weights or not example_boundaries:
return induction_heads
# Analyze each layer and head
layers_analyzed = {}
for record in attention_weights:
layer_idx = record.get('layer', 0)
attn = record.get('attention')
if attn is None or layer_idx in layers_analyzed:
continue
layers_analyzed[layer_idx] = True
# Analyze each attention head
if attn.dim() >= 3:
num_heads = attn.shape[1]
seq_len = attn.shape[-1]
for head_idx in range(num_heads):
head_attn = attn[0, head_idx] # [seq_len, seq_len]
# Detect different induction patterns
copy_score = self._detect_copy_pattern(head_attn, input_ids)
prefix_score = self._detect_prefix_matching(head_attn, example_boundaries)
abstract_score = self._detect_abstract_pattern(head_attn, seq_len)
# Determine strongest pattern
max_score = max(copy_score, prefix_score, abstract_score)
if max_score > 0.3: # Threshold for significant pattern
pattern_type = 'copy' if copy_score == max_score else \
'prefix_match' if prefix_score == max_score else 'abstract'
# Find emergence point (where pattern suddenly strengthens)
emergence_point = self._find_emergence_point(head_attn)
induction_heads.append(InductionHeadSignal(
layer=layer_idx,
head=head_idx,
strength=max_score,
pattern_type=pattern_type,
emergence_point=emergence_point
))
return induction_heads
def _detect_copy_pattern(self, attn_matrix: torch.Tensor, input_ids: torch.Tensor) -> float:
"""Detect if attention head copies from previous occurrences"""
seq_len = attn_matrix.shape[0]
copy_score = 0.0
count = 0
# Look for positions that attend strongly to previous same/similar tokens
for i in range(1, min(seq_len, 50)): # Limit analysis for efficiency
if i >= len(input_ids[0]):
break
current_token = input_ids[0][i].item()
# Find previous occurrences of the same token
for j in range(i):
if j < len(input_ids[0]) and input_ids[0][j].item() == current_token:
# Check if attention is strong to this position
if attn_matrix[i, j] > 0.1: # Threshold for significant attention
copy_score += attn_matrix[i, j].item()
count += 1
return copy_score / max(count, 1)
def _detect_prefix_matching(
self,
attn_matrix: torch.Tensor,
example_boundaries: List[Tuple[int, int]]
) -> float:
"""Detect if attention matches prefixes across examples"""
if len(example_boundaries) < 2:
return 0.0
prefix_score = 0.0
count = 0
# Check if tokens attend to similar positions in different examples
for i, (start1, end1) in enumerate(example_boundaries[:-1]):
for j, (start2, end2) in enumerate(example_boundaries[i+1:], i+1):
# Compare attention patterns between examples
for offset in range(min(5, end1-start1, end2-start2)): # Check first 5 tokens
pos1 = start1 + offset
pos2 = start2 + offset
if pos1 < attn_matrix.shape[0] and pos2 < attn_matrix.shape[1]:
# Check if later example attends to earlier example at same offset
if pos2 < attn_matrix.shape[0] and pos1 < attn_matrix.shape[1]:
attention_strength = attn_matrix[pos2, pos1].item()
if attention_strength > 0.1:
prefix_score += attention_strength
count += 1
return prefix_score / max(count, 1)
def _detect_abstract_pattern(self, attn_matrix: torch.Tensor, seq_len: int) -> float:
"""Detect abstract pattern matching (e.g., function->function mapping)"""
# Look for diagonal patterns offset by example length
# This indicates attending to structurally similar positions
abstract_score = 0.0
window_size = 10
for i in range(window_size, min(seq_len, 50)):
# Check if attention follows a diagonal pattern with offset
diagonal_sum = 0.0
for offset in range(1, min(window_size, i)):
if i - offset >= 0:
diagonal_sum += attn_matrix[i, i - offset].item()
# High diagonal attention indicates structural copying
if diagonal_sum / window_size > 0.1:
abstract_score += diagonal_sum / window_size
return min(abstract_score / 10, 1.0) # Normalize
def _find_emergence_point(self, attn_matrix: torch.Tensor) -> Optional[int]:
"""Find the token position where the pattern suddenly emerges"""
seq_len = min(attn_matrix.shape[0], 50) # Limit for efficiency
if seq_len < 10:
return None
# Calculate attention entropy at each position
entropies = []
for i in range(seq_len):
attn_dist = attn_matrix[i, :i+1] # Only look at previous positions
if attn_dist.sum() > 0:
attn_dist = attn_dist / attn_dist.sum()
# Calculate entropy
entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
entropies.append(entropy)
else:
entropies.append(0.0)
# Find sudden drops in entropy (indicating focused attention)
if len(entropies) < 5:
return None
for i in range(4, len(entropies)):
recent_avg = np.mean(entropies[i-4:i])
current = entropies[i]
# Sudden drop indicates emergence
if recent_avg > 0 and current < recent_avg * 0.5:
return i
return None
def analyze_icl_emergence(
self,
attention_weights: List[Dict],
input_ids: torch.Tensor,
example_boundaries: List[Tuple[int, int]],
generated_tokens: List[int]
) -> ICLEmergenceAnalysis:
"""
Comprehensive analysis of when and how ICL emerges during generation
"""
# Detect induction heads
induction_heads = self.detect_induction_heads(
attention_weights, input_ids, example_boundaries
)
# Calculate attention entropy trajectory
entropy_trajectory = self._calculate_entropy_trajectory(
attention_weights, len(generated_tokens)
)
# Determine emergence point
emergence_token = None
emergence_layer = None
emergence_confidence = 0.0
if induction_heads:
# Find strongest induction signal
strongest_head = max(induction_heads, key=lambda h: h.strength)
# Check for consistent emergence points across heads
emergence_points = [h.emergence_point for h in induction_heads if h.emergence_point]
if emergence_points:
# Most common emergence point
emergence_token = int(np.median(emergence_points))
emergence_layer = strongest_head.layer
# Confidence based on consistency and strength
consistency = len(emergence_points) / len(induction_heads)
emergence_confidence = min(strongest_head.strength * consistency, 1.0)
# Check for entropy drop as additional signal
if entropy_trajectory and len(entropy_trajectory) > 5:
for i in range(5, len(entropy_trajectory)):
recent_avg = np.mean(entropy_trajectory[i-5:i])
if recent_avg > 0 and entropy_trajectory[i] < recent_avg * 0.6:
if emergence_token is None:
emergence_token = i
emergence_confidence = 0.5
break
# Calculate pattern consistency
pattern_consistency = self._calculate_pattern_consistency(induction_heads)
return ICLEmergenceAnalysis(
emergence_detected=emergence_token is not None,
emergence_token=emergence_token,
emergence_layer=emergence_layer,
confidence=emergence_confidence,
induction_heads=induction_heads,
attention_entropy_drop=entropy_trajectory,
pattern_consistency=pattern_consistency
)
def _calculate_entropy_trajectory(
self,
attention_weights: List[Dict],
num_generated: int
) -> List[float]:
"""Calculate attention entropy at each generated position"""
entropies = []
if not attention_weights:
return entropies
# Group attention by position
num_layers = self.adapter.get_num_layers() if self.adapter else 20 # Use adapter or fallback to CodeGen's 20
for gen_idx in range(num_generated):
position_entropy = []
# Get attention for this generated position across all layers
for i in range(gen_idx * num_layers, min((gen_idx + 1) * num_layers, len(attention_weights))):
if i < len(attention_weights):
attn = attention_weights[i].get('attention')
if attn is not None and attn.dim() >= 3:
# Average across heads
avg_attn = attn[0].mean(dim=0)
if avg_attn.shape[0] > gen_idx:
# Get attention distribution for this position
attn_dist = avg_attn[-1] # Last position is newly generated
if attn_dist.sum() > 0:
attn_dist = attn_dist / attn_dist.sum()
# Calculate entropy
entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
position_entropy.append(entropy)
if position_entropy:
entropies.append(np.mean(position_entropy))
else:
entropies.append(0.0)
return entropies
def _calculate_pattern_consistency(self, induction_heads: List[InductionHeadSignal]) -> float:
"""Calculate how consistent the induction patterns are across heads"""
if not induction_heads:
return 0.0
# Group by pattern type
pattern_counts = {}
for head in induction_heads:
pattern_counts[head.pattern_type] = pattern_counts.get(head.pattern_type, 0) + 1
# Consistency is ratio of dominant pattern
max_count = max(pattern_counts.values())
return max_count / len(induction_heads) |