""" Context Efficiency Analyzer for In-Context Learning Measures how efficiently the model uses context examples to perform tasks. Based on research showing that not all examples contribute equally and that optimal context usage can significantly improve performance. """ import torch import numpy as np from typing import List, Dict, Tuple, Optional from dataclasses import dataclass import logging logger = logging.getLogger(__name__) @dataclass class TokenEfficiency: """Efficiency metrics for individual tokens""" token: str position: int information_content: float # Bits of information redundancy_score: float # 0-1 (1 = completely redundant) contribution_score: float # How much it contributes to output @dataclass class ExampleEfficiency: """Efficiency metrics for each example""" example_id: str total_tokens: int effective_tokens: int # Tokens that actually contribute efficiency_ratio: float # effective/total redundancy_rate: float # Percentage of redundant tokens information_density: float # Bits per token marginal_benefit: float # Additional benefit vs previous examples @dataclass class ContextEfficiencyAnalysis: """Complete context efficiency analysis""" overall_efficiency: float # 0-1 score total_context_tokens: int effective_context_tokens: int example_efficiencies: List[ExampleEfficiency] token_efficiencies: List[TokenEfficiency] optimal_example_count: int # Suggested optimal number of examples redundancy_patterns: Dict[str, float] # Pattern type -> frequency compression_potential: float # How much context could be compressed attention_utilization: float # How much of context gets attention class ContextEfficiencyAnalyzer: """Analyzes how efficiently context is used in ICL""" def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer self.device = next(model.parameters()).device def analyze_context_efficiency( self, examples: List[Tuple[str, str]], # (input, output) pairs test_prompt: str, attention_weights: Optional[List[Dict]] = None, generated_tokens: List[str] = None, confidence_scores: List[float] = None ) -> ContextEfficiencyAnalysis: """ Comprehensive analysis of context efficiency """ # Tokenize all examples example_tokens = [] example_boundaries = [] current_pos = 0 for idx, (input_text, output_text) in enumerate(examples): example_text = f"{input_text}\n{output_text}\n" tokens = self.tokenizer.tokenize(example_text) example_tokens.extend(tokens) example_boundaries.append((current_pos, current_pos + len(tokens))) current_pos += len(tokens) # Analyze each example's efficiency example_efficiencies = [] for idx, (start, end) in enumerate(example_boundaries): efficiency = self._analyze_example_efficiency( example_idx=idx, example_tokens=example_tokens[start:end], all_tokens=example_tokens, attention_weights=attention_weights, generated_tokens=generated_tokens ) example_efficiencies.append(efficiency) # Analyze token-level efficiency token_efficiencies = self._analyze_token_efficiency( example_tokens=example_tokens, attention_weights=attention_weights, generated_tokens=generated_tokens ) # Calculate redundancy patterns redundancy_patterns = self._identify_redundancy_patterns( example_tokens=example_tokens, token_efficiencies=token_efficiencies ) # Determine optimal example count optimal_count = self._calculate_optimal_example_count( example_efficiencies=example_efficiencies ) # Calculate compression potential compression_potential = self._calculate_compression_potential( token_efficiencies=token_efficiencies ) # Calculate attention utilization attention_utilization = self._calculate_attention_utilization( attention_weights=attention_weights, total_context_tokens=len(example_tokens) ) # Calculate overall efficiency effective_tokens = sum(1 for t in token_efficiencies if t.redundancy_score < 0.5) overall_efficiency = effective_tokens / max(len(example_tokens), 1) return ContextEfficiencyAnalysis( overall_efficiency=overall_efficiency, total_context_tokens=len(example_tokens), effective_context_tokens=effective_tokens, example_efficiencies=example_efficiencies, token_efficiencies=token_efficiencies, optimal_example_count=optimal_count, redundancy_patterns=redundancy_patterns, compression_potential=compression_potential, attention_utilization=attention_utilization ) def _analyze_example_efficiency( self, example_idx: int, example_tokens: List[str], all_tokens: List[str], attention_weights: Optional[List[Dict]], generated_tokens: List[str] ) -> ExampleEfficiency: """Analyze efficiency of a single example""" # Calculate redundancy with previous examples redundant_count = 0 if example_idx > 0: # Check for repeated patterns for token in example_tokens: if all_tokens[:example_idx * len(example_tokens)].count(token) > 2: redundant_count += 1 redundancy_rate = redundant_count / max(len(example_tokens), 1) # Calculate information density (simplified Shannon entropy) unique_tokens = len(set(example_tokens)) information_density = np.log2(max(unique_tokens, 1)) / max(len(example_tokens), 1) # Calculate marginal benefit (how much this example adds) if example_idx == 0: marginal_benefit = 1.0 # First example always has full benefit else: # Estimate based on new unique patterns introduced new_patterns = set(example_tokens) - set(all_tokens[:example_idx * len(example_tokens)]) marginal_benefit = len(new_patterns) / max(len(example_tokens), 1) # Calculate effective tokens (those that contribute) effective_tokens = int(len(example_tokens) * (1 - redundancy_rate)) return ExampleEfficiency( example_id=str(example_idx + 1), total_tokens=len(example_tokens), effective_tokens=effective_tokens, efficiency_ratio=effective_tokens / max(len(example_tokens), 1), redundancy_rate=redundancy_rate, information_density=information_density, marginal_benefit=marginal_benefit ) def _analyze_token_efficiency( self, example_tokens: List[str], attention_weights: Optional[List[Dict]], generated_tokens: List[str] ) -> List[TokenEfficiency]: """Analyze efficiency of individual tokens""" token_efficiencies = [] for idx, token in enumerate(example_tokens): # Calculate information content (simplified) # Rare tokens have more information frequency = example_tokens.count(token) information_content = np.log2(len(example_tokens) / max(frequency, 1)) # Calculate redundancy # Tokens that appear many times in same context are redundant local_window = example_tokens[max(0, idx-5):min(len(example_tokens), idx+5)] local_frequency = local_window.count(token) redundancy_score = min(local_frequency / 3.0, 1.0) # Cap at 1.0 # Calculate contribution score # Based on whether similar tokens appear in output contribution_score = 0.0 if generated_tokens: # Check if token or similar tokens appear in output if token in generated_tokens: contribution_score = 1.0 elif any(token.lower() in gen_token.lower() for gen_token in generated_tokens): contribution_score = 0.5 token_efficiencies.append(TokenEfficiency( token=token, position=idx, information_content=information_content, redundancy_score=redundancy_score, contribution_score=contribution_score )) return token_efficiencies def _identify_redundancy_patterns( self, example_tokens: List[str], token_efficiencies: List[TokenEfficiency] ) -> Dict[str, float]: """Identify common redundancy patterns""" patterns = { 'repeated_tokens': 0.0, 'boilerplate': 0.0, 'structural_repetition': 0.0, 'semantic_overlap': 0.0 } # Count repeated tokens token_counts = {} for token in example_tokens: token_counts[token] = token_counts.get(token, 0) + 1 repeated = sum(1 for count in token_counts.values() if count > 3) patterns['repeated_tokens'] = repeated / max(len(token_counts), 1) # Detect boilerplate (common programming patterns) boilerplate_tokens = ['def', 'class', 'return', 'import', 'from', '"""', "'''"] boilerplate_count = sum(1 for token in example_tokens if token in boilerplate_tokens) patterns['boilerplate'] = boilerplate_count / max(len(example_tokens), 1) # Detect structural repetition (same patterns) # Look for sequences that repeat sequence_length = 3 sequences = {} for i in range(len(example_tokens) - sequence_length): seq = tuple(example_tokens[i:i+sequence_length]) sequences[seq] = sequences.get(seq, 0) + 1 repeated_sequences = sum(1 for count in sequences.values() if count > 1) patterns['structural_repetition'] = repeated_sequences / max(len(sequences), 1) # Estimate semantic overlap (tokens with high redundancy scores) high_redundancy = sum(1 for t in token_efficiencies if t.redundancy_score > 0.7) patterns['semantic_overlap'] = high_redundancy / max(len(token_efficiencies), 1) return patterns def _calculate_optimal_example_count( self, example_efficiencies: List[ExampleEfficiency] ) -> int: """Determine the optimal number of examples based on marginal benefits""" if not example_efficiencies: return 0 # Find point where marginal benefit drops below threshold threshold = 0.3 # Examples adding less than 30% benefit are not worth it for idx, efficiency in enumerate(example_efficiencies): if efficiency.marginal_benefit < threshold and idx > 0: return idx # If all examples have good marginal benefit, use all return len(example_efficiencies) def _calculate_compression_potential( self, token_efficiencies: List[TokenEfficiency] ) -> float: """Calculate how much the context could be compressed""" if not token_efficiencies: return 0.0 # Tokens with high redundancy and low contribution can be removed removable = sum( 1 for t in token_efficiencies if t.redundancy_score > 0.6 and t.contribution_score < 0.3 ) return removable / len(token_efficiencies) def _calculate_attention_utilization( self, attention_weights: Optional[List[Dict]], total_context_tokens: int ) -> float: """Calculate what percentage of context receives significant attention""" if not attention_weights or total_context_tokens == 0: return 0.0 # Aggregate attention across all layers and heads attended_positions = set() for record in attention_weights: attn = record.get('attention') if attn is not None and attn.dim() >= 3: # Average across heads and look at which positions get attention avg_attn = attn.mean(dim=1) # Average across heads # Positions with attention > threshold are considered "utilized" threshold = 0.05 high_attention = (avg_attn > threshold).nonzero(as_tuple=True) if len(high_attention) > 1: attended_positions.update(high_attention[1].tolist()) # Filter to only context positions context_attended = [pos for pos in attended_positions if pos < total_context_tokens] return len(context_attended) / total_context_tokens if total_context_tokens > 0 else 0.0