api / backend /induction_head_detector.py
gary-boon
Add Code Llama 7B support with hardware-aware filtering and ICL timeout fixes
ed40a9a
"""
Induction Head Detection for In-Context Learning
Based on research showing that ICL emerges abruptly in transformers through
the formation of induction heads - attention patterns that copy from context.
"""
import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class InductionHeadSignal:
"""Signals indicating induction head behavior"""
layer: int
head: int
strength: float # 0-1 score of induction pattern strength
pattern_type: str # 'copy', 'prefix_match', 'abstract'
emergence_point: Optional[int] # Token position where pattern emerges
@dataclass
class ICLEmergenceAnalysis:
"""Analysis of when and how ICL emerges"""
emergence_detected: bool
emergence_token: Optional[int] # Token position where ICL kicks in
emergence_layer: Optional[int] # Layer where strongest signal appears
confidence: float # Confidence in detection (0-1)
induction_heads: List[InductionHeadSignal]
attention_entropy_drop: List[float] # Entropy at each position
pattern_consistency: float # How consistent the pattern is
class InductionHeadDetector:
"""Detects induction heads and ICL emergence in transformer models"""
def __init__(self, model, tokenizer, adapter=None):
self.model = model
self.tokenizer = tokenizer
self.adapter = adapter
self.device = next(model.parameters()).device
def detect_induction_heads(
self,
attention_weights: List[Dict],
input_ids: torch.Tensor,
example_boundaries: List[Tuple[int, int]]
) -> List[InductionHeadSignal]:
"""
Detect induction heads by looking for attention patterns that:
1. Copy from previous occurrences (classic induction)
2. Match prefixes across examples
3. Show abstract pattern matching
"""
induction_heads = []
if not attention_weights or not example_boundaries:
return induction_heads
# Analyze each layer and head
layers_analyzed = {}
for record in attention_weights:
layer_idx = record.get('layer', 0)
attn = record.get('attention')
if attn is None or layer_idx in layers_analyzed:
continue
layers_analyzed[layer_idx] = True
# Analyze each attention head
if attn.dim() >= 3:
num_heads = attn.shape[1]
seq_len = attn.shape[-1]
for head_idx in range(num_heads):
head_attn = attn[0, head_idx] # [seq_len, seq_len]
# Detect different induction patterns
copy_score = self._detect_copy_pattern(head_attn, input_ids)
prefix_score = self._detect_prefix_matching(head_attn, example_boundaries)
abstract_score = self._detect_abstract_pattern(head_attn, seq_len)
# Determine strongest pattern
max_score = max(copy_score, prefix_score, abstract_score)
if max_score > 0.3: # Threshold for significant pattern
pattern_type = 'copy' if copy_score == max_score else \
'prefix_match' if prefix_score == max_score else 'abstract'
# Find emergence point (where pattern suddenly strengthens)
emergence_point = self._find_emergence_point(head_attn)
induction_heads.append(InductionHeadSignal(
layer=layer_idx,
head=head_idx,
strength=max_score,
pattern_type=pattern_type,
emergence_point=emergence_point
))
return induction_heads
def _detect_copy_pattern(self, attn_matrix: torch.Tensor, input_ids: torch.Tensor) -> float:
"""Detect if attention head copies from previous occurrences"""
seq_len = attn_matrix.shape[0]
copy_score = 0.0
count = 0
# Look for positions that attend strongly to previous same/similar tokens
for i in range(1, min(seq_len, 50)): # Limit analysis for efficiency
if i >= len(input_ids[0]):
break
current_token = input_ids[0][i].item()
# Find previous occurrences of the same token
for j in range(i):
if j < len(input_ids[0]) and input_ids[0][j].item() == current_token:
# Check if attention is strong to this position
if attn_matrix[i, j] > 0.1: # Threshold for significant attention
copy_score += attn_matrix[i, j].item()
count += 1
return copy_score / max(count, 1)
def _detect_prefix_matching(
self,
attn_matrix: torch.Tensor,
example_boundaries: List[Tuple[int, int]]
) -> float:
"""Detect if attention matches prefixes across examples"""
if len(example_boundaries) < 2:
return 0.0
prefix_score = 0.0
count = 0
# Check if tokens attend to similar positions in different examples
for i, (start1, end1) in enumerate(example_boundaries[:-1]):
for j, (start2, end2) in enumerate(example_boundaries[i+1:], i+1):
# Compare attention patterns between examples
for offset in range(min(5, end1-start1, end2-start2)): # Check first 5 tokens
pos1 = start1 + offset
pos2 = start2 + offset
if pos1 < attn_matrix.shape[0] and pos2 < attn_matrix.shape[1]:
# Check if later example attends to earlier example at same offset
if pos2 < attn_matrix.shape[0] and pos1 < attn_matrix.shape[1]:
attention_strength = attn_matrix[pos2, pos1].item()
if attention_strength > 0.1:
prefix_score += attention_strength
count += 1
return prefix_score / max(count, 1)
def _detect_abstract_pattern(self, attn_matrix: torch.Tensor, seq_len: int) -> float:
"""Detect abstract pattern matching (e.g., function->function mapping)"""
# Look for diagonal patterns offset by example length
# This indicates attending to structurally similar positions
abstract_score = 0.0
window_size = 10
for i in range(window_size, min(seq_len, 50)):
# Check if attention follows a diagonal pattern with offset
diagonal_sum = 0.0
for offset in range(1, min(window_size, i)):
if i - offset >= 0:
diagonal_sum += attn_matrix[i, i - offset].item()
# High diagonal attention indicates structural copying
if diagonal_sum / window_size > 0.1:
abstract_score += diagonal_sum / window_size
return min(abstract_score / 10, 1.0) # Normalize
def _find_emergence_point(self, attn_matrix: torch.Tensor) -> Optional[int]:
"""Find the token position where the pattern suddenly emerges"""
seq_len = min(attn_matrix.shape[0], 50) # Limit for efficiency
if seq_len < 10:
return None
# Calculate attention entropy at each position
entropies = []
for i in range(seq_len):
attn_dist = attn_matrix[i, :i+1] # Only look at previous positions
if attn_dist.sum() > 0:
attn_dist = attn_dist / attn_dist.sum()
# Calculate entropy
entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
entropies.append(entropy)
else:
entropies.append(0.0)
# Find sudden drops in entropy (indicating focused attention)
if len(entropies) < 5:
return None
for i in range(4, len(entropies)):
recent_avg = np.mean(entropies[i-4:i])
current = entropies[i]
# Sudden drop indicates emergence
if recent_avg > 0 and current < recent_avg * 0.5:
return i
return None
def analyze_icl_emergence(
self,
attention_weights: List[Dict],
input_ids: torch.Tensor,
example_boundaries: List[Tuple[int, int]],
generated_tokens: List[int]
) -> ICLEmergenceAnalysis:
"""
Comprehensive analysis of when and how ICL emerges during generation
"""
# Detect induction heads
induction_heads = self.detect_induction_heads(
attention_weights, input_ids, example_boundaries
)
# Calculate attention entropy trajectory
entropy_trajectory = self._calculate_entropy_trajectory(
attention_weights, len(generated_tokens)
)
# Determine emergence point
emergence_token = None
emergence_layer = None
emergence_confidence = 0.0
if induction_heads:
# Find strongest induction signal
strongest_head = max(induction_heads, key=lambda h: h.strength)
# Check for consistent emergence points across heads
emergence_points = [h.emergence_point for h in induction_heads if h.emergence_point]
if emergence_points:
# Most common emergence point
emergence_token = int(np.median(emergence_points))
emergence_layer = strongest_head.layer
# Confidence based on consistency and strength
consistency = len(emergence_points) / len(induction_heads)
emergence_confidence = min(strongest_head.strength * consistency, 1.0)
# Check for entropy drop as additional signal
if entropy_trajectory and len(entropy_trajectory) > 5:
for i in range(5, len(entropy_trajectory)):
recent_avg = np.mean(entropy_trajectory[i-5:i])
if recent_avg > 0 and entropy_trajectory[i] < recent_avg * 0.6:
if emergence_token is None:
emergence_token = i
emergence_confidence = 0.5
break
# Calculate pattern consistency
pattern_consistency = self._calculate_pattern_consistency(induction_heads)
return ICLEmergenceAnalysis(
emergence_detected=emergence_token is not None,
emergence_token=emergence_token,
emergence_layer=emergence_layer,
confidence=emergence_confidence,
induction_heads=induction_heads,
attention_entropy_drop=entropy_trajectory,
pattern_consistency=pattern_consistency
)
def _calculate_entropy_trajectory(
self,
attention_weights: List[Dict],
num_generated: int
) -> List[float]:
"""Calculate attention entropy at each generated position"""
entropies = []
if not attention_weights:
return entropies
# Group attention by position
num_layers = self.adapter.get_num_layers() if self.adapter else 20 # Use adapter or fallback to CodeGen's 20
for gen_idx in range(num_generated):
position_entropy = []
# Get attention for this generated position across all layers
for i in range(gen_idx * num_layers, min((gen_idx + 1) * num_layers, len(attention_weights))):
if i < len(attention_weights):
attn = attention_weights[i].get('attention')
if attn is not None and attn.dim() >= 3:
# Average across heads
avg_attn = attn[0].mean(dim=0)
if avg_attn.shape[0] > gen_idx:
# Get attention distribution for this position
attn_dist = avg_attn[-1] # Last position is newly generated
if attn_dist.sum() > 0:
attn_dist = attn_dist / attn_dist.sum()
# Calculate entropy
entropy = -(attn_dist * torch.log(attn_dist + 1e-10)).sum().item()
position_entropy.append(entropy)
if position_entropy:
entropies.append(np.mean(position_entropy))
else:
entropies.append(0.0)
return entropies
def _calculate_pattern_consistency(self, induction_heads: List[InductionHeadSignal]) -> float:
"""Calculate how consistent the induction patterns are across heads"""
if not induction_heads:
return 0.0
# Group by pattern type
pattern_counts = {}
for head in induction_heads:
pattern_counts[head.pattern_type] = pattern_counts.get(head.pattern_type, 0) + 1
# Consistency is ratio of dominant pattern
max_count = max(pattern_counts.values())
return max_count / len(induction_heads)