File size: 13,407 Bytes
920a98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""
Context Efficiency Analyzer for In-Context Learning

Measures how efficiently the model uses context examples to perform tasks.
Based on research showing that not all examples contribute equally and that
optimal context usage can significantly improve performance.
"""

import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import logging

logger = logging.getLogger(__name__)

@dataclass
class TokenEfficiency:
    """Efficiency metrics for individual tokens"""
    token: str
    position: int
    information_content: float  # Bits of information
    redundancy_score: float  # 0-1 (1 = completely redundant)
    contribution_score: float  # How much it contributes to output

@dataclass
class ExampleEfficiency:
    """Efficiency metrics for each example"""
    example_id: str
    total_tokens: int
    effective_tokens: int  # Tokens that actually contribute
    efficiency_ratio: float  # effective/total
    redundancy_rate: float  # Percentage of redundant tokens
    information_density: float  # Bits per token
    marginal_benefit: float  # Additional benefit vs previous examples

@dataclass
class ContextEfficiencyAnalysis:
    """Complete context efficiency analysis"""
    overall_efficiency: float  # 0-1 score
    total_context_tokens: int
    effective_context_tokens: int
    example_efficiencies: List[ExampleEfficiency]
    token_efficiencies: List[TokenEfficiency]
    optimal_example_count: int  # Suggested optimal number of examples
    redundancy_patterns: Dict[str, float]  # Pattern type -> frequency
    compression_potential: float  # How much context could be compressed
    attention_utilization: float  # How much of context gets attention

class ContextEfficiencyAnalyzer:
    """Analyzes how efficiently context is used in ICL"""
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device
        
    def analyze_context_efficiency(
        self,
        examples: List[Tuple[str, str]],  # (input, output) pairs
        test_prompt: str,
        attention_weights: Optional[List[Dict]] = None,
        generated_tokens: List[str] = None,
        confidence_scores: List[float] = None
    ) -> ContextEfficiencyAnalysis:
        """
        Comprehensive analysis of context efficiency
        """
        
        # Tokenize all examples
        example_tokens = []
        example_boundaries = []
        current_pos = 0
        
        for idx, (input_text, output_text) in enumerate(examples):
            example_text = f"{input_text}\n{output_text}\n"
            tokens = self.tokenizer.tokenize(example_text)
            example_tokens.extend(tokens)
            example_boundaries.append((current_pos, current_pos + len(tokens)))
            current_pos += len(tokens)
        
        # Analyze each example's efficiency
        example_efficiencies = []
        for idx, (start, end) in enumerate(example_boundaries):
            efficiency = self._analyze_example_efficiency(
                example_idx=idx,
                example_tokens=example_tokens[start:end],
                all_tokens=example_tokens,
                attention_weights=attention_weights,
                generated_tokens=generated_tokens
            )
            example_efficiencies.append(efficiency)
        
        # Analyze token-level efficiency
        token_efficiencies = self._analyze_token_efficiency(
            example_tokens=example_tokens,
            attention_weights=attention_weights,
            generated_tokens=generated_tokens
        )
        
        # Calculate redundancy patterns
        redundancy_patterns = self._identify_redundancy_patterns(
            example_tokens=example_tokens,
            token_efficiencies=token_efficiencies
        )
        
        # Determine optimal example count
        optimal_count = self._calculate_optimal_example_count(
            example_efficiencies=example_efficiencies
        )
        
        # Calculate compression potential
        compression_potential = self._calculate_compression_potential(
            token_efficiencies=token_efficiencies
        )
        
        # Calculate attention utilization
        attention_utilization = self._calculate_attention_utilization(
            attention_weights=attention_weights,
            total_context_tokens=len(example_tokens)
        )
        
        # Calculate overall efficiency
        effective_tokens = sum(1 for t in token_efficiencies if t.redundancy_score < 0.5)
        overall_efficiency = effective_tokens / max(len(example_tokens), 1)
        
        return ContextEfficiencyAnalysis(
            overall_efficiency=overall_efficiency,
            total_context_tokens=len(example_tokens),
            effective_context_tokens=effective_tokens,
            example_efficiencies=example_efficiencies,
            token_efficiencies=token_efficiencies,
            optimal_example_count=optimal_count,
            redundancy_patterns=redundancy_patterns,
            compression_potential=compression_potential,
            attention_utilization=attention_utilization
        )
    
    def _analyze_example_efficiency(
        self,
        example_idx: int,
        example_tokens: List[str],
        all_tokens: List[str],
        attention_weights: Optional[List[Dict]],
        generated_tokens: List[str]
    ) -> ExampleEfficiency:
        """Analyze efficiency of a single example"""
        
        # Calculate redundancy with previous examples
        redundant_count = 0
        if example_idx > 0:
            # Check for repeated patterns
            for token in example_tokens:
                if all_tokens[:example_idx * len(example_tokens)].count(token) > 2:
                    redundant_count += 1
        
        redundancy_rate = redundant_count / max(len(example_tokens), 1)
        
        # Calculate information density (simplified Shannon entropy)
        unique_tokens = len(set(example_tokens))
        information_density = np.log2(max(unique_tokens, 1)) / max(len(example_tokens), 1)
        
        # Calculate marginal benefit (how much this example adds)
        if example_idx == 0:
            marginal_benefit = 1.0  # First example always has full benefit
        else:
            # Estimate based on new unique patterns introduced
            new_patterns = set(example_tokens) - set(all_tokens[:example_idx * len(example_tokens)])
            marginal_benefit = len(new_patterns) / max(len(example_tokens), 1)
        
        # Calculate effective tokens (those that contribute)
        effective_tokens = int(len(example_tokens) * (1 - redundancy_rate))
        
        return ExampleEfficiency(
            example_id=str(example_idx + 1),
            total_tokens=len(example_tokens),
            effective_tokens=effective_tokens,
            efficiency_ratio=effective_tokens / max(len(example_tokens), 1),
            redundancy_rate=redundancy_rate,
            information_density=information_density,
            marginal_benefit=marginal_benefit
        )
    
    def _analyze_token_efficiency(
        self,
        example_tokens: List[str],
        attention_weights: Optional[List[Dict]],
        generated_tokens: List[str]
    ) -> List[TokenEfficiency]:
        """Analyze efficiency of individual tokens"""
        
        token_efficiencies = []
        
        for idx, token in enumerate(example_tokens):
            # Calculate information content (simplified)
            # Rare tokens have more information
            frequency = example_tokens.count(token)
            information_content = np.log2(len(example_tokens) / max(frequency, 1))
            
            # Calculate redundancy
            # Tokens that appear many times in same context are redundant
            local_window = example_tokens[max(0, idx-5):min(len(example_tokens), idx+5)]
            local_frequency = local_window.count(token)
            redundancy_score = min(local_frequency / 3.0, 1.0)  # Cap at 1.0
            
            # Calculate contribution score
            # Based on whether similar tokens appear in output
            contribution_score = 0.0
            if generated_tokens:
                # Check if token or similar tokens appear in output
                if token in generated_tokens:
                    contribution_score = 1.0
                elif any(token.lower() in gen_token.lower() for gen_token in generated_tokens):
                    contribution_score = 0.5
            
            token_efficiencies.append(TokenEfficiency(
                token=token,
                position=idx,
                information_content=information_content,
                redundancy_score=redundancy_score,
                contribution_score=contribution_score
            ))
        
        return token_efficiencies
    
    def _identify_redundancy_patterns(
        self,
        example_tokens: List[str],
        token_efficiencies: List[TokenEfficiency]
    ) -> Dict[str, float]:
        """Identify common redundancy patterns"""
        
        patterns = {
            'repeated_tokens': 0.0,
            'boilerplate': 0.0,
            'structural_repetition': 0.0,
            'semantic_overlap': 0.0
        }
        
        # Count repeated tokens
        token_counts = {}
        for token in example_tokens:
            token_counts[token] = token_counts.get(token, 0) + 1
        
        repeated = sum(1 for count in token_counts.values() if count > 3)
        patterns['repeated_tokens'] = repeated / max(len(token_counts), 1)
        
        # Detect boilerplate (common programming patterns)
        boilerplate_tokens = ['def', 'class', 'return', 'import', 'from', '"""', "'''"]
        boilerplate_count = sum(1 for token in example_tokens if token in boilerplate_tokens)
        patterns['boilerplate'] = boilerplate_count / max(len(example_tokens), 1)
        
        # Detect structural repetition (same patterns)
        # Look for sequences that repeat
        sequence_length = 3
        sequences = {}
        for i in range(len(example_tokens) - sequence_length):
            seq = tuple(example_tokens[i:i+sequence_length])
            sequences[seq] = sequences.get(seq, 0) + 1
        
        repeated_sequences = sum(1 for count in sequences.values() if count > 1)
        patterns['structural_repetition'] = repeated_sequences / max(len(sequences), 1)
        
        # Estimate semantic overlap (tokens with high redundancy scores)
        high_redundancy = sum(1 for t in token_efficiencies if t.redundancy_score > 0.7)
        patterns['semantic_overlap'] = high_redundancy / max(len(token_efficiencies), 1)
        
        return patterns
    
    def _calculate_optimal_example_count(
        self,
        example_efficiencies: List[ExampleEfficiency]
    ) -> int:
        """Determine the optimal number of examples based on marginal benefits"""
        
        if not example_efficiencies:
            return 0
        
        # Find point where marginal benefit drops below threshold
        threshold = 0.3  # Examples adding less than 30% benefit are not worth it
        
        for idx, efficiency in enumerate(example_efficiencies):
            if efficiency.marginal_benefit < threshold and idx > 0:
                return idx
        
        # If all examples have good marginal benefit, use all
        return len(example_efficiencies)
    
    def _calculate_compression_potential(
        self,
        token_efficiencies: List[TokenEfficiency]
    ) -> float:
        """Calculate how much the context could be compressed"""
        
        if not token_efficiencies:
            return 0.0
        
        # Tokens with high redundancy and low contribution can be removed
        removable = sum(
            1 for t in token_efficiencies 
            if t.redundancy_score > 0.6 and t.contribution_score < 0.3
        )
        
        return removable / len(token_efficiencies)
    
    def _calculate_attention_utilization(
        self,
        attention_weights: Optional[List[Dict]],
        total_context_tokens: int
    ) -> float:
        """Calculate what percentage of context receives significant attention"""
        
        if not attention_weights or total_context_tokens == 0:
            return 0.0
        
        # Aggregate attention across all layers and heads
        attended_positions = set()
        
        for record in attention_weights:
            attn = record.get('attention')
            if attn is not None and attn.dim() >= 3:
                # Average across heads and look at which positions get attention
                avg_attn = attn.mean(dim=1)  # Average across heads
                
                # Positions with attention > threshold are considered "utilized"
                threshold = 0.05
                high_attention = (avg_attn > threshold).nonzero(as_tuple=True)
                
                if len(high_attention) > 1:
                    attended_positions.update(high_attention[1].tolist())
        
        # Filter to only context positions
        context_attended = [pos for pos in attended_positions if pos < total_context_tokens]
        
        return len(context_attended) / total_context_tokens if total_context_tokens > 0 else 0.0