File size: 23,926 Bytes
767a3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
"""
Transformer Pipeline Analyzer
Captures and returns all intermediate states of transformer processing
"""

import torch
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
import logging

logger = logging.getLogger(__name__)

@dataclass
class PipelineStep:
    """Represents a single step in the transformer pipeline"""
    step_number: int
    step_name: str
    step_type: str  # 'tokenization', 'embedding', 'attention', 'ffn', 'output'
    description: str
    data: Dict[str, Any]

class TransformerPipelineAnalyzer:
    """Analyzes the complete flow through a transformer model"""
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device
        self.steps = []
        self.intermediate_states = {}
        
    def analyze_pipeline(self, text: str, max_new_tokens: int = 1, 
                        temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95) -> Dict[str, Any]:
        """
        Capture all steps of transformer processing for multiple tokens
        
        Args:
            text: Input text to analyze
            max_new_tokens: Number of tokens to generate (default 1)
            temperature: Controls randomness in generation (default 0.7)
            top_k: Limits to top K most likely tokens (default 50)
            top_p: Cumulative probability cutoff (default 0.95)
            
        Returns:
            Dict containing tokens generated and their pipeline steps
        """
        all_tokens = []
        all_pipelines = []
        current_text = text
        
        # First generate all the tokens using the model's generate method
        # This ensures proper autoregressive generation
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True)
            input_ids = inputs["input_ids"].to(self.device)
            
            # Generate tokens properly using model.generate()
            generated_ids = self.model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=True,  # Enable sampling for variety
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
            )
            
            # Extract only the new tokens
            new_token_ids = generated_ids[0, input_ids.shape[1]:].tolist()
            generated_tokens = [self.tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in new_token_ids]
            
            logger.info(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
        
        # Now analyze the pipeline for each generated token
        for token_idx, next_token in enumerate(generated_tokens):
            # Analyze pipeline for current text (which will predict the next token)
            pipeline_steps = self._analyze_single_token(current_text, token_idx)
            
            # Update the output step with the actual generated token
            # (since _analyze_single_token might predict differently due to sampling)
            for step in reversed(pipeline_steps):
                if step.step_type == 'output':
                    # Update with the actual generated token
                    step.data['predicted_token'] = next_token
                    step.data['actual_token_id'] = new_token_ids[token_idx] if token_idx < len(new_token_ids) else None
                    break
            
            all_tokens.append(next_token)
            all_pipelines.append(pipeline_steps)
            current_text += next_token
            
            # Store first pipeline for backward compatibility
            if token_idx == 0:
                self.last_single_token_steps = pipeline_steps
        
        return {
            'tokens': all_tokens,
            'pipelines': all_pipelines,
            'final_text': current_text,
            'num_tokens': len(all_tokens)
        }
    
    def _analyze_single_token(self, text: str, token_position: int) -> List[PipelineStep]:
        """
        Analyze the pipeline for generating a single token
        
        Args:
            text: Current text to continue from
            token_position: Position of this token in the generation sequence
            
        Returns:
            List of PipelineStep objects for this token
        """
        steps = []
        step_counter = 0
        
        # Step 1: Raw Input
        steps.append(PipelineStep(
            step_number=step_counter,
            step_name="Raw Input",
            step_type="input",
            description="The original text input provided by the user",
            data={"text": text, "length": len(text)}
        ))
        step_counter += 1
        
        # Step 2: Tokenization
        inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True)
        input_ids = inputs["input_ids"].to(self.device)
        tokens = [self.tokenizer.decode([tid]) for tid in input_ids[0]]
        token_ids = input_ids[0].tolist()
        
        steps.append(PipelineStep(
            step_number=step_counter,
            step_name="Tokenization",
            step_type="tokenization",
            description="Text split into subword tokens using the model's tokenizer",
            data={
                "tokens": tokens,
                "token_ids": token_ids,
                "num_tokens": len(tokens),
                "tokenizer_name": self.tokenizer.__class__.__name__
            }
        ))
        step_counter += 1
        
        # Step 3: Token Embeddings
        with torch.no_grad():
            # Get token embeddings
            if hasattr(self.model, 'transformer'):
                embed_layer = self.model.transformer.wte
                pos_embed_layer = self.model.transformer.wpe if hasattr(self.model.transformer, 'wpe') else None
            else:
                embed_layer = self.model.get_input_embeddings()
                pos_embed_layer = None
            
            token_embeddings = embed_layer(input_ids)
            
            # Add positional embeddings if available
            if pos_embed_layer:
                position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=self.device)
                position_ids = position_ids.unsqueeze(0)
                position_embeddings = pos_embed_layer(position_ids)
                embeddings = token_embeddings + position_embeddings
            else:
                embeddings = token_embeddings
                position_embeddings = None
            
            steps.append(PipelineStep(
                step_number=step_counter,
                step_name="Initial Embeddings",
                step_type="embedding",
                description="Token embeddings combined with positional encodings",
                data={
                    "embedding_dim": embeddings.shape[-1],
                    "has_position_encoding": pos_embed_layer is not None,
                    "embeddings_sample": embeddings[0, :3, :8].cpu().numpy().tolist(),  # First 3 tokens, 8 dims
                    "embeddings_shape": list(embeddings.shape)
                }
            ))
            step_counter += 1
            
            # Step 4-N: Process through layers
            current_hidden = embeddings
            
            # Get model layers
            if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
                layers = self.model.transformer.h
            else:
                layers = self.model.encoder.layer if hasattr(self.model, 'encoder') else []
            
            # Process through each layer
            for layer_idx, layer in enumerate(layers[:4]):  # Sample first 4 layers for performance
                # Attention mechanism
                layer_output = self._process_layer(layer, current_hidden, layer_idx)
                
                # Add attention step with tokens for labeling
                steps.append(PipelineStep(
                    step_number=step_counter,
                    step_name=f"Layer {layer_idx} - Multi-Head Attention",
                    step_type="attention",
                    description=f"Self-attention computation in layer {layer_idx}",
                    data={
                        "layer": layer_idx,
                        "num_heads": self._get_num_heads(layer),
                        "attention_pattern": layer_output.get("attention_pattern", None),
                        "tokens": tokens,  # Include tokens for labeling the attention matrix
                        "hidden_state_norm": float(torch.norm(layer_output["hidden_states"]).item())
                    }
                ))
                step_counter += 1
                
                # Feed-forward network
                if "ffn_output" in layer_output:
                    steps.append(PipelineStep(
                        step_number=step_counter,
                        step_name=f"Layer {layer_idx} - Feed-Forward Network",
                        step_type="ffn",
                        description=f"Feed-forward transformation in layer {layer_idx}",
                        data={
                            "layer": layer_idx,
                            "activation": "gelu",  # Most transformers use GELU
                            "hidden_state_norm": float(torch.norm(layer_output["ffn_output"]).item()),
                            "intermediate_size": layer_output.get("intermediate_size", 4096),
                            "hidden_size": layer_output.get("hidden_size", 1024),
                            "activation_stats": layer_output.get("activation_stats", {}),
                            "gate_values": layer_output.get("gate_values", None),
                            "tokens": tokens,  # Include tokens for context
                            "token_magnitudes": layer_output.get("token_magnitudes", [])
                        }
                    ))
                    step_counter += 1
                
                current_hidden = layer_output["hidden_states"]
            
            # Final layer norm (if exists)
            if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'ln_f'):
                current_hidden = self.model.transformer.ln_f(current_hidden)
                
                steps.append(PipelineStep(
                    step_number=step_counter,
                    step_name="Final Layer Normalization",
                    step_type="normalization",
                    description="Normalize hidden states before output projection",
                    data={
                        "norm_type": "LayerNorm",
                        "hidden_state_norm": float(torch.norm(current_hidden).item())
                    }
                ))
                step_counter += 1
            
            # Output projection
            if hasattr(self.model, 'lm_head'):
                logits = self.model.lm_head(current_hidden)
            else:
                logits = current_hidden
            
            # Get probabilities for the last token
            last_token_logits = logits[0, -1, :]
            probs = torch.softmax(last_token_logits, dim=-1)
            
            # Get top 5 predictions
            top_probs, top_indices = torch.topk(probs, 5)
            # Decode tokens properly, preserving whitespace and special characters
            top_tokens = []
            for idx in top_indices.tolist():
                decoded = self.tokenizer.decode([idx], skip_special_tokens=False, clean_up_tokenization_spaces=False)
                top_tokens.append(decoded)
                # Debug logging
                if idx == top_indices[0].item():
                    import logging
                    logger = logging.getLogger(__name__)
                    logger.info(f"Token generation - Input: '{text}', Predicted ID: {idx}, Decoded: '{decoded}'")
            
            steps.append(PipelineStep(
                step_number=step_counter,
                step_name="Output Projection",
                step_type="output",
                description="Project to vocabulary and compute probabilities",
                data={
                    "vocab_size": logits.shape[-1],
                    "top_5_tokens": top_tokens,
                    "top_5_probs": top_probs.cpu().numpy().tolist(),
                    "predicted_token": top_tokens[0],
                    "confidence": float(top_probs[0].item())
                }
            ))
            step_counter += 1
            
            # Step N: Generated Result
            # For code generation, we might want to show the first meaningful token
            # Check if the predicted token is just whitespace or quote
            predicted_token = top_tokens[0]
            display_token = predicted_token
            additional_info = ""
            
            # If it's a trivial token (quote, newline, whitespace), note what comes next
            if predicted_token in ["'", '"', "\n", " ", "    ", "\t"]:
                additional_info = f"Next token: '{predicted_token}' (formatting)"
                # Show what would come after formatting tokens
                if len(top_tokens) > 1:
                    for alt_token in top_tokens[1:]:
                        if alt_token not in ["'", '"', "\n", " ", "    ", "\t"]:
                            additional_info += f", likely code token: '{alt_token}'"
                            break
            
            generated_text = text + predicted_token
            steps.append(PipelineStep(
                step_number=step_counter,
                step_name="Generated Result",
                step_type="generated",
                description=f"Complete text with token #{token_position + 1}",
                data={
                    "original_text": text,
                    "predicted_token": predicted_token,
                    "complete_text": generated_text,
                    "is_code": "def " in text.lower() or "class " in text.lower() or "import " in text.lower(),
                    "additional_info": additional_info,
                    "token_position": token_position + 1
                }
            ))
            step_counter += 1
        
        return steps
    
    def _process_layer(self, layer, hidden_states, layer_idx):
        """Process a single transformer layer"""
        output = {}
        
        try:
            # Process with attention weight capture
            with torch.no_grad():
                if hasattr(layer, 'attn'):
                    # GPT-style architecture - capture attention weights
                    # First apply layer norm if present
                    ln_output = layer.ln_1(hidden_states) if hasattr(layer, 'ln_1') else hidden_states
                    
                    # Get attention weights by calling the attention module with output_attentions
                    qkv = None
                    if hasattr(layer.attn, 'qkv_proj'):
                        # CodeGen architecture - has combined QKV projection
                        qkv = layer.attn.qkv_proj(ln_output)
                        embed_dim = layer.attn.embed_dim
                        n_head = layer.attn.num_attention_heads if hasattr(layer.attn, 'num_attention_heads') else 8
                    elif hasattr(layer.attn, 'c_attn'):
                        # GPT2-style architecture
                        qkv = layer.attn.c_attn(ln_output)
                        embed_dim = layer.attn.embed_dim
                        n_head = layer.attn.n_head if hasattr(layer.attn, 'n_head') else 8
                    
                    if qkv is not None:
                        # Split into Q, K, V
                        query, key, value = qkv.split(embed_dim, dim=2)
                        
                        # Reshape for multi-head attention
                        batch_size, seq_len = query.shape[:2]
                        head_dim = embed_dim // n_head
                        
                        query = query.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
                        key = key.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
                        value = value.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
                        
                        # Compute attention scores
                        attn_weights = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)
                        
                        # Apply causal mask (for autoregressive models)
                        if hasattr(layer.attn, 'bias') and layer.attn.bias is not None:
                            attn_weights = attn_weights + layer.attn.bias[:, :, :seq_len, :seq_len]
                        else:
                            # Create causal mask manually if no bias exists
                            causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=attn_weights.device) * -1e4, diagonal=1)
                            attn_weights = attn_weights + causal_mask.unsqueeze(0).unsqueeze(0)
                        
                        # Apply softmax
                        attn_probs = torch.softmax(attn_weights, dim=-1)
                        
                        # Average across heads for visualization
                        avg_attn = attn_probs.mean(dim=1)  # Shape: [batch, seq_len, seq_len]
                        
                        # Store the full attention pattern
                        output["attention_pattern"] = avg_attn[0].cpu().numpy().tolist()  # Full seq_len x seq_len
                        logger.info(f"Extracted attention pattern with shape: {avg_attn[0].shape}")
                        
                        # Apply attention to values and continue processing
                        attn_output = torch.matmul(attn_probs, value)
                        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
                        
                        # Apply output projection
                        if hasattr(layer.attn, 'out_proj'):
                            # CodeGen architecture
                            attn_output = layer.attn.out_proj(attn_output)
                        elif hasattr(layer.attn, 'c_proj'):
                            # GPT2-style architecture
                            attn_output = layer.attn.c_proj(attn_output)
                        
                        # Apply residual dropout if present
                        if hasattr(layer.attn, 'resid_dropout'):
                            attn_output = layer.attn.resid_dropout(attn_output)
                        
                        # Add residual connection
                        attn_output = hidden_states + attn_output
                    else:
                        # Fallback for different architecture
                        attn_output = layer.attn(hidden_states)
                        if isinstance(attn_output, tuple):
                            attn_output = attn_output[0]
                    
                    # Apply MLP with detailed analysis
                    if hasattr(layer, 'mlp'):
                        ln2_output = layer.ln_2(attn_output) if hasattr(layer, 'ln_2') else attn_output
                        
                        # Extract detailed FFN information
                        if hasattr(layer.mlp, 'fc_in') or hasattr(layer.mlp, 'c_fc'):
                            # Get intermediate layer
                            if hasattr(layer.mlp, 'fc_in'):
                                # CodeGen architecture
                                intermediate = layer.mlp.fc_in(ln2_output)
                                output["intermediate_size"] = layer.mlp.fc_in.out_features
                                output["hidden_size"] = layer.mlp.fc_in.in_features
                            elif hasattr(layer.mlp, 'c_fc'):
                                # GPT2 architecture
                                intermediate = layer.mlp.c_fc(ln2_output)
                                output["intermediate_size"] = layer.mlp.c_fc.out_features
                                output["hidden_size"] = layer.mlp.c_fc.in_features
                            
                            # Compute activation statistics
                            with torch.no_grad():
                                act_values = intermediate.detach()
                                output["activation_stats"] = {
                                    "mean": float(act_values.mean().item()),
                                    "std": float(act_values.std().item()),
                                    "max": float(act_values.max().item()),
                                    "min": float(act_values.min().item()),
                                    "sparsity": float((act_values == 0).float().mean().item()),  # Fraction of zeros
                                    "active_neurons": int((act_values.abs() > 0.1).sum().item())  # Neurons with significant activation
                                }
                                
                                # Get per-token magnitudes (average activation magnitude per token)
                                token_mags = act_values.abs().mean(dim=-1)[0].cpu().numpy().tolist()
                                output["token_magnitudes"] = token_mags
                        
                        mlp_output = layer.mlp(ln2_output)
                        output["ffn_output"] = mlp_output
                        hidden_states = attn_output + mlp_output
                    else:
                        hidden_states = attn_output
                else:
                    # BERT-style or other architecture
                    hidden_states = layer(hidden_states)[0]
                
                output["hidden_states"] = hidden_states
                    
        except Exception as e:
            logger.warning(f"Error processing layer {layer_idx}: {e}")
            import traceback
            logger.warning(f"Traceback: {traceback.format_exc()}")
            output["hidden_states"] = hidden_states
            # Fallback to simple pattern if real extraction fails
            if "attention_pattern" not in output:
                seq_len = hidden_states.shape[1]
                output["attention_pattern"] = np.eye(seq_len).tolist()  # Identity matrix as fallback
                logger.warning(f"Using fallback attention pattern for layer {layer_idx}")
            
        return output
    
    def _get_num_heads(self, layer):
        """Get number of attention heads in a layer"""
        if hasattr(layer, 'attn'):
            if hasattr(layer.attn, 'num_attention_heads'):
                return layer.attn.num_attention_heads  # CodeGen
            elif hasattr(layer.attn, 'n_head'):
                return layer.attn.n_head  # GPT2
            elif hasattr(layer.attn, 'num_heads'):
                return layer.attn.num_heads  # Other architectures
        return 8  # Default guess
    
    def get_steps_dict(self) -> List[Dict]:
        """Convert steps to dictionary format for JSON serialization
        
        This is kept for backward compatibility but may not work with multi-token generation.
        Use the result from analyze_pipeline directly instead.
        """
        # If we have stored steps from single token generation, return them
        if hasattr(self, 'last_single_token_steps'):
            return [asdict(step) for step in self.last_single_token_steps]
        return []