Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 13,407 Bytes
920a98d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 |
"""
Context Efficiency Analyzer for In-Context Learning
Measures how efficiently the model uses context examples to perform tasks.
Based on research showing that not all examples contribute equally and that
optimal context usage can significantly improve performance.
"""
import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class TokenEfficiency:
"""Efficiency metrics for individual tokens"""
token: str
position: int
information_content: float # Bits of information
redundancy_score: float # 0-1 (1 = completely redundant)
contribution_score: float # How much it contributes to output
@dataclass
class ExampleEfficiency:
"""Efficiency metrics for each example"""
example_id: str
total_tokens: int
effective_tokens: int # Tokens that actually contribute
efficiency_ratio: float # effective/total
redundancy_rate: float # Percentage of redundant tokens
information_density: float # Bits per token
marginal_benefit: float # Additional benefit vs previous examples
@dataclass
class ContextEfficiencyAnalysis:
"""Complete context efficiency analysis"""
overall_efficiency: float # 0-1 score
total_context_tokens: int
effective_context_tokens: int
example_efficiencies: List[ExampleEfficiency]
token_efficiencies: List[TokenEfficiency]
optimal_example_count: int # Suggested optimal number of examples
redundancy_patterns: Dict[str, float] # Pattern type -> frequency
compression_potential: float # How much context could be compressed
attention_utilization: float # How much of context gets attention
class ContextEfficiencyAnalyzer:
"""Analyzes how efficiently context is used in ICL"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
def analyze_context_efficiency(
self,
examples: List[Tuple[str, str]], # (input, output) pairs
test_prompt: str,
attention_weights: Optional[List[Dict]] = None,
generated_tokens: List[str] = None,
confidence_scores: List[float] = None
) -> ContextEfficiencyAnalysis:
"""
Comprehensive analysis of context efficiency
"""
# Tokenize all examples
example_tokens = []
example_boundaries = []
current_pos = 0
for idx, (input_text, output_text) in enumerate(examples):
example_text = f"{input_text}\n{output_text}\n"
tokens = self.tokenizer.tokenize(example_text)
example_tokens.extend(tokens)
example_boundaries.append((current_pos, current_pos + len(tokens)))
current_pos += len(tokens)
# Analyze each example's efficiency
example_efficiencies = []
for idx, (start, end) in enumerate(example_boundaries):
efficiency = self._analyze_example_efficiency(
example_idx=idx,
example_tokens=example_tokens[start:end],
all_tokens=example_tokens,
attention_weights=attention_weights,
generated_tokens=generated_tokens
)
example_efficiencies.append(efficiency)
# Analyze token-level efficiency
token_efficiencies = self._analyze_token_efficiency(
example_tokens=example_tokens,
attention_weights=attention_weights,
generated_tokens=generated_tokens
)
# Calculate redundancy patterns
redundancy_patterns = self._identify_redundancy_patterns(
example_tokens=example_tokens,
token_efficiencies=token_efficiencies
)
# Determine optimal example count
optimal_count = self._calculate_optimal_example_count(
example_efficiencies=example_efficiencies
)
# Calculate compression potential
compression_potential = self._calculate_compression_potential(
token_efficiencies=token_efficiencies
)
# Calculate attention utilization
attention_utilization = self._calculate_attention_utilization(
attention_weights=attention_weights,
total_context_tokens=len(example_tokens)
)
# Calculate overall efficiency
effective_tokens = sum(1 for t in token_efficiencies if t.redundancy_score < 0.5)
overall_efficiency = effective_tokens / max(len(example_tokens), 1)
return ContextEfficiencyAnalysis(
overall_efficiency=overall_efficiency,
total_context_tokens=len(example_tokens),
effective_context_tokens=effective_tokens,
example_efficiencies=example_efficiencies,
token_efficiencies=token_efficiencies,
optimal_example_count=optimal_count,
redundancy_patterns=redundancy_patterns,
compression_potential=compression_potential,
attention_utilization=attention_utilization
)
def _analyze_example_efficiency(
self,
example_idx: int,
example_tokens: List[str],
all_tokens: List[str],
attention_weights: Optional[List[Dict]],
generated_tokens: List[str]
) -> ExampleEfficiency:
"""Analyze efficiency of a single example"""
# Calculate redundancy with previous examples
redundant_count = 0
if example_idx > 0:
# Check for repeated patterns
for token in example_tokens:
if all_tokens[:example_idx * len(example_tokens)].count(token) > 2:
redundant_count += 1
redundancy_rate = redundant_count / max(len(example_tokens), 1)
# Calculate information density (simplified Shannon entropy)
unique_tokens = len(set(example_tokens))
information_density = np.log2(max(unique_tokens, 1)) / max(len(example_tokens), 1)
# Calculate marginal benefit (how much this example adds)
if example_idx == 0:
marginal_benefit = 1.0 # First example always has full benefit
else:
# Estimate based on new unique patterns introduced
new_patterns = set(example_tokens) - set(all_tokens[:example_idx * len(example_tokens)])
marginal_benefit = len(new_patterns) / max(len(example_tokens), 1)
# Calculate effective tokens (those that contribute)
effective_tokens = int(len(example_tokens) * (1 - redundancy_rate))
return ExampleEfficiency(
example_id=str(example_idx + 1),
total_tokens=len(example_tokens),
effective_tokens=effective_tokens,
efficiency_ratio=effective_tokens / max(len(example_tokens), 1),
redundancy_rate=redundancy_rate,
information_density=information_density,
marginal_benefit=marginal_benefit
)
def _analyze_token_efficiency(
self,
example_tokens: List[str],
attention_weights: Optional[List[Dict]],
generated_tokens: List[str]
) -> List[TokenEfficiency]:
"""Analyze efficiency of individual tokens"""
token_efficiencies = []
for idx, token in enumerate(example_tokens):
# Calculate information content (simplified)
# Rare tokens have more information
frequency = example_tokens.count(token)
information_content = np.log2(len(example_tokens) / max(frequency, 1))
# Calculate redundancy
# Tokens that appear many times in same context are redundant
local_window = example_tokens[max(0, idx-5):min(len(example_tokens), idx+5)]
local_frequency = local_window.count(token)
redundancy_score = min(local_frequency / 3.0, 1.0) # Cap at 1.0
# Calculate contribution score
# Based on whether similar tokens appear in output
contribution_score = 0.0
if generated_tokens:
# Check if token or similar tokens appear in output
if token in generated_tokens:
contribution_score = 1.0
elif any(token.lower() in gen_token.lower() for gen_token in generated_tokens):
contribution_score = 0.5
token_efficiencies.append(TokenEfficiency(
token=token,
position=idx,
information_content=information_content,
redundancy_score=redundancy_score,
contribution_score=contribution_score
))
return token_efficiencies
def _identify_redundancy_patterns(
self,
example_tokens: List[str],
token_efficiencies: List[TokenEfficiency]
) -> Dict[str, float]:
"""Identify common redundancy patterns"""
patterns = {
'repeated_tokens': 0.0,
'boilerplate': 0.0,
'structural_repetition': 0.0,
'semantic_overlap': 0.0
}
# Count repeated tokens
token_counts = {}
for token in example_tokens:
token_counts[token] = token_counts.get(token, 0) + 1
repeated = sum(1 for count in token_counts.values() if count > 3)
patterns['repeated_tokens'] = repeated / max(len(token_counts), 1)
# Detect boilerplate (common programming patterns)
boilerplate_tokens = ['def', 'class', 'return', 'import', 'from', '"""', "'''"]
boilerplate_count = sum(1 for token in example_tokens if token in boilerplate_tokens)
patterns['boilerplate'] = boilerplate_count / max(len(example_tokens), 1)
# Detect structural repetition (same patterns)
# Look for sequences that repeat
sequence_length = 3
sequences = {}
for i in range(len(example_tokens) - sequence_length):
seq = tuple(example_tokens[i:i+sequence_length])
sequences[seq] = sequences.get(seq, 0) + 1
repeated_sequences = sum(1 for count in sequences.values() if count > 1)
patterns['structural_repetition'] = repeated_sequences / max(len(sequences), 1)
# Estimate semantic overlap (tokens with high redundancy scores)
high_redundancy = sum(1 for t in token_efficiencies if t.redundancy_score > 0.7)
patterns['semantic_overlap'] = high_redundancy / max(len(token_efficiencies), 1)
return patterns
def _calculate_optimal_example_count(
self,
example_efficiencies: List[ExampleEfficiency]
) -> int:
"""Determine the optimal number of examples based on marginal benefits"""
if not example_efficiencies:
return 0
# Find point where marginal benefit drops below threshold
threshold = 0.3 # Examples adding less than 30% benefit are not worth it
for idx, efficiency in enumerate(example_efficiencies):
if efficiency.marginal_benefit < threshold and idx > 0:
return idx
# If all examples have good marginal benefit, use all
return len(example_efficiencies)
def _calculate_compression_potential(
self,
token_efficiencies: List[TokenEfficiency]
) -> float:
"""Calculate how much the context could be compressed"""
if not token_efficiencies:
return 0.0
# Tokens with high redundancy and low contribution can be removed
removable = sum(
1 for t in token_efficiencies
if t.redundancy_score > 0.6 and t.contribution_score < 0.3
)
return removable / len(token_efficiencies)
def _calculate_attention_utilization(
self,
attention_weights: Optional[List[Dict]],
total_context_tokens: int
) -> float:
"""Calculate what percentage of context receives significant attention"""
if not attention_weights or total_context_tokens == 0:
return 0.0
# Aggregate attention across all layers and heads
attended_positions = set()
for record in attention_weights:
attn = record.get('attention')
if attn is not None and attn.dim() >= 3:
# Average across heads and look at which positions get attention
avg_attn = attn.mean(dim=1) # Average across heads
# Positions with attention > threshold are considered "utilized"
threshold = 0.05
high_attention = (avg_attn > threshold).nonzero(as_tuple=True)
if len(high_attention) > 1:
attended_positions.update(high_attention[1].tolist())
# Filter to only context positions
context_attended = [pos for pos in attended_positions if pos < total_context_tokens]
return len(context_attended) / total_context_tokens if total_context_tokens > 0 else 0.0 |