Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """ | |
| Context Efficiency Analyzer for In-Context Learning | |
| Measures how efficiently the model uses context examples to perform tasks. | |
| Based on research showing that not all examples contribute equally and that | |
| optimal context usage can significantly improve performance. | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Tuple, Optional | |
| from dataclasses import dataclass | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class TokenEfficiency: | |
| """Efficiency metrics for individual tokens""" | |
| token: str | |
| position: int | |
| information_content: float # Bits of information | |
| redundancy_score: float # 0-1 (1 = completely redundant) | |
| contribution_score: float # How much it contributes to output | |
| class ExampleEfficiency: | |
| """Efficiency metrics for each example""" | |
| example_id: str | |
| total_tokens: int | |
| effective_tokens: int # Tokens that actually contribute | |
| efficiency_ratio: float # effective/total | |
| redundancy_rate: float # Percentage of redundant tokens | |
| information_density: float # Bits per token | |
| marginal_benefit: float # Additional benefit vs previous examples | |
| class ContextEfficiencyAnalysis: | |
| """Complete context efficiency analysis""" | |
| overall_efficiency: float # 0-1 score | |
| total_context_tokens: int | |
| effective_context_tokens: int | |
| example_efficiencies: List[ExampleEfficiency] | |
| token_efficiencies: List[TokenEfficiency] | |
| optimal_example_count: int # Suggested optimal number of examples | |
| redundancy_patterns: Dict[str, float] # Pattern type -> frequency | |
| compression_potential: float # How much context could be compressed | |
| attention_utilization: float # How much of context gets attention | |
| class ContextEfficiencyAnalyzer: | |
| """Analyzes how efficiently context is used in ICL""" | |
| def __init__(self, model, tokenizer): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.device = next(model.parameters()).device | |
| def analyze_context_efficiency( | |
| self, | |
| examples: List[Tuple[str, str]], # (input, output) pairs | |
| test_prompt: str, | |
| attention_weights: Optional[List[Dict]] = None, | |
| generated_tokens: List[str] = None, | |
| confidence_scores: List[float] = None | |
| ) -> ContextEfficiencyAnalysis: | |
| """ | |
| Comprehensive analysis of context efficiency | |
| """ | |
| # Tokenize all examples | |
| example_tokens = [] | |
| example_boundaries = [] | |
| current_pos = 0 | |
| for idx, (input_text, output_text) in enumerate(examples): | |
| example_text = f"{input_text}\n{output_text}\n" | |
| tokens = self.tokenizer.tokenize(example_text) | |
| example_tokens.extend(tokens) | |
| example_boundaries.append((current_pos, current_pos + len(tokens))) | |
| current_pos += len(tokens) | |
| # Analyze each example's efficiency | |
| example_efficiencies = [] | |
| for idx, (start, end) in enumerate(example_boundaries): | |
| efficiency = self._analyze_example_efficiency( | |
| example_idx=idx, | |
| example_tokens=example_tokens[start:end], | |
| all_tokens=example_tokens, | |
| attention_weights=attention_weights, | |
| generated_tokens=generated_tokens | |
| ) | |
| example_efficiencies.append(efficiency) | |
| # Analyze token-level efficiency | |
| token_efficiencies = self._analyze_token_efficiency( | |
| example_tokens=example_tokens, | |
| attention_weights=attention_weights, | |
| generated_tokens=generated_tokens | |
| ) | |
| # Calculate redundancy patterns | |
| redundancy_patterns = self._identify_redundancy_patterns( | |
| example_tokens=example_tokens, | |
| token_efficiencies=token_efficiencies | |
| ) | |
| # Determine optimal example count | |
| optimal_count = self._calculate_optimal_example_count( | |
| example_efficiencies=example_efficiencies | |
| ) | |
| # Calculate compression potential | |
| compression_potential = self._calculate_compression_potential( | |
| token_efficiencies=token_efficiencies | |
| ) | |
| # Calculate attention utilization | |
| attention_utilization = self._calculate_attention_utilization( | |
| attention_weights=attention_weights, | |
| total_context_tokens=len(example_tokens) | |
| ) | |
| # Calculate overall efficiency | |
| effective_tokens = sum(1 for t in token_efficiencies if t.redundancy_score < 0.5) | |
| overall_efficiency = effective_tokens / max(len(example_tokens), 1) | |
| return ContextEfficiencyAnalysis( | |
| overall_efficiency=overall_efficiency, | |
| total_context_tokens=len(example_tokens), | |
| effective_context_tokens=effective_tokens, | |
| example_efficiencies=example_efficiencies, | |
| token_efficiencies=token_efficiencies, | |
| optimal_example_count=optimal_count, | |
| redundancy_patterns=redundancy_patterns, | |
| compression_potential=compression_potential, | |
| attention_utilization=attention_utilization | |
| ) | |
| def _analyze_example_efficiency( | |
| self, | |
| example_idx: int, | |
| example_tokens: List[str], | |
| all_tokens: List[str], | |
| attention_weights: Optional[List[Dict]], | |
| generated_tokens: List[str] | |
| ) -> ExampleEfficiency: | |
| """Analyze efficiency of a single example""" | |
| # Calculate redundancy with previous examples | |
| redundant_count = 0 | |
| if example_idx > 0: | |
| # Check for repeated patterns | |
| for token in example_tokens: | |
| if all_tokens[:example_idx * len(example_tokens)].count(token) > 2: | |
| redundant_count += 1 | |
| redundancy_rate = redundant_count / max(len(example_tokens), 1) | |
| # Calculate information density (simplified Shannon entropy) | |
| unique_tokens = len(set(example_tokens)) | |
| information_density = np.log2(max(unique_tokens, 1)) / max(len(example_tokens), 1) | |
| # Calculate marginal benefit (how much this example adds) | |
| if example_idx == 0: | |
| marginal_benefit = 1.0 # First example always has full benefit | |
| else: | |
| # Estimate based on new unique patterns introduced | |
| new_patterns = set(example_tokens) - set(all_tokens[:example_idx * len(example_tokens)]) | |
| marginal_benefit = len(new_patterns) / max(len(example_tokens), 1) | |
| # Calculate effective tokens (those that contribute) | |
| effective_tokens = int(len(example_tokens) * (1 - redundancy_rate)) | |
| return ExampleEfficiency( | |
| example_id=str(example_idx + 1), | |
| total_tokens=len(example_tokens), | |
| effective_tokens=effective_tokens, | |
| efficiency_ratio=effective_tokens / max(len(example_tokens), 1), | |
| redundancy_rate=redundancy_rate, | |
| information_density=information_density, | |
| marginal_benefit=marginal_benefit | |
| ) | |
| def _analyze_token_efficiency( | |
| self, | |
| example_tokens: List[str], | |
| attention_weights: Optional[List[Dict]], | |
| generated_tokens: List[str] | |
| ) -> List[TokenEfficiency]: | |
| """Analyze efficiency of individual tokens""" | |
| token_efficiencies = [] | |
| for idx, token in enumerate(example_tokens): | |
| # Calculate information content (simplified) | |
| # Rare tokens have more information | |
| frequency = example_tokens.count(token) | |
| information_content = np.log2(len(example_tokens) / max(frequency, 1)) | |
| # Calculate redundancy | |
| # Tokens that appear many times in same context are redundant | |
| local_window = example_tokens[max(0, idx-5):min(len(example_tokens), idx+5)] | |
| local_frequency = local_window.count(token) | |
| redundancy_score = min(local_frequency / 3.0, 1.0) # Cap at 1.0 | |
| # Calculate contribution score | |
| # Based on whether similar tokens appear in output | |
| contribution_score = 0.0 | |
| if generated_tokens: | |
| # Check if token or similar tokens appear in output | |
| if token in generated_tokens: | |
| contribution_score = 1.0 | |
| elif any(token.lower() in gen_token.lower() for gen_token in generated_tokens): | |
| contribution_score = 0.5 | |
| token_efficiencies.append(TokenEfficiency( | |
| token=token, | |
| position=idx, | |
| information_content=information_content, | |
| redundancy_score=redundancy_score, | |
| contribution_score=contribution_score | |
| )) | |
| return token_efficiencies | |
| def _identify_redundancy_patterns( | |
| self, | |
| example_tokens: List[str], | |
| token_efficiencies: List[TokenEfficiency] | |
| ) -> Dict[str, float]: | |
| """Identify common redundancy patterns""" | |
| patterns = { | |
| 'repeated_tokens': 0.0, | |
| 'boilerplate': 0.0, | |
| 'structural_repetition': 0.0, | |
| 'semantic_overlap': 0.0 | |
| } | |
| # Count repeated tokens | |
| token_counts = {} | |
| for token in example_tokens: | |
| token_counts[token] = token_counts.get(token, 0) + 1 | |
| repeated = sum(1 for count in token_counts.values() if count > 3) | |
| patterns['repeated_tokens'] = repeated / max(len(token_counts), 1) | |
| # Detect boilerplate (common programming patterns) | |
| boilerplate_tokens = ['def', 'class', 'return', 'import', 'from', '"""', "'''"] | |
| boilerplate_count = sum(1 for token in example_tokens if token in boilerplate_tokens) | |
| patterns['boilerplate'] = boilerplate_count / max(len(example_tokens), 1) | |
| # Detect structural repetition (same patterns) | |
| # Look for sequences that repeat | |
| sequence_length = 3 | |
| sequences = {} | |
| for i in range(len(example_tokens) - sequence_length): | |
| seq = tuple(example_tokens[i:i+sequence_length]) | |
| sequences[seq] = sequences.get(seq, 0) + 1 | |
| repeated_sequences = sum(1 for count in sequences.values() if count > 1) | |
| patterns['structural_repetition'] = repeated_sequences / max(len(sequences), 1) | |
| # Estimate semantic overlap (tokens with high redundancy scores) | |
| high_redundancy = sum(1 for t in token_efficiencies if t.redundancy_score > 0.7) | |
| patterns['semantic_overlap'] = high_redundancy / max(len(token_efficiencies), 1) | |
| return patterns | |
| def _calculate_optimal_example_count( | |
| self, | |
| example_efficiencies: List[ExampleEfficiency] | |
| ) -> int: | |
| """Determine the optimal number of examples based on marginal benefits""" | |
| if not example_efficiencies: | |
| return 0 | |
| # Find point where marginal benefit drops below threshold | |
| threshold = 0.3 # Examples adding less than 30% benefit are not worth it | |
| for idx, efficiency in enumerate(example_efficiencies): | |
| if efficiency.marginal_benefit < threshold and idx > 0: | |
| return idx | |
| # If all examples have good marginal benefit, use all | |
| return len(example_efficiencies) | |
| def _calculate_compression_potential( | |
| self, | |
| token_efficiencies: List[TokenEfficiency] | |
| ) -> float: | |
| """Calculate how much the context could be compressed""" | |
| if not token_efficiencies: | |
| return 0.0 | |
| # Tokens with high redundancy and low contribution can be removed | |
| removable = sum( | |
| 1 for t in token_efficiencies | |
| if t.redundancy_score > 0.6 and t.contribution_score < 0.3 | |
| ) | |
| return removable / len(token_efficiencies) | |
| def _calculate_attention_utilization( | |
| self, | |
| attention_weights: Optional[List[Dict]], | |
| total_context_tokens: int | |
| ) -> float: | |
| """Calculate what percentage of context receives significant attention""" | |
| if not attention_weights or total_context_tokens == 0: | |
| return 0.0 | |
| # Aggregate attention across all layers and heads | |
| attended_positions = set() | |
| for record in attention_weights: | |
| attn = record.get('attention') | |
| if attn is not None and attn.dim() >= 3: | |
| # Average across heads and look at which positions get attention | |
| avg_attn = attn.mean(dim=1) # Average across heads | |
| # Positions with attention > threshold are considered "utilized" | |
| threshold = 0.05 | |
| high_attention = (avg_attn > threshold).nonzero(as_tuple=True) | |
| if len(high_attention) > 1: | |
| attended_positions.update(high_attention[1].tolist()) | |
| # Filter to only context positions | |
| context_attended = [pos for pos in attended_positions if pos < total_context_tokens] | |
| return len(context_attended) / total_context_tokens if total_context_tokens > 0 else 0.0 |