""" Induction Head Detection for In-Context Learning Based on research showing that ICL emerges abruptly in transformers through the formation of induction heads - attention patterns that copy from context. """ import torch import numpy as np from typing import List, Dict, Tuple, Optional from dataclasses import dataclass import logging logger = logging.getLogger(__name__) @dataclass class InductionHeadSignal: """Signals indicating induction head behavior""" layer: int head: int strength: float # 0-1 score of induction pattern strength pattern_type: str # 'copy', 'prefix_match', 'abstract' emergence_point: Optional[int] # Token position where pattern emerges @dataclass class ICLEmergenceAnalysis: """Analysis of when and how ICL emerges""" emergence_detected: bool emergence_token: Optional[int] # Token position where ICL kicks in emergence_layer: Optional[int] # Layer where strongest signal appears confidence: float # Confidence in detection (0-1) induction_heads: List[InductionHeadSignal] attention_entropy_drop: List[float] # Entropy at each position pattern_consistency: float # How consistent the pattern is class InductionHeadDetector: """Detects induction heads and ICL emergence in transformer models""" def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer self.device = next(model.parameters()).device def detect_induction_heads( self, attention_weights: List[Dict], input_ids: torch.Tensor, example_boundaries: List[Tuple[int, int]] ) -> List[InductionHeadSignal]: """ Detect induction heads by looking for attention patterns that: 1. Copy from previous occurrences (classic induction) 2. Match prefixes across examples 3. Show abstract pattern matching """ induction_heads = [] if not attention_weights or not example_boundaries: return induction_heads # Analyze each layer and head layers_analyzed = {} for record in attention_weights: layer_idx = record.get('layer', 0) attn = record.get('attention') if attn is None or layer_idx in layers_analyzed: continue layers_analyzed[layer_idx] = True # Analyze each attention head if attn.dim() >= 3: num_heads = attn.shape[1] seq_len = attn.shape[-1] for head_idx in range(num_heads): head_attn = attn[0, head_idx] # [seq_len, seq_len] # Detect different induction patterns copy_score = self._detect_copy_pattern(head_attn, input_ids) prefix_score = self._detect_prefix_matching(head_attn, example_boundaries) abstract_score = self._detect_abstract_pattern(head_attn, seq_len) # Determine strongest pattern max_score = max(copy_score, prefix_score, abstract_score) if max_score > 0.3: # Threshold for significant pattern pattern_type = 'copy' if copy_score == max_score else \ 'prefix_match' if prefix_score == max_score else 'abstract' # Find emergence point (where pattern suddenly strengthens) emergence_point = self._find_emergence_point(head_attn) induction_heads.append(InductionHeadSignal( layer=layer_idx, head=head_idx, strength=max_score, pattern_type=pattern_type, emergence_point=emergence_point )) return induction_heads def _detect_copy_pattern(self, attn_matrix: torch.Tensor, input_ids: torch.Tensor) -> float: """Detect if attention head copies from previous occurrences""" seq_len = attn_matrix.shape[0] copy_score = 0.0 count = 0 # Look for positions that attend strongly to previous same/similar tokens for i in range(1, min(seq_len, 50)): # Limit analysis for efficiency if i >= len(input_ids[0]): break current_token = input_ids[0][i].item() # Find previous occurrences of the same token for j in range(i): if j < len(input_ids[0]) and input_ids[0][j].item() == current_token: # Check if attention is strong to this position if attn_matrix[i, j] > 0.1: # Threshold for significant attention copy_score += attn_matrix[i, j].item() count += 1 return copy_score / max(count, 1) def _detect_prefix_matching( self, attn_matrix: torch.Tensor, example_boundaries: List[Tuple[int, int]] ) -> float: """Detect if attention matches prefixes across examples""" if len(example_boundaries) < 2: return 0.0 prefix_score = 0.0 count = 0 # Check if tokens attend to similar positions in different examples for i, (start1, end1) in enumerate(example_boundaries[:-1]): for j, (start2, end2) in enumerate(example_boundaries[i+1:], i+1): # Compare attention patterns between examples for offset in range(min(5, end1-start1, end2-start2)): # Check first 5 tokens pos1 = start1 + offset pos2 = start2 + offset if pos1 < attn_matrix.shape[0] and pos2 < attn_matrix.shape[1]: # Check if later example attends to earlier example at same offset if pos2 < attn_matrix.shape[0] and pos1 < attn_matrix.shape[1]: attention_strength = attn_matrix[pos2, pos1].item() if attention_strength > 0.1: prefix_score += attention_strength count += 1 return prefix_score / max(count, 1) def _detect_abstract_pattern(self, attn_matrix: torch.Tensor, seq_len: int) -> float: """Detect abstract pattern matching (e.g., function->function mapping)""" # Look for diagonal patterns offset by example length # This indicates attending to structurally similar positions abstract_score = 0.0 window_size = 10 for i in range(window_size, min(seq_len, 50)): # Check if attention follows a diagonal pattern with offset diagonal_sum = 0.0 for offset in range(1, min(window_size, i)): if i - offset >= 0: diagonal_sum += attn_matrix[i, i - offset].item() # High diagonal attention indicates structural copying if diagonal_sum / window_size > 0.1: abstract_score += diagonal_sum / window_size return min(abstract_score / 10, 1.0) # Normalize def _find_emergence_point(self, attn_matrix: torch.Tensor) -> Optional[int]: """Find the token position where the pattern suddenly emerges""" seq_len = min(attn_matrix.shape[0], 50) # Limit for efficiency if seq_len < 10: return None # Calculate attention entropy at each position entropies = [] for i in range(seq_len): attn_dist = attn_matrix[i, :i+1] # Only look at previous positions if attn_dist.sum() > 0: attn_dist = attn_dist / attn_dist.sum() # Calculate entropy entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item() entropies.append(entropy) else: entropies.append(0.0) # Find sudden drops in entropy (indicating focused attention) if len(entropies) < 5: return None for i in range(4, len(entropies)): recent_avg = np.mean(entropies[i-4:i]) current = entropies[i] # Sudden drop indicates emergence if recent_avg > 0 and current < recent_avg * 0.5: return i return None def analyze_icl_emergence( self, attention_weights: List[Dict], input_ids: torch.Tensor, example_boundaries: List[Tuple[int, int]], generated_tokens: List[int] ) -> ICLEmergenceAnalysis: """ Comprehensive analysis of when and how ICL emerges during generation """ # Detect induction heads induction_heads = self.detect_induction_heads( attention_weights, input_ids, example_boundaries ) # Calculate attention entropy trajectory entropy_trajectory = self._calculate_entropy_trajectory( attention_weights, len(generated_tokens) ) # Determine emergence point emergence_token = None emergence_layer = None emergence_confidence = 0.0 if induction_heads: # Find strongest induction signal strongest_head = max(induction_heads, key=lambda h: h.strength) # Check for consistent emergence points across heads emergence_points = [h.emergence_point for h in induction_heads if h.emergence_point] if emergence_points: # Most common emergence point emergence_token = int(np.median(emergence_points)) emergence_layer = strongest_head.layer # Confidence based on consistency and strength consistency = len(emergence_points) / len(induction_heads) emergence_confidence = min(strongest_head.strength * consistency, 1.0) # Check for entropy drop as additional signal if entropy_trajectory and len(entropy_trajectory) > 5: for i in range(5, len(entropy_trajectory)): recent_avg = np.mean(entropy_trajectory[i-5:i]) if recent_avg > 0 and entropy_trajectory[i] < recent_avg * 0.6: if emergence_token is None: emergence_token = i emergence_confidence = 0.5 break # Calculate pattern consistency pattern_consistency = self._calculate_pattern_consistency(induction_heads) return ICLEmergenceAnalysis( emergence_detected=emergence_token is not None, emergence_token=emergence_token, emergence_layer=emergence_layer, confidence=emergence_confidence, induction_heads=induction_heads, attention_entropy_drop=entropy_trajectory, pattern_consistency=pattern_consistency ) def _calculate_entropy_trajectory( self, attention_weights: List[Dict], num_generated: int ) -> List[float]: """Calculate attention entropy at each generated position""" entropies = [] if not attention_weights: return entropies # Group attention by position num_layers = 20 # CodeGen model for gen_idx in range(num_generated): position_entropy = [] # Get attention for this generated position across all layers for i in range(gen_idx * num_layers, min((gen_idx + 1) * num_layers, len(attention_weights))): if i < len(attention_weights): attn = attention_weights[i].get('attention') if attn is not None and attn.dim() >= 3: # Average across heads avg_attn = attn[0].mean(dim=0) if avg_attn.shape[0] > gen_idx: # Get attention distribution for this position attn_dist = avg_attn[-1] # Last position is newly generated if attn_dist.sum() > 0: attn_dist = attn_dist / attn_dist.sum() # Calculate entropy entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item() position_entropy.append(entropy) if position_entropy: entropies.append(np.mean(position_entropy)) else: entropies.append(0.0) return entropies def _calculate_pattern_consistency(self, induction_heads: List[InductionHeadSignal]) -> float: """Calculate how consistent the induction patterns are across heads""" if not induction_heads: return 0.0 # Group by pattern type pattern_counts = {} for head in induction_heads: pattern_counts[head.pattern_type] = pattern_counts.get(head.pattern_type, 0) + 1 # Consistency is ratio of dominant pattern max_count = max(pattern_counts.values()) return max_count / len(induction_heads)