Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """ | |
| Induction Head Detection for In-Context Learning | |
| Based on research showing that ICL emerges abruptly in transformers through | |
| the formation of induction heads - attention patterns that copy from context. | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Tuple, Optional | |
| from dataclasses import dataclass | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class InductionHeadSignal: | |
| """Signals indicating induction head behavior""" | |
| layer: int | |
| head: int | |
| strength: float # 0-1 score of induction pattern strength | |
| pattern_type: str # 'copy', 'prefix_match', 'abstract' | |
| emergence_point: Optional[int] # Token position where pattern emerges | |
| class ICLEmergenceAnalysis: | |
| """Analysis of when and how ICL emerges""" | |
| emergence_detected: bool | |
| emergence_token: Optional[int] # Token position where ICL kicks in | |
| emergence_layer: Optional[int] # Layer where strongest signal appears | |
| confidence: float # Confidence in detection (0-1) | |
| induction_heads: List[InductionHeadSignal] | |
| attention_entropy_drop: List[float] # Entropy at each position | |
| pattern_consistency: float # How consistent the pattern is | |
| class InductionHeadDetector: | |
| """Detects induction heads and ICL emergence in transformer models""" | |
| def __init__(self, model, tokenizer): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.device = next(model.parameters()).device | |
| def detect_induction_heads( | |
| self, | |
| attention_weights: List[Dict], | |
| input_ids: torch.Tensor, | |
| example_boundaries: List[Tuple[int, int]] | |
| ) -> List[InductionHeadSignal]: | |
| """ | |
| Detect induction heads by looking for attention patterns that: | |
| 1. Copy from previous occurrences (classic induction) | |
| 2. Match prefixes across examples | |
| 3. Show abstract pattern matching | |
| """ | |
| induction_heads = [] | |
| if not attention_weights or not example_boundaries: | |
| return induction_heads | |
| # Analyze each layer and head | |
| layers_analyzed = {} | |
| for record in attention_weights: | |
| layer_idx = record.get('layer', 0) | |
| attn = record.get('attention') | |
| if attn is None or layer_idx in layers_analyzed: | |
| continue | |
| layers_analyzed[layer_idx] = True | |
| # Analyze each attention head | |
| if attn.dim() >= 3: | |
| num_heads = attn.shape[1] | |
| seq_len = attn.shape[-1] | |
| for head_idx in range(num_heads): | |
| head_attn = attn[0, head_idx] # [seq_len, seq_len] | |
| # Detect different induction patterns | |
| copy_score = self._detect_copy_pattern(head_attn, input_ids) | |
| prefix_score = self._detect_prefix_matching(head_attn, example_boundaries) | |
| abstract_score = self._detect_abstract_pattern(head_attn, seq_len) | |
| # Determine strongest pattern | |
| max_score = max(copy_score, prefix_score, abstract_score) | |
| if max_score > 0.3: # Threshold for significant pattern | |
| pattern_type = 'copy' if copy_score == max_score else \ | |
| 'prefix_match' if prefix_score == max_score else 'abstract' | |
| # Find emergence point (where pattern suddenly strengthens) | |
| emergence_point = self._find_emergence_point(head_attn) | |
| induction_heads.append(InductionHeadSignal( | |
| layer=layer_idx, | |
| head=head_idx, | |
| strength=max_score, | |
| pattern_type=pattern_type, | |
| emergence_point=emergence_point | |
| )) | |
| return induction_heads | |
| def _detect_copy_pattern(self, attn_matrix: torch.Tensor, input_ids: torch.Tensor) -> float: | |
| """Detect if attention head copies from previous occurrences""" | |
| seq_len = attn_matrix.shape[0] | |
| copy_score = 0.0 | |
| count = 0 | |
| # Look for positions that attend strongly to previous same/similar tokens | |
| for i in range(1, min(seq_len, 50)): # Limit analysis for efficiency | |
| if i >= len(input_ids[0]): | |
| break | |
| current_token = input_ids[0][i].item() | |
| # Find previous occurrences of the same token | |
| for j in range(i): | |
| if j < len(input_ids[0]) and input_ids[0][j].item() == current_token: | |
| # Check if attention is strong to this position | |
| if attn_matrix[i, j] > 0.1: # Threshold for significant attention | |
| copy_score += attn_matrix[i, j].item() | |
| count += 1 | |
| return copy_score / max(count, 1) | |
| def _detect_prefix_matching( | |
| self, | |
| attn_matrix: torch.Tensor, | |
| example_boundaries: List[Tuple[int, int]] | |
| ) -> float: | |
| """Detect if attention matches prefixes across examples""" | |
| if len(example_boundaries) < 2: | |
| return 0.0 | |
| prefix_score = 0.0 | |
| count = 0 | |
| # Check if tokens attend to similar positions in different examples | |
| for i, (start1, end1) in enumerate(example_boundaries[:-1]): | |
| for j, (start2, end2) in enumerate(example_boundaries[i+1:], i+1): | |
| # Compare attention patterns between examples | |
| for offset in range(min(5, end1-start1, end2-start2)): # Check first 5 tokens | |
| pos1 = start1 + offset | |
| pos2 = start2 + offset | |
| if pos1 < attn_matrix.shape[0] and pos2 < attn_matrix.shape[1]: | |
| # Check if later example attends to earlier example at same offset | |
| if pos2 < attn_matrix.shape[0] and pos1 < attn_matrix.shape[1]: | |
| attention_strength = attn_matrix[pos2, pos1].item() | |
| if attention_strength > 0.1: | |
| prefix_score += attention_strength | |
| count += 1 | |
| return prefix_score / max(count, 1) | |
| def _detect_abstract_pattern(self, attn_matrix: torch.Tensor, seq_len: int) -> float: | |
| """Detect abstract pattern matching (e.g., function->function mapping)""" | |
| # Look for diagonal patterns offset by example length | |
| # This indicates attending to structurally similar positions | |
| abstract_score = 0.0 | |
| window_size = 10 | |
| for i in range(window_size, min(seq_len, 50)): | |
| # Check if attention follows a diagonal pattern with offset | |
| diagonal_sum = 0.0 | |
| for offset in range(1, min(window_size, i)): | |
| if i - offset >= 0: | |
| diagonal_sum += attn_matrix[i, i - offset].item() | |
| # High diagonal attention indicates structural copying | |
| if diagonal_sum / window_size > 0.1: | |
| abstract_score += diagonal_sum / window_size | |
| return min(abstract_score / 10, 1.0) # Normalize | |
| def _find_emergence_point(self, attn_matrix: torch.Tensor) -> Optional[int]: | |
| """Find the token position where the pattern suddenly emerges""" | |
| seq_len = min(attn_matrix.shape[0], 50) # Limit for efficiency | |
| if seq_len < 10: | |
| return None | |
| # Calculate attention entropy at each position | |
| entropies = [] | |
| for i in range(seq_len): | |
| attn_dist = attn_matrix[i, :i+1] # Only look at previous positions | |
| if attn_dist.sum() > 0: | |
| attn_dist = attn_dist / attn_dist.sum() | |
| # Calculate entropy | |
| entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item() | |
| entropies.append(entropy) | |
| else: | |
| entropies.append(0.0) | |
| # Find sudden drops in entropy (indicating focused attention) | |
| if len(entropies) < 5: | |
| return None | |
| for i in range(4, len(entropies)): | |
| recent_avg = np.mean(entropies[i-4:i]) | |
| current = entropies[i] | |
| # Sudden drop indicates emergence | |
| if recent_avg > 0 and current < recent_avg * 0.5: | |
| return i | |
| return None | |
| def analyze_icl_emergence( | |
| self, | |
| attention_weights: List[Dict], | |
| input_ids: torch.Tensor, | |
| example_boundaries: List[Tuple[int, int]], | |
| generated_tokens: List[int] | |
| ) -> ICLEmergenceAnalysis: | |
| """ | |
| Comprehensive analysis of when and how ICL emerges during generation | |
| """ | |
| # Detect induction heads | |
| induction_heads = self.detect_induction_heads( | |
| attention_weights, input_ids, example_boundaries | |
| ) | |
| # Calculate attention entropy trajectory | |
| entropy_trajectory = self._calculate_entropy_trajectory( | |
| attention_weights, len(generated_tokens) | |
| ) | |
| # Determine emergence point | |
| emergence_token = None | |
| emergence_layer = None | |
| emergence_confidence = 0.0 | |
| if induction_heads: | |
| # Find strongest induction signal | |
| strongest_head = max(induction_heads, key=lambda h: h.strength) | |
| # Check for consistent emergence points across heads | |
| emergence_points = [h.emergence_point for h in induction_heads if h.emergence_point] | |
| if emergence_points: | |
| # Most common emergence point | |
| emergence_token = int(np.median(emergence_points)) | |
| emergence_layer = strongest_head.layer | |
| # Confidence based on consistency and strength | |
| consistency = len(emergence_points) / len(induction_heads) | |
| emergence_confidence = min(strongest_head.strength * consistency, 1.0) | |
| # Check for entropy drop as additional signal | |
| if entropy_trajectory and len(entropy_trajectory) > 5: | |
| for i in range(5, len(entropy_trajectory)): | |
| recent_avg = np.mean(entropy_trajectory[i-5:i]) | |
| if recent_avg > 0 and entropy_trajectory[i] < recent_avg * 0.6: | |
| if emergence_token is None: | |
| emergence_token = i | |
| emergence_confidence = 0.5 | |
| break | |
| # Calculate pattern consistency | |
| pattern_consistency = self._calculate_pattern_consistency(induction_heads) | |
| return ICLEmergenceAnalysis( | |
| emergence_detected=emergence_token is not None, | |
| emergence_token=emergence_token, | |
| emergence_layer=emergence_layer, | |
| confidence=emergence_confidence, | |
| induction_heads=induction_heads, | |
| attention_entropy_drop=entropy_trajectory, | |
| pattern_consistency=pattern_consistency | |
| ) | |
| def _calculate_entropy_trajectory( | |
| self, | |
| attention_weights: List[Dict], | |
| num_generated: int | |
| ) -> List[float]: | |
| """Calculate attention entropy at each generated position""" | |
| entropies = [] | |
| if not attention_weights: | |
| return entropies | |
| # Group attention by position | |
| num_layers = 20 # CodeGen model | |
| for gen_idx in range(num_generated): | |
| position_entropy = [] | |
| # Get attention for this generated position across all layers | |
| for i in range(gen_idx * num_layers, min((gen_idx + 1) * num_layers, len(attention_weights))): | |
| if i < len(attention_weights): | |
| attn = attention_weights[i].get('attention') | |
| if attn is not None and attn.dim() >= 3: | |
| # Average across heads | |
| avg_attn = attn[0].mean(dim=0) | |
| if avg_attn.shape[0] > gen_idx: | |
| # Get attention distribution for this position | |
| attn_dist = avg_attn[-1] # Last position is newly generated | |
| if attn_dist.sum() > 0: | |
| attn_dist = attn_dist / attn_dist.sum() | |
| # Calculate entropy | |
| entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item() | |
| position_entropy.append(entropy) | |
| if position_entropy: | |
| entropies.append(np.mean(position_entropy)) | |
| else: | |
| entropies.append(0.0) | |
| return entropies | |
| def _calculate_pattern_consistency(self, induction_heads: List[InductionHeadSignal]) -> float: | |
| """Calculate how consistent the induction patterns are across heads""" | |
| if not induction_heads: | |
| return 0.0 | |
| # Group by pattern type | |
| pattern_counts = {} | |
| for head in induction_heads: | |
| pattern_counts[head.pattern_type] = pattern_counts.get(head.pattern_type, 0) + 1 | |
| # Consistency is ratio of dominant pattern | |
| max_count = max(pattern_counts.values()) | |
| return max_count / len(induction_heads) |