""" Transformer Pipeline Analyzer Captures and returns all intermediate states of transformer processing """ import torch import numpy as np from typing import Dict, List, Any, Optional, Tuple from dataclasses import dataclass, asdict import logging logger = logging.getLogger(__name__) @dataclass class PipelineStep: """Represents a single step in the transformer pipeline""" step_number: int step_name: str step_type: str # 'tokenization', 'embedding', 'attention', 'ffn', 'output' description: str data: Dict[str, Any] class TransformerPipelineAnalyzer: """Analyzes the complete flow through a transformer model""" def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer self.device = next(model.parameters()).device self.steps = [] self.intermediate_states = {} def analyze_pipeline(self, text: str, max_new_tokens: int = 1, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95) -> Dict[str, Any]: """ Capture all steps of transformer processing for multiple tokens Args: text: Input text to analyze max_new_tokens: Number of tokens to generate (default 1) temperature: Controls randomness in generation (default 0.7) top_k: Limits to top K most likely tokens (default 50) top_p: Cumulative probability cutoff (default 0.95) Returns: Dict containing tokens generated and their pipeline steps """ all_tokens = [] all_pipelines = [] current_text = text # First generate all the tokens using the model's generate method # This ensures proper autoregressive generation with torch.no_grad(): inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True) input_ids = inputs["input_ids"].to(self.device) # Generate tokens properly using model.generate() generated_ids = self.model.generate( input_ids, max_new_tokens=max_new_tokens, do_sample=True, # Enable sampling for variety temperature=temperature, top_k=top_k, top_p=top_p, pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id ) # Extract only the new tokens new_token_ids = generated_ids[0, input_ids.shape[1]:].tolist() generated_tokens = [self.tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in new_token_ids] logger.info(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") # Now analyze the pipeline for each generated token for token_idx, next_token in enumerate(generated_tokens): # Analyze pipeline for current text (which will predict the next token) pipeline_steps = self._analyze_single_token(current_text, token_idx) # Update the output step with the actual generated token # (since _analyze_single_token might predict differently due to sampling) for step in reversed(pipeline_steps): if step.step_type == 'output': # Update with the actual generated token step.data['predicted_token'] = next_token step.data['actual_token_id'] = new_token_ids[token_idx] if token_idx < len(new_token_ids) else None break all_tokens.append(next_token) all_pipelines.append(pipeline_steps) current_text += next_token # Store first pipeline for backward compatibility if token_idx == 0: self.last_single_token_steps = pipeline_steps return { 'tokens': all_tokens, 'pipelines': all_pipelines, 'final_text': current_text, 'num_tokens': len(all_tokens) } def _analyze_single_token(self, text: str, token_position: int) -> List[PipelineStep]: """ Analyze the pipeline for generating a single token Args: text: Current text to continue from token_position: Position of this token in the generation sequence Returns: List of PipelineStep objects for this token """ steps = [] step_counter = 0 # Step 1: Raw Input steps.append(PipelineStep( step_number=step_counter, step_name="Raw Input", step_type="input", description="The original text input provided by the user", data={"text": text, "length": len(text)} )) step_counter += 1 # Step 2: Tokenization inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True) input_ids = inputs["input_ids"].to(self.device) tokens = [self.tokenizer.decode([tid]) for tid in input_ids[0]] token_ids = input_ids[0].tolist() steps.append(PipelineStep( step_number=step_counter, step_name="Tokenization", step_type="tokenization", description="Text split into subword tokens using the model's tokenizer", data={ "tokens": tokens, "token_ids": token_ids, "num_tokens": len(tokens), "tokenizer_name": self.tokenizer.__class__.__name__ } )) step_counter += 1 # Step 3: Token Embeddings with torch.no_grad(): # Get token embeddings if hasattr(self.model, 'transformer'): embed_layer = self.model.transformer.wte pos_embed_layer = self.model.transformer.wpe if hasattr(self.model.transformer, 'wpe') else None else: embed_layer = self.model.get_input_embeddings() pos_embed_layer = None token_embeddings = embed_layer(input_ids) # Add positional embeddings if available if pos_embed_layer: position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=self.device) position_ids = position_ids.unsqueeze(0) position_embeddings = pos_embed_layer(position_ids) embeddings = token_embeddings + position_embeddings else: embeddings = token_embeddings position_embeddings = None steps.append(PipelineStep( step_number=step_counter, step_name="Initial Embeddings", step_type="embedding", description="Token embeddings combined with positional encodings", data={ "embedding_dim": embeddings.shape[-1], "has_position_encoding": pos_embed_layer is not None, "embeddings_sample": embeddings[0, :3, :8].cpu().numpy().tolist(), # First 3 tokens, 8 dims "embeddings_shape": list(embeddings.shape) } )) step_counter += 1 # Step 4-N: Process through layers current_hidden = embeddings # Get model layers if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'): layers = self.model.transformer.h else: layers = self.model.encoder.layer if hasattr(self.model, 'encoder') else [] # Process through each layer for layer_idx, layer in enumerate(layers[:4]): # Sample first 4 layers for performance # Attention mechanism layer_output = self._process_layer(layer, current_hidden, layer_idx) # Add attention step with tokens for labeling steps.append(PipelineStep( step_number=step_counter, step_name=f"Layer {layer_idx} - Multi-Head Attention", step_type="attention", description=f"Self-attention computation in layer {layer_idx}", data={ "layer": layer_idx, "num_heads": self._get_num_heads(layer), "attention_pattern": layer_output.get("attention_pattern", None), "tokens": tokens, # Include tokens for labeling the attention matrix "hidden_state_norm": float(torch.norm(layer_output["hidden_states"]).item()) } )) step_counter += 1 # Feed-forward network if "ffn_output" in layer_output: steps.append(PipelineStep( step_number=step_counter, step_name=f"Layer {layer_idx} - Feed-Forward Network", step_type="ffn", description=f"Feed-forward transformation in layer {layer_idx}", data={ "layer": layer_idx, "activation": "gelu", # Most transformers use GELU "hidden_state_norm": float(torch.norm(layer_output["ffn_output"]).item()), "intermediate_size": layer_output.get("intermediate_size", 4096), "hidden_size": layer_output.get("hidden_size", 1024), "activation_stats": layer_output.get("activation_stats", {}), "gate_values": layer_output.get("gate_values", None), "tokens": tokens, # Include tokens for context "token_magnitudes": layer_output.get("token_magnitudes", []) } )) step_counter += 1 current_hidden = layer_output["hidden_states"] # Final layer norm (if exists) if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'ln_f'): current_hidden = self.model.transformer.ln_f(current_hidden) steps.append(PipelineStep( step_number=step_counter, step_name="Final Layer Normalization", step_type="normalization", description="Normalize hidden states before output projection", data={ "norm_type": "LayerNorm", "hidden_state_norm": float(torch.norm(current_hidden).item()) } )) step_counter += 1 # Output projection if hasattr(self.model, 'lm_head'): logits = self.model.lm_head(current_hidden) else: logits = current_hidden # Get probabilities for the last token last_token_logits = logits[0, -1, :] probs = torch.softmax(last_token_logits, dim=-1) # Get top 5 predictions top_probs, top_indices = torch.topk(probs, 5) # Decode tokens properly, preserving whitespace and special characters top_tokens = [] for idx in top_indices.tolist(): decoded = self.tokenizer.decode([idx], skip_special_tokens=False, clean_up_tokenization_spaces=False) top_tokens.append(decoded) # Debug logging if idx == top_indices[0].item(): import logging logger = logging.getLogger(__name__) logger.info(f"Token generation - Input: '{text}', Predicted ID: {idx}, Decoded: '{decoded}'") steps.append(PipelineStep( step_number=step_counter, step_name="Output Projection", step_type="output", description="Project to vocabulary and compute probabilities", data={ "vocab_size": logits.shape[-1], "top_5_tokens": top_tokens, "top_5_probs": top_probs.cpu().numpy().tolist(), "predicted_token": top_tokens[0], "confidence": float(top_probs[0].item()) } )) step_counter += 1 # Step N: Generated Result # For code generation, we might want to show the first meaningful token # Check if the predicted token is just whitespace or quote predicted_token = top_tokens[0] display_token = predicted_token additional_info = "" # If it's a trivial token (quote, newline, whitespace), note what comes next if predicted_token in ["'", '"', "\n", " ", " ", "\t"]: additional_info = f"Next token: '{predicted_token}' (formatting)" # Show what would come after formatting tokens if len(top_tokens) > 1: for alt_token in top_tokens[1:]: if alt_token not in ["'", '"', "\n", " ", " ", "\t"]: additional_info += f", likely code token: '{alt_token}'" break generated_text = text + predicted_token steps.append(PipelineStep( step_number=step_counter, step_name="Generated Result", step_type="generated", description=f"Complete text with token #{token_position + 1}", data={ "original_text": text, "predicted_token": predicted_token, "complete_text": generated_text, "is_code": "def " in text.lower() or "class " in text.lower() or "import " in text.lower(), "additional_info": additional_info, "token_position": token_position + 1 } )) step_counter += 1 return steps def _process_layer(self, layer, hidden_states, layer_idx): """Process a single transformer layer""" output = {} try: # Process with attention weight capture with torch.no_grad(): if hasattr(layer, 'attn'): # GPT-style architecture - capture attention weights # First apply layer norm if present ln_output = layer.ln_1(hidden_states) if hasattr(layer, 'ln_1') else hidden_states # Get attention weights by calling the attention module with output_attentions qkv = None if hasattr(layer.attn, 'qkv_proj'): # CodeGen architecture - has combined QKV projection qkv = layer.attn.qkv_proj(ln_output) embed_dim = layer.attn.embed_dim n_head = layer.attn.num_attention_heads if hasattr(layer.attn, 'num_attention_heads') else 8 elif hasattr(layer.attn, 'c_attn'): # GPT2-style architecture qkv = layer.attn.c_attn(ln_output) embed_dim = layer.attn.embed_dim n_head = layer.attn.n_head if hasattr(layer.attn, 'n_head') else 8 if qkv is not None: # Split into Q, K, V query, key, value = qkv.split(embed_dim, dim=2) # Reshape for multi-head attention batch_size, seq_len = query.shape[:2] head_dim = embed_dim // n_head query = query.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) key = key.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) value = value.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) # Compute attention scores attn_weights = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5) # Apply causal mask (for autoregressive models) if hasattr(layer.attn, 'bias') and layer.attn.bias is not None: attn_weights = attn_weights + layer.attn.bias[:, :, :seq_len, :seq_len] else: # Create causal mask manually if no bias exists causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=attn_weights.device) * -1e4, diagonal=1) attn_weights = attn_weights + causal_mask.unsqueeze(0).unsqueeze(0) # Apply softmax attn_probs = torch.softmax(attn_weights, dim=-1) # Average across heads for visualization avg_attn = attn_probs.mean(dim=1) # Shape: [batch, seq_len, seq_len] # Store the full attention pattern output["attention_pattern"] = avg_attn[0].cpu().numpy().tolist() # Full seq_len x seq_len logger.info(f"Extracted attention pattern with shape: {avg_attn[0].shape}") # Apply attention to values and continue processing attn_output = torch.matmul(attn_probs, value) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) # Apply output projection if hasattr(layer.attn, 'out_proj'): # CodeGen architecture attn_output = layer.attn.out_proj(attn_output) elif hasattr(layer.attn, 'c_proj'): # GPT2-style architecture attn_output = layer.attn.c_proj(attn_output) # Apply residual dropout if present if hasattr(layer.attn, 'resid_dropout'): attn_output = layer.attn.resid_dropout(attn_output) # Add residual connection attn_output = hidden_states + attn_output else: # Fallback for different architecture attn_output = layer.attn(hidden_states) if isinstance(attn_output, tuple): attn_output = attn_output[0] # Apply MLP with detailed analysis if hasattr(layer, 'mlp'): ln2_output = layer.ln_2(attn_output) if hasattr(layer, 'ln_2') else attn_output # Extract detailed FFN information if hasattr(layer.mlp, 'fc_in') or hasattr(layer.mlp, 'c_fc'): # Get intermediate layer if hasattr(layer.mlp, 'fc_in'): # CodeGen architecture intermediate = layer.mlp.fc_in(ln2_output) output["intermediate_size"] = layer.mlp.fc_in.out_features output["hidden_size"] = layer.mlp.fc_in.in_features elif hasattr(layer.mlp, 'c_fc'): # GPT2 architecture intermediate = layer.mlp.c_fc(ln2_output) output["intermediate_size"] = layer.mlp.c_fc.out_features output["hidden_size"] = layer.mlp.c_fc.in_features # Compute activation statistics with torch.no_grad(): act_values = intermediate.detach() output["activation_stats"] = { "mean": float(act_values.mean().item()), "std": float(act_values.std().item()), "max": float(act_values.max().item()), "min": float(act_values.min().item()), "sparsity": float((act_values == 0).float().mean().item()), # Fraction of zeros "active_neurons": int((act_values.abs() > 0.1).sum().item()) # Neurons with significant activation } # Get per-token magnitudes (average activation magnitude per token) token_mags = act_values.abs().mean(dim=-1)[0].cpu().numpy().tolist() output["token_magnitudes"] = token_mags mlp_output = layer.mlp(ln2_output) output["ffn_output"] = mlp_output hidden_states = attn_output + mlp_output else: hidden_states = attn_output else: # BERT-style or other architecture hidden_states = layer(hidden_states)[0] output["hidden_states"] = hidden_states except Exception as e: logger.warning(f"Error processing layer {layer_idx}: {e}") import traceback logger.warning(f"Traceback: {traceback.format_exc()}") output["hidden_states"] = hidden_states # Fallback to simple pattern if real extraction fails if "attention_pattern" not in output: seq_len = hidden_states.shape[1] output["attention_pattern"] = np.eye(seq_len).tolist() # Identity matrix as fallback logger.warning(f"Using fallback attention pattern for layer {layer_idx}") return output def _get_num_heads(self, layer): """Get number of attention heads in a layer""" if hasattr(layer, 'attn'): if hasattr(layer.attn, 'num_attention_heads'): return layer.attn.num_attention_heads # CodeGen elif hasattr(layer.attn, 'n_head'): return layer.attn.n_head # GPT2 elif hasattr(layer.attn, 'num_heads'): return layer.attn.num_heads # Other architectures return 8 # Default guess def get_steps_dict(self) -> List[Dict]: """Convert steps to dictionary format for JSON serialization This is kept for backward compatibility but may not work with multi-token generation. Use the result from analyze_pipeline directly instead. """ # If we have stored steps from single token generation, return them if hasattr(self, 'last_single_token_steps'): return [asdict(step) for step in self.last_single_token_steps] return []