File size: 13,193 Bytes
920a98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed40a9a
 
920a98d
 
ed40a9a
920a98d
ed40a9a
 
 
 
 
920a98d
ed40a9a
 
920a98d
ed40a9a
 
920a98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""
In-Context Learning Analysis Service

Analyzes how examples influence model behavior during code generation.
"""

import torch
import numpy as np
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
from .icl_attention_extractor import AttentionExtractor
from .induction_head_detector import InductionHeadDetector, ICLEmergenceAnalysis
import logging

logger = logging.getLogger(__name__)

@dataclass
class ICLExample:
    """Represents an in-context learning example"""
    input: str
    output: str
    
@dataclass 
class ICLAnalysisResult:
    """Results from ICL analysis"""
    shot_count: int
    generated_code: str
    tokens: List[str]
    confidence_scores: List[float]
    attention_from_examples: Dict[str, List[float]]  # example_id -> attention weights per token
    perplexity: float
    avg_confidence: float
    example_influences: Dict[str, float]  # example_id -> overall influence score
    hidden_state_drift: Optional[List[float]] = None  # magnitude of hidden state changes
    icl_emergence: Optional[ICLEmergenceAnalysis] = None  # When/how ICL kicks in

class ICLAnalyzer:
    """Analyzes in-context learning effects on model behavior"""

    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, adapter=None):
        self.model = model
        self.tokenizer = tokenizer
        self.adapter = adapter
        self.device = next(model.parameters()).device

        # Ensure tokenizer has pad_token (needed for Code-Llama)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Initialize attention extractor for real attention data
        self.attention_extractor = AttentionExtractor(model, tokenizer, adapter=adapter)

        # Initialize induction head detector
        self.induction_detector = InductionHeadDetector(model, tokenizer, adapter=adapter)

        # Storage for attention patterns
        self.attention_maps = []
        self.hidden_states = []
        
    def prepare_prompt_with_examples(self, examples: List[ICLExample], test_prompt: str) -> str:
        """Construct prompt with examples in standard format"""
        if not examples:
            return test_prompt
            
        prompt_parts = []
        for example in examples:
            prompt_parts.append(f"{example.input}\n{example.output}\n")
        prompt_parts.append(test_prompt)
        
        return "\n".join(prompt_parts)
    
    def extract_attention_patterns(self, outputs, input_ids, example_boundaries: List[Tuple[int, int]]) -> Dict[str, List[float]]:
        """Extract attention patterns - real if available, simulated otherwise"""
        
        # Try to use real attention data if available
        if hasattr(self, 'last_attention_data') and self.last_attention_data:
            logger.info("Using real attention data from model hooks")
            prompt_length = len(input_ids[0])
            return self.attention_extractor.aggregate_attention_to_examples(
                self.last_attention_data,
                example_boundaries,
                prompt_length
            )
        
        # Fall back to simulated patterns
        logger.info("Using simulated attention patterns")
        attention_from_examples = {}
        
        if not example_boundaries:
            return attention_from_examples
            
        generated_ids = outputs.sequences[0][len(input_ids[0]):]
        num_generated = len(generated_ids)
        
        if num_generated == 0:
            return attention_from_examples
            
        # Create simulated patterns (existing code)
        for idx, (start, end) in enumerate(example_boundaries):
            example_id = str(idx + 1)
            base_weight = 0.3 + (idx * 0.1) / len(example_boundaries)
            
            attention_weights = []
            for token_idx in range(num_generated):
                weight = base_weight * np.exp(-token_idx * 0.05)
                weight += np.random.normal(0, 0.02)
                weight = max(0, min(1, weight))
                attention_weights.append(weight)
                
            attention_from_examples[example_id] = attention_weights
            
        # Normalize
        if len(attention_from_examples) > 1:
            for token_idx in range(num_generated):
                total = sum(weights[token_idx] for weights in attention_from_examples.values())
                if total > 0:
                    for example_id in attention_from_examples:
                        attention_from_examples[example_id][token_idx] /= total
                        
        return attention_from_examples
    
    def calculate_example_influences(self, attention_from_examples: Dict[str, List[float]]) -> Dict[str, float]:
        """Calculate overall influence score for each example"""
        
        # If we have real attention data, use the extractor's method
        if hasattr(self, 'last_attention_data') and self.last_attention_data:
            return self.attention_extractor.calculate_example_influences(attention_from_examples)
        
        # Otherwise use existing calculation
        influences = {}
        
        for example_id, weights in attention_from_examples.items():
            influences[example_id] = float(np.mean(weights)) if weights else 0.0
            
        total = sum(influences.values())
        if total > 0 and total != 1.0:
            influences = {k: v/total for k, v in influences.items()}
            
        return influences
    
    def track_hidden_state_drift(self, base_hidden_states, example_hidden_states) -> List[float]:
        """Track how hidden states change from base (no examples) to with examples"""
        if base_hidden_states is None or example_hidden_states is None:
            return []
            
        # Calculate L2 distance between hidden states at each position
        drift = []
        min_len = min(len(base_hidden_states), len(example_hidden_states))
        
        for i in range(min_len):
            base = base_hidden_states[i]
            example = example_hidden_states[i]
            
            if isinstance(base, torch.Tensor):
                base = base.cpu().numpy()
            if isinstance(example, torch.Tensor):
                example = example.cpu().numpy()
                
            distance = np.linalg.norm(example - base)
            drift.append(float(distance))
            
        return drift
    
    def analyze_generation(
        self,
        examples: List[ICLExample],
        test_prompt: str,
        max_length: int = 150,
        temperature: float = 0.7,
        base_hidden_states: Optional[Any] = None
    ) -> ICLAnalysisResult:
        """Analyze how examples influence generation"""
        
        # Prepare prompt
        full_prompt = self.prepare_prompt_with_examples(examples, test_prompt)
        
        # Tokenize
        inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs["input_ids"].to(self.device)
        attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(self.device)
        
        # Find example boundaries in token space
        example_boundaries = []
        if examples:
            current_pos = 0
            for example in examples:
                example_text = f"{example.input}\n{example.output}\n"
                example_tokens = self.tokenizer(example_text, add_special_tokens=False)["input_ids"]
                example_boundaries.append((current_pos, current_pos + len(example_tokens)))
                current_pos += len(example_tokens)
        
        # First do standard generation to get scores and text
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_length=max_length,
                temperature=temperature,
                do_sample=temperature > 0,
                pad_token_id=self.tokenizer.pad_token_id,
                return_dict_in_generate=True,
                output_scores=True,
                output_hidden_states=False
            )
        
        # Then try to extract real attention data
        try:
            logger.info("Extracting real attention data")
            _, attention_data, _ = self.attention_extractor.extract_attention_with_generation(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=min(30, max_length - len(input_ids[0])),  # Limit for performance
                temperature=temperature
            )
            self.last_attention_data = attention_data
            logger.info(f"Successfully extracted {len(attention_data)} attention records")
        except Exception as e:
            logger.warning(f"Real attention extraction failed: {e}")
            self.last_attention_data = None
        
        # Extract generated tokens - show raw output, no trimming
        generated_ids = outputs.sequences[0][len(input_ids[0]):]
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        tokens = [self.tokenizer.decode([token_id]) for token_id in generated_ids]
        
        # Calculate confidence scores
        confidence_scores = []
        if outputs.scores:
            for score in outputs.scores:
                probs = F.softmax(score[0], dim=-1)
                max_prob = probs.max().item()
                confidence_scores.append(max_prob)
        
        # Calculate perplexity
        if outputs.scores:
            log_probs = []
            for i, score in enumerate(outputs.scores):
                if i < len(generated_ids):
                    token_id = generated_ids[i]
                    log_prob = F.log_softmax(score[0], dim=-1)[token_id].item()
                    log_probs.append(log_prob)
            perplexity = np.exp(-np.mean(log_probs)) if log_probs else 0.0
        else:
            perplexity = 0.0
        
        # Extract attention patterns
        attention_from_examples = self.extract_attention_patterns(outputs, input_ids, example_boundaries)
        
        # Calculate example influences
        example_influences = self.calculate_example_influences(attention_from_examples)
        
        # Track hidden state drift if base states provided
        hidden_state_drift = None
        if base_hidden_states is not None and hasattr(outputs, 'hidden_states'):
            current_hidden = outputs.hidden_states[-1] if outputs.hidden_states else None
            if current_hidden is not None:
                hidden_state_drift = self.track_hidden_state_drift(base_hidden_states, current_hidden)
        
        # Analyze ICL emergence if we have attention data and examples
        icl_emergence = None
        if self.last_attention_data and len(examples) > 0:
            try:
                icl_emergence = self.induction_detector.analyze_icl_emergence(
                    self.last_attention_data,
                    input_ids,
                    example_boundaries,
                    generated_ids.tolist() if generated_ids.numel() > 0 else []
                )
                logger.info(f"ICL emergence analysis: detected={icl_emergence.emergence_detected}, "
                          f"token={icl_emergence.emergence_token}, confidence={icl_emergence.confidence:.2f}")
            except Exception as e:
                logger.warning(f"ICL emergence analysis failed: {e}")
        
        return ICLAnalysisResult(
            shot_count=len(examples),
            generated_code=generated_text,
            tokens=tokens,
            confidence_scores=confidence_scores,
            attention_from_examples=attention_from_examples,
            perplexity=perplexity,
            avg_confidence=np.mean(confidence_scores) if confidence_scores else 0.0,
            example_influences=example_influences,
            hidden_state_drift=hidden_state_drift,
            icl_emergence=icl_emergence
        )
    
    def compare_shot_settings(
        self,
        examples: List[ICLExample],
        test_prompt: str,
        max_length: int = 150,
        temperature: float = 0.7
    ) -> Dict[str, ICLAnalysisResult]:
        """Compare 0-shot, 1-shot, and few-shot generation"""
        results = {}
        
        # 0-shot (no examples)
        results['zero_shot'] = self.analyze_generation([], test_prompt, max_length, temperature)
        base_hidden = results['zero_shot'].hidden_state_drift  # Use as baseline
        
        # 1-shot (first example only)
        if len(examples) >= 1:
            results['one_shot'] = self.analyze_generation(
                examples[:1], test_prompt, max_length, temperature, base_hidden
            )
        
        # Few-shot (all examples)
        if len(examples) >= 2:
            results['few_shot'] = self.analyze_generation(
                examples, test_prompt, max_length, temperature, base_hidden
            )
        
        return results