api / backend /context_efficiency_analyzer.py
gary-boon
Add backend support for ICL emergence analysis
920a98d
raw
history blame
13.4 kB
"""
Context Efficiency Analyzer for In-Context Learning
Measures how efficiently the model uses context examples to perform tasks.
Based on research showing that not all examples contribute equally and that
optimal context usage can significantly improve performance.
"""
import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class TokenEfficiency:
"""Efficiency metrics for individual tokens"""
token: str
position: int
information_content: float # Bits of information
redundancy_score: float # 0-1 (1 = completely redundant)
contribution_score: float # How much it contributes to output
@dataclass
class ExampleEfficiency:
"""Efficiency metrics for each example"""
example_id: str
total_tokens: int
effective_tokens: int # Tokens that actually contribute
efficiency_ratio: float # effective/total
redundancy_rate: float # Percentage of redundant tokens
information_density: float # Bits per token
marginal_benefit: float # Additional benefit vs previous examples
@dataclass
class ContextEfficiencyAnalysis:
"""Complete context efficiency analysis"""
overall_efficiency: float # 0-1 score
total_context_tokens: int
effective_context_tokens: int
example_efficiencies: List[ExampleEfficiency]
token_efficiencies: List[TokenEfficiency]
optimal_example_count: int # Suggested optimal number of examples
redundancy_patterns: Dict[str, float] # Pattern type -> frequency
compression_potential: float # How much context could be compressed
attention_utilization: float # How much of context gets attention
class ContextEfficiencyAnalyzer:
"""Analyzes how efficiently context is used in ICL"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
def analyze_context_efficiency(
self,
examples: List[Tuple[str, str]], # (input, output) pairs
test_prompt: str,
attention_weights: Optional[List[Dict]] = None,
generated_tokens: List[str] = None,
confidence_scores: List[float] = None
) -> ContextEfficiencyAnalysis:
"""
Comprehensive analysis of context efficiency
"""
# Tokenize all examples
example_tokens = []
example_boundaries = []
current_pos = 0
for idx, (input_text, output_text) in enumerate(examples):
example_text = f"{input_text}\n{output_text}\n"
tokens = self.tokenizer.tokenize(example_text)
example_tokens.extend(tokens)
example_boundaries.append((current_pos, current_pos + len(tokens)))
current_pos += len(tokens)
# Analyze each example's efficiency
example_efficiencies = []
for idx, (start, end) in enumerate(example_boundaries):
efficiency = self._analyze_example_efficiency(
example_idx=idx,
example_tokens=example_tokens[start:end],
all_tokens=example_tokens,
attention_weights=attention_weights,
generated_tokens=generated_tokens
)
example_efficiencies.append(efficiency)
# Analyze token-level efficiency
token_efficiencies = self._analyze_token_efficiency(
example_tokens=example_tokens,
attention_weights=attention_weights,
generated_tokens=generated_tokens
)
# Calculate redundancy patterns
redundancy_patterns = self._identify_redundancy_patterns(
example_tokens=example_tokens,
token_efficiencies=token_efficiencies
)
# Determine optimal example count
optimal_count = self._calculate_optimal_example_count(
example_efficiencies=example_efficiencies
)
# Calculate compression potential
compression_potential = self._calculate_compression_potential(
token_efficiencies=token_efficiencies
)
# Calculate attention utilization
attention_utilization = self._calculate_attention_utilization(
attention_weights=attention_weights,
total_context_tokens=len(example_tokens)
)
# Calculate overall efficiency
effective_tokens = sum(1 for t in token_efficiencies if t.redundancy_score < 0.5)
overall_efficiency = effective_tokens / max(len(example_tokens), 1)
return ContextEfficiencyAnalysis(
overall_efficiency=overall_efficiency,
total_context_tokens=len(example_tokens),
effective_context_tokens=effective_tokens,
example_efficiencies=example_efficiencies,
token_efficiencies=token_efficiencies,
optimal_example_count=optimal_count,
redundancy_patterns=redundancy_patterns,
compression_potential=compression_potential,
attention_utilization=attention_utilization
)
def _analyze_example_efficiency(
self,
example_idx: int,
example_tokens: List[str],
all_tokens: List[str],
attention_weights: Optional[List[Dict]],
generated_tokens: List[str]
) -> ExampleEfficiency:
"""Analyze efficiency of a single example"""
# Calculate redundancy with previous examples
redundant_count = 0
if example_idx > 0:
# Check for repeated patterns
for token in example_tokens:
if all_tokens[:example_idx * len(example_tokens)].count(token) > 2:
redundant_count += 1
redundancy_rate = redundant_count / max(len(example_tokens), 1)
# Calculate information density (simplified Shannon entropy)
unique_tokens = len(set(example_tokens))
information_density = np.log2(max(unique_tokens, 1)) / max(len(example_tokens), 1)
# Calculate marginal benefit (how much this example adds)
if example_idx == 0:
marginal_benefit = 1.0 # First example always has full benefit
else:
# Estimate based on new unique patterns introduced
new_patterns = set(example_tokens) - set(all_tokens[:example_idx * len(example_tokens)])
marginal_benefit = len(new_patterns) / max(len(example_tokens), 1)
# Calculate effective tokens (those that contribute)
effective_tokens = int(len(example_tokens) * (1 - redundancy_rate))
return ExampleEfficiency(
example_id=str(example_idx + 1),
total_tokens=len(example_tokens),
effective_tokens=effective_tokens,
efficiency_ratio=effective_tokens / max(len(example_tokens), 1),
redundancy_rate=redundancy_rate,
information_density=information_density,
marginal_benefit=marginal_benefit
)
def _analyze_token_efficiency(
self,
example_tokens: List[str],
attention_weights: Optional[List[Dict]],
generated_tokens: List[str]
) -> List[TokenEfficiency]:
"""Analyze efficiency of individual tokens"""
token_efficiencies = []
for idx, token in enumerate(example_tokens):
# Calculate information content (simplified)
# Rare tokens have more information
frequency = example_tokens.count(token)
information_content = np.log2(len(example_tokens) / max(frequency, 1))
# Calculate redundancy
# Tokens that appear many times in same context are redundant
local_window = example_tokens[max(0, idx-5):min(len(example_tokens), idx+5)]
local_frequency = local_window.count(token)
redundancy_score = min(local_frequency / 3.0, 1.0) # Cap at 1.0
# Calculate contribution score
# Based on whether similar tokens appear in output
contribution_score = 0.0
if generated_tokens:
# Check if token or similar tokens appear in output
if token in generated_tokens:
contribution_score = 1.0
elif any(token.lower() in gen_token.lower() for gen_token in generated_tokens):
contribution_score = 0.5
token_efficiencies.append(TokenEfficiency(
token=token,
position=idx,
information_content=information_content,
redundancy_score=redundancy_score,
contribution_score=contribution_score
))
return token_efficiencies
def _identify_redundancy_patterns(
self,
example_tokens: List[str],
token_efficiencies: List[TokenEfficiency]
) -> Dict[str, float]:
"""Identify common redundancy patterns"""
patterns = {
'repeated_tokens': 0.0,
'boilerplate': 0.0,
'structural_repetition': 0.0,
'semantic_overlap': 0.0
}
# Count repeated tokens
token_counts = {}
for token in example_tokens:
token_counts[token] = token_counts.get(token, 0) + 1
repeated = sum(1 for count in token_counts.values() if count > 3)
patterns['repeated_tokens'] = repeated / max(len(token_counts), 1)
# Detect boilerplate (common programming patterns)
boilerplate_tokens = ['def', 'class', 'return', 'import', 'from', '"""', "'''"]
boilerplate_count = sum(1 for token in example_tokens if token in boilerplate_tokens)
patterns['boilerplate'] = boilerplate_count / max(len(example_tokens), 1)
# Detect structural repetition (same patterns)
# Look for sequences that repeat
sequence_length = 3
sequences = {}
for i in range(len(example_tokens) - sequence_length):
seq = tuple(example_tokens[i:i+sequence_length])
sequences[seq] = sequences.get(seq, 0) + 1
repeated_sequences = sum(1 for count in sequences.values() if count > 1)
patterns['structural_repetition'] = repeated_sequences / max(len(sequences), 1)
# Estimate semantic overlap (tokens with high redundancy scores)
high_redundancy = sum(1 for t in token_efficiencies if t.redundancy_score > 0.7)
patterns['semantic_overlap'] = high_redundancy / max(len(token_efficiencies), 1)
return patterns
def _calculate_optimal_example_count(
self,
example_efficiencies: List[ExampleEfficiency]
) -> int:
"""Determine the optimal number of examples based on marginal benefits"""
if not example_efficiencies:
return 0
# Find point where marginal benefit drops below threshold
threshold = 0.3 # Examples adding less than 30% benefit are not worth it
for idx, efficiency in enumerate(example_efficiencies):
if efficiency.marginal_benefit < threshold and idx > 0:
return idx
# If all examples have good marginal benefit, use all
return len(example_efficiencies)
def _calculate_compression_potential(
self,
token_efficiencies: List[TokenEfficiency]
) -> float:
"""Calculate how much the context could be compressed"""
if not token_efficiencies:
return 0.0
# Tokens with high redundancy and low contribution can be removed
removable = sum(
1 for t in token_efficiencies
if t.redundancy_score > 0.6 and t.contribution_score < 0.3
)
return removable / len(token_efficiencies)
def _calculate_attention_utilization(
self,
attention_weights: Optional[List[Dict]],
total_context_tokens: int
) -> float:
"""Calculate what percentage of context receives significant attention"""
if not attention_weights or total_context_tokens == 0:
return 0.0
# Aggregate attention across all layers and heads
attended_positions = set()
for record in attention_weights:
attn = record.get('attention')
if attn is not None and attn.dim() >= 3:
# Average across heads and look at which positions get attention
avg_attn = attn.mean(dim=1) # Average across heads
# Positions with attention > threshold are considered "utilized"
threshold = 0.05
high_attention = (avg_attn > threshold).nonzero(as_tuple=True)
if len(high_attention) > 1:
attended_positions.update(high_attention[1].tolist())
# Filter to only context positions
context_attended = [pos for pos in attended_positions if pos < total_context_tokens]
return len(context_attended) / total_context_tokens if total_context_tokens > 0 else 0.0