File size: 28,991 Bytes
767a3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed40a9a
 
767a3fd
 
ed40a9a
767a3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed40a9a
767a3fd
ed40a9a
 
 
 
 
 
 
 
 
 
 
 
 
767a3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed40a9a
 
 
 
 
 
 
 
 
 
767a3fd
ed40a9a
 
 
767a3fd
ed40a9a
767a3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed40a9a
767a3fd
 
ed40a9a
 
 
 
 
 
767a3fd
 
 
 
 
ed40a9a
767a3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed40a9a
767a3fd
 
 
ed40a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767a3fd
ed40a9a
 
 
 
 
 
 
 
 
767a3fd
ed40a9a
 
 
 
767a3fd
 
ed40a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767a3fd
 
 
ed40a9a
767a3fd
 
 
ed40a9a
767a3fd
 
ed40a9a
767a3fd
ed40a9a
 
 
767a3fd
 
ed40a9a
767a3fd
 
ed40a9a
767a3fd
ed40a9a
767a3fd
ed40a9a
 
767a3fd
 
ed40a9a
767a3fd
ed40a9a
 
 
 
 
 
 
767a3fd
ed40a9a
 
767a3fd
 
 
ed40a9a
 
 
 
 
 
 
 
 
 
767a3fd
ed40a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767a3fd
 
 
 
 
 
 
 
 
 
 
ed40a9a
767a3fd
 
 
ed40a9a
 
 
767a3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
"""
Transformer Pipeline Analyzer
Captures and returns all intermediate states of transformer processing
"""

import torch
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
import logging

logger = logging.getLogger(__name__)

@dataclass
class PipelineStep:
    """Represents a single step in the transformer pipeline"""
    step_number: int
    step_name: str
    step_type: str  # 'tokenization', 'embedding', 'attention', 'ffn', 'output'
    description: str
    data: Dict[str, Any]

class TransformerPipelineAnalyzer:
    """Analyzes the complete flow through a transformer model"""

    def __init__(self, model, tokenizer, adapter=None):
        self.model = model
        self.tokenizer = tokenizer
        self.adapter = adapter  # Model adapter for accessing architecture-specific components
        self.device = next(model.parameters()).device
        self.steps = []
        self.intermediate_states = {}
        
    def analyze_pipeline(self, text: str, max_new_tokens: int = 1, 
                        temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95) -> Dict[str, Any]:
        """
        Capture all steps of transformer processing for multiple tokens
        
        Args:
            text: Input text to analyze
            max_new_tokens: Number of tokens to generate (default 1)
            temperature: Controls randomness in generation (default 0.7)
            top_k: Limits to top K most likely tokens (default 50)
            top_p: Cumulative probability cutoff (default 0.95)
            
        Returns:
            Dict containing tokens generated and their pipeline steps
        """
        all_tokens = []
        all_pipelines = []
        current_text = text
        
        # First generate all the tokens using the model's generate method
        # This ensures proper autoregressive generation
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True)
            input_ids = inputs["input_ids"].to(self.device)
            
            # Generate tokens properly using model.generate()
            generated_ids = self.model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=True,  # Enable sampling for variety
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
            )
            
            # Extract only the new tokens with context-aware decoding
            new_token_ids = generated_ids[0, input_ids.shape[1]:].tolist()

            # Decode tokens progressively to maintain SentencePiece context
            generated_tokens = []
            prev_decoded_length = len(text)
            for i, tid in enumerate(new_token_ids):
                # Decode the full sequence up to this point
                full_sequence = torch.cat([input_ids[0], torch.tensor(new_token_ids[:i+1], device=input_ids.device)])
                full_decoded = self.tokenizer.decode(full_sequence, skip_special_tokens=False, clean_up_tokenization_spaces=False)
                # Extract just the new token by comparing lengths
                new_token = full_decoded[prev_decoded_length:]
                generated_tokens.append(new_token)
                prev_decoded_length = len(full_decoded)

            logger.info(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
        
        # Now analyze the pipeline for each generated token
        for token_idx, next_token in enumerate(generated_tokens):
            # Analyze pipeline for current text (which will predict the next token)
            pipeline_steps = self._analyze_single_token(current_text, token_idx)
            
            # Update the output step with the actual generated token
            # (since _analyze_single_token might predict differently due to sampling)
            for step in reversed(pipeline_steps):
                if step.step_type == 'output':
                    # Update with the actual generated token
                    step.data['predicted_token'] = next_token
                    step.data['actual_token_id'] = new_token_ids[token_idx] if token_idx < len(new_token_ids) else None
                    break
            
            all_tokens.append(next_token)
            all_pipelines.append(pipeline_steps)
            current_text += next_token
            
            # Store first pipeline for backward compatibility
            if token_idx == 0:
                self.last_single_token_steps = pipeline_steps
        
        return {
            'tokens': all_tokens,
            'pipelines': all_pipelines,
            'final_text': current_text,
            'num_tokens': len(all_tokens)
        }
    
    def _analyze_single_token(self, text: str, token_position: int) -> List[PipelineStep]:
        """
        Analyze the pipeline for generating a single token
        
        Args:
            text: Current text to continue from
            token_position: Position of this token in the generation sequence
            
        Returns:
            List of PipelineStep objects for this token
        """
        steps = []
        step_counter = 0
        
        # Step 1: Raw Input
        steps.append(PipelineStep(
            step_number=step_counter,
            step_name="Raw Input",
            step_type="input",
            description="The original text input provided by the user",
            data={"text": text, "length": len(text)}
        ))
        step_counter += 1
        
        # Step 2: Tokenization
        inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True)
        input_ids = inputs["input_ids"].to(self.device)
        tokens = [self.tokenizer.decode([tid]) for tid in input_ids[0]]
        token_ids = input_ids[0].tolist()
        
        steps.append(PipelineStep(
            step_number=step_counter,
            step_name="Tokenization",
            step_type="tokenization",
            description="Text split into subword tokens using the model's tokenizer",
            data={
                "tokens": tokens,
                "token_ids": token_ids,
                "num_tokens": len(tokens),
                "tokenizer_name": self.tokenizer.__class__.__name__
            }
        ))
        step_counter += 1
        
        # Step 3: Token Embeddings
        with torch.no_grad():
            # Get token embeddings
            if hasattr(self.model, 'transformer'):
                embed_layer = self.model.transformer.wte
                pos_embed_layer = self.model.transformer.wpe if hasattr(self.model.transformer, 'wpe') else None
            else:
                embed_layer = self.model.get_input_embeddings()
                pos_embed_layer = None
            
            token_embeddings = embed_layer(input_ids)
            
            # Add positional embeddings if available
            if pos_embed_layer:
                position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=self.device)
                position_ids = position_ids.unsqueeze(0)
                position_embeddings = pos_embed_layer(position_ids)
                embeddings = token_embeddings + position_embeddings
            else:
                embeddings = token_embeddings
                position_embeddings = None
            
            steps.append(PipelineStep(
                step_number=step_counter,
                step_name="Initial Embeddings",
                step_type="embedding",
                description="Token embeddings combined with positional encodings",
                data={
                    "embedding_dim": embeddings.shape[-1],
                    "has_position_encoding": pos_embed_layer is not None,
                    "embeddings_sample": embeddings[0, :3, :8].cpu().numpy().tolist(),  # First 3 tokens, 8 dims
                    "embeddings_shape": list(embeddings.shape)
                }
            ))
            step_counter += 1
            
            # Step 4-N: Process through layers
            current_hidden = embeddings

            # Get model layers - use adapter if available for multi-architecture support
            if self.adapter:
                # Use adapter to get layer count and access layers
                num_layers = self.adapter.get_num_layers()
                sample_layers = min(4, num_layers)  # Sample first 4 layers for performance
                layers = [self.adapter.get_layer_module(i) for i in range(sample_layers)]
            elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
                # Fallback for CodeGen-style models
                layers = self.model.transformer.h[:4]
            else:
                # Fallback for other architectures
                layers = self.model.encoder.layer[:4] if hasattr(self.model, 'encoder') else []

            # Process through each layer
            for layer_idx, layer in enumerate(layers):
                # Attention mechanism
                layer_output = self._process_layer(layer, current_hidden, layer_idx)
                
                # Add attention step with tokens for labeling
                steps.append(PipelineStep(
                    step_number=step_counter,
                    step_name=f"Layer {layer_idx} - Multi-Head Attention",
                    step_type="attention",
                    description=f"Self-attention computation in layer {layer_idx}",
                    data={
                        "layer": layer_idx,
                        "num_heads": self._get_num_heads(layer),
                        "attention_pattern": layer_output.get("attention_pattern", None),
                        "tokens": tokens,  # Include tokens for labeling the attention matrix
                        "hidden_state_norm": float(torch.norm(layer_output["hidden_states"]).item())
                    }
                ))
                step_counter += 1
                
                # Feed-forward network
                if "ffn_output" in layer_output:
                    steps.append(PipelineStep(
                        step_number=step_counter,
                        step_name=f"Layer {layer_idx} - Feed-Forward Network",
                        step_type="ffn",
                        description=f"Feed-forward transformation in layer {layer_idx}",
                        data={
                            "layer": layer_idx,
                            "activation": "gelu",  # Most transformers use GELU
                            "hidden_state_norm": float(torch.norm(layer_output["ffn_output"]).item()),
                            "intermediate_size": layer_output.get("intermediate_size", 4096),
                            "hidden_size": layer_output.get("hidden_size", 1024),
                            "activation_stats": layer_output.get("activation_stats", {}),
                            "gate_values": layer_output.get("gate_values", None),
                            "tokens": tokens,  # Include tokens for context
                            "token_magnitudes": layer_output.get("token_magnitudes", [])
                        }
                    ))
                    step_counter += 1
                
                current_hidden = layer_output["hidden_states"]
            
            # Final layer norm (if exists)
            if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'ln_f'):
                current_hidden = self.model.transformer.ln_f(current_hidden)
                
                steps.append(PipelineStep(
                    step_number=step_counter,
                    step_name="Final Layer Normalization",
                    step_type="normalization",
                    description="Normalize hidden states before output projection",
                    data={
                        "norm_type": "LayerNorm",
                        "hidden_state_norm": float(torch.norm(current_hidden).item())
                    }
                ))
                step_counter += 1
            
            # Output projection
            if hasattr(self.model, 'lm_head'):
                logits = self.model.lm_head(current_hidden)
            else:
                logits = current_hidden
            
            # Get probabilities for the last token
            last_token_logits = logits[0, -1, :]
            probs = torch.softmax(last_token_logits, dim=-1)
            
            # Get top 5 predictions
            top_probs, top_indices = torch.topk(probs, 5)
            # Decode tokens with context-aware decoding for SentencePiece tokenizers
            top_tokens = []
            for idx in top_indices.tolist():
                # For context-aware decoding: append token to existing sequence and decode the delta
                # This ensures proper SentencePiece decoding (handles leading spaces, etc.)
                full_sequence = torch.cat([input_ids[0], torch.tensor([idx], device=input_ids.device)])
                full_decoded = self.tokenizer.decode(full_sequence, skip_special_tokens=False, clean_up_tokenization_spaces=False)
                # Extract just the new token by removing the original text
                decoded = full_decoded[len(text):]
                top_tokens.append(decoded)
                # Debug logging
                if idx == top_indices[0].item():
                    import logging
                    logger = logging.getLogger(__name__)
                    logger.info(f"Token generation - Input: '{text}', Predicted ID: {idx}, Context-aware decoded: '{decoded}'")
            
            steps.append(PipelineStep(
                step_number=step_counter,
                step_name="Output Projection",
                step_type="output",
                description="Project to vocabulary and compute probabilities",
                data={
                    "vocab_size": logits.shape[-1],
                    "top_5_tokens": top_tokens,
                    "top_5_probs": top_probs.cpu().numpy().tolist(),
                    "predicted_token": top_tokens[0],
                    "confidence": float(top_probs[0].item())
                }
            ))
            step_counter += 1
            
            # Step N: Generated Result
            # For code generation, we might want to show the first meaningful token
            # Check if the predicted token is just whitespace or quote
            predicted_token = top_tokens[0]
            display_token = predicted_token
            additional_info = ""
            
            # If it's a trivial token (quote, newline, whitespace), note what comes next
            if predicted_token in ["'", '"', "\n", " ", "    ", "\t"]:
                additional_info = f"Next token: '{predicted_token}' (formatting)"
                # Show what would come after formatting tokens
                if len(top_tokens) > 1:
                    for alt_token in top_tokens[1:]:
                        if alt_token not in ["'", '"', "\n", " ", "    ", "\t"]:
                            additional_info += f", likely code token: '{alt_token}'"
                            break
            
            generated_text = text + predicted_token
            steps.append(PipelineStep(
                step_number=step_counter,
                step_name="Generated Result",
                step_type="generated",
                description=f"Complete text with token #{token_position + 1}",
                data={
                    "original_text": text,
                    "predicted_token": predicted_token,
                    "complete_text": generated_text,
                    "is_code": "def " in text.lower() or "class " in text.lower() or "import " in text.lower(),
                    "additional_info": additional_info,
                    "token_position": token_position + 1
                }
            ))
            step_counter += 1
        
        return steps
    
    def _process_layer(self, layer, hidden_states, layer_idx):
        """Process a single transformer layer"""
        output = {}

        try:
            # Process with attention weight capture
            with torch.no_grad():
                # Get attention module using adapter for multi-architecture support
                attn_module = None
                if self.adapter:
                    attn_module = self.adapter.get_attention_module(layer_idx)
                elif hasattr(layer, 'attn'):
                    attn_module = layer.attn
                elif hasattr(layer, 'self_attn'):
                    attn_module = layer.self_attn

                if attn_module:
                    # Apply pre-attention layer norm
                    # LLaMA uses input_layernorm, CodeGen uses ln_1
                    if hasattr(layer, 'input_layernorm'):
                        ln_output = layer.input_layernorm(hidden_states)
                    elif hasattr(layer, 'ln_1'):
                        ln_output = layer.ln_1(hidden_states)
                    else:
                        ln_output = hidden_states

                    # Try to extract attention manually for visualization
                    attention_extracted = False

                    # Check if this is CodeGen/GPT2 style (combined QKV)
                    if hasattr(attn_module, 'qkv_proj'):
                        # CodeGen architecture - has combined QKV projection
                        qkv = attn_module.qkv_proj(ln_output)
                        embed_dim = attn_module.embed_dim
                        n_head = attn_module.num_attention_heads if hasattr(attn_module, 'num_attention_heads') else 8

                        # Split into Q, K, V
                        query, key, value = qkv.split(embed_dim, dim=2)
                        attention_extracted = True

                    elif hasattr(attn_module, 'c_attn'):
                        # GPT2-style architecture
                        qkv = attn_module.c_attn(ln_output)
                        embed_dim = attn_module.embed_dim
                        n_head = attn_module.n_head if hasattr(attn_module, 'n_head') else 8

                        # Split into Q, K, V
                        query, key, value = qkv.split(embed_dim, dim=2)
                        attention_extracted = True

                    elif hasattr(attn_module, 'q_proj') and hasattr(attn_module, 'k_proj') and hasattr(attn_module, 'v_proj'):
                        # LLaMA architecture - separate Q, K, V projections
                        query = attn_module.q_proj(ln_output)
                        key = attn_module.k_proj(ln_output)
                        value = attn_module.v_proj(ln_output)

                        # Get dimensions
                        if hasattr(attn_module, 'num_heads'):
                            n_head = attn_module.num_heads
                        elif hasattr(attn_module, 'num_attention_heads'):
                            n_head = attn_module.num_attention_heads
                        else:
                            n_head = 32  # Default for LLaMA

                        embed_dim = query.shape[-1]
                        attention_extracted = True

                    if attention_extracted:
                        # Reshape for multi-head attention
                        batch_size, seq_len = query.shape[:2]
                        head_dim = embed_dim // n_head

                        query = query.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
                        key = key.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
                        value = value.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)

                        # Compute attention scores
                        attn_weights = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)

                        # Apply causal mask (for autoregressive models)
                        causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=attn_weights.device) * -1e10, diagonal=1)
                        attn_weights = attn_weights + causal_mask.unsqueeze(0).unsqueeze(0)

                        # Apply softmax
                        attn_probs = torch.softmax(attn_weights, dim=-1)

                        # Average across heads for visualization
                        avg_attn = attn_probs.mean(dim=1)  # Shape: [batch, seq_len, seq_len]

                        # Store the full attention pattern
                        output["attention_pattern"] = avg_attn[0].cpu().numpy().tolist()
                        logger.info(f"Extracted attention pattern with shape: {avg_attn[0].shape}")

                        # Apply attention to values
                        attn_output = torch.matmul(attn_probs, value)
                        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

                        # Apply output projection
                        if hasattr(attn_module, 'out_proj'):
                            # CodeGen/LLaMA architecture
                            attn_output = attn_module.out_proj(attn_output) if hasattr(attn_module, 'out_proj') else attn_output
                        elif hasattr(attn_module, 'o_proj'):
                            # LLaMA uses o_proj
                            attn_output = attn_module.o_proj(attn_output)
                        elif hasattr(attn_module, 'c_proj'):
                            # GPT2-style architecture
                            attn_output = attn_module.c_proj(attn_output)

                        # Add residual connection
                        attn_output = hidden_states + attn_output
                    else:
                        # Fallback: call the layer directly (won't get attention pattern)
                        logger.warning(f"Could not extract attention manually for layer {layer_idx}, using layer forward pass")
                        attn_result = layer(hidden_states)
                        if isinstance(attn_result, tuple):
                            attn_output = attn_result[0]
                        else:
                            attn_output = attn_result
                        # Use identity matrix as fallback
                        seq_len = hidden_states.shape[1]
                        output["attention_pattern"] = np.eye(seq_len).tolist()
                    
                    # Apply MLP/FFN with detailed analysis
                    # Get FFN module using adapter for multi-architecture support
                    ffn_module = None
                    if self.adapter:
                        ffn_module = self.adapter.get_ffn_module(layer_idx)
                    elif hasattr(layer, 'mlp'):
                        ffn_module = layer.mlp

                    if ffn_module:
                        # Apply layer norm - LLaMA uses post_attention_layernorm, CodeGen uses ln_2
                        if hasattr(layer, 'post_attention_layernorm'):
                            ln2_output = layer.post_attention_layernorm(attn_output)
                        elif hasattr(layer, 'ln_2'):
                            ln2_output = layer.ln_2(attn_output)
                        else:
                            ln2_output = attn_output

                        # Extract detailed FFN information based on architecture
                        intermediate = None

                        if hasattr(ffn_module, 'gate_proj') and hasattr(ffn_module, 'up_proj'):
                            # LLaMA architecture - uses gated FFN (SwiGLU)
                            gate_output = ffn_module.gate_proj(ln2_output)
                            up_output = ffn_module.up_proj(ln2_output)
                            # SwiGLU activation: gate(x) * up(x)
                            import torch.nn.functional as F
                            intermediate = F.silu(gate_output) * up_output
                            output["intermediate_size"] = ffn_module.gate_proj.out_features
                            output["hidden_size"] = ffn_module.gate_proj.in_features

                            # Store gate activation stats
                            with torch.no_grad():
                                gate_values = F.silu(gate_output).detach()
                                output["gate_values"] = {
                                    "mean": float(gate_values.mean().item()),
                                    "std": float(gate_values.std().item()),
                                    "max": float(gate_values.max().item()),
                                    "min": float(gate_values.min().item())
                                }

                        elif hasattr(ffn_module, 'fc_in'):
                            # CodeGen architecture
                            intermediate = ffn_module.fc_in(ln2_output)
                            output["intermediate_size"] = ffn_module.fc_in.out_features
                            output["hidden_size"] = ffn_module.fc_in.in_features

                        elif hasattr(ffn_module, 'c_fc'):
                            # GPT2 architecture
                            intermediate = ffn_module.c_fc(ln2_output)
                            output["intermediate_size"] = ffn_module.c_fc.out_features
                            output["hidden_size"] = ffn_module.c_fc.in_features

                        if intermediate is not None:
                            # Compute activation statistics
                            with torch.no_grad():
                                act_values = intermediate.detach()
                                output["activation_stats"] = {
                                    "mean": float(act_values.mean().item()),
                                    "std": float(act_values.std().item()),
                                    "max": float(act_values.max().item()),
                                    "min": float(act_values.min().item()),
                                    "sparsity": float((act_values == 0).float().mean().item()),  # Fraction of zeros
                                    "active_neurons": int((act_values.abs() > 0.1).sum().item())  # Neurons with significant activation
                                }

                                # Get per-token magnitudes (average activation magnitude per token)
                                token_mags = act_values.abs().mean(dim=-1)[0].cpu().numpy().tolist()
                                output["token_magnitudes"] = token_mags

                        # Apply full MLP
                        mlp_output = ffn_module(ln2_output)
                        output["ffn_output"] = mlp_output
                        hidden_states = attn_output + mlp_output
                    else:
                        hidden_states = attn_output
                else:
                    # BERT-style or other architecture
                    hidden_states = layer(hidden_states)[0]
                
                output["hidden_states"] = hidden_states
                    
        except Exception as e:
            logger.warning(f"Error processing layer {layer_idx}: {e}")
            import traceback
            logger.warning(f"Traceback: {traceback.format_exc()}")
            output["hidden_states"] = hidden_states
            # Fallback to simple pattern if real extraction fails
            if "attention_pattern" not in output:
                seq_len = hidden_states.shape[1]
                output["attention_pattern"] = np.eye(seq_len).tolist()  # Identity matrix as fallback
                logger.warning(f"Using fallback attention pattern for layer {layer_idx}")
            
        return output
    
    def _get_num_heads(self, layer):
        """Get number of attention heads in a layer"""
        if hasattr(layer, 'attn'):
            if hasattr(layer.attn, 'num_attention_heads'):
                return layer.attn.num_attention_heads  # CodeGen
            elif hasattr(layer.attn, 'n_head'):
                return layer.attn.n_head  # GPT2
            elif hasattr(layer.attn, 'num_heads'):
                return layer.attn.num_heads  # Other architectures
        return 8  # Default guess
    
    def get_steps_dict(self) -> List[Dict]:
        """Convert steps to dictionary format for JSON serialization
        
        This is kept for backward compatibility but may not work with multi-token generation.
        Use the result from analyze_pipeline directly instead.
        """
        # If we have stored steps from single token generation, return them
        if hasattr(self, 'last_single_token_steps'):
            return [asdict(step) for step in self.last_single_token_steps]
        return []