""" Transformer Pipeline Analyzer Captures and returns all intermediate states of transformer processing """ import torch import numpy as np from typing import Dict, List, Any, Optional, Tuple from dataclasses import dataclass, asdict import logging logger = logging.getLogger(__name__) @dataclass class PipelineStep: """Represents a single step in the transformer pipeline""" step_number: int step_name: str step_type: str # 'tokenization', 'embedding', 'attention', 'ffn', 'output' description: str data: Dict[str, Any] class TransformerPipelineAnalyzer: """Analyzes the complete flow through a transformer model""" def __init__(self, model, tokenizer, adapter=None): self.model = model self.tokenizer = tokenizer self.adapter = adapter # Model adapter for accessing architecture-specific components self.device = next(model.parameters()).device self.steps = [] self.intermediate_states = {} def analyze_pipeline(self, text: str, max_new_tokens: int = 1, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95) -> Dict[str, Any]: """ Capture all steps of transformer processing for multiple tokens Args: text: Input text to analyze max_new_tokens: Number of tokens to generate (default 1) temperature: Controls randomness in generation (default 0.7) top_k: Limits to top K most likely tokens (default 50) top_p: Cumulative probability cutoff (default 0.95) Returns: Dict containing tokens generated and their pipeline steps """ all_tokens = [] all_pipelines = [] current_text = text # First generate all the tokens using the model's generate method # This ensures proper autoregressive generation with torch.no_grad(): inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True) input_ids = inputs["input_ids"].to(self.device) # Generate tokens properly using model.generate() generated_ids = self.model.generate( input_ids, max_new_tokens=max_new_tokens, do_sample=True, # Enable sampling for variety temperature=temperature, top_k=top_k, top_p=top_p, pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id ) # Extract only the new tokens with context-aware decoding new_token_ids = generated_ids[0, input_ids.shape[1]:].tolist() # Decode tokens progressively to maintain SentencePiece context generated_tokens = [] prev_decoded_length = len(text) for i, tid in enumerate(new_token_ids): # Decode the full sequence up to this point full_sequence = torch.cat([input_ids[0], torch.tensor(new_token_ids[:i+1], device=input_ids.device)]) full_decoded = self.tokenizer.decode(full_sequence, skip_special_tokens=False, clean_up_tokenization_spaces=False) # Extract just the new token by comparing lengths new_token = full_decoded[prev_decoded_length:] generated_tokens.append(new_token) prev_decoded_length = len(full_decoded) logger.info(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") # Now analyze the pipeline for each generated token for token_idx, next_token in enumerate(generated_tokens): # Analyze pipeline for current text (which will predict the next token) pipeline_steps = self._analyze_single_token(current_text, token_idx) # Update the output step with the actual generated token # (since _analyze_single_token might predict differently due to sampling) for step in reversed(pipeline_steps): if step.step_type == 'output': # Update with the actual generated token step.data['predicted_token'] = next_token step.data['actual_token_id'] = new_token_ids[token_idx] if token_idx < len(new_token_ids) else None break all_tokens.append(next_token) all_pipelines.append(pipeline_steps) current_text += next_token # Store first pipeline for backward compatibility if token_idx == 0: self.last_single_token_steps = pipeline_steps return { 'tokens': all_tokens, 'pipelines': all_pipelines, 'final_text': current_text, 'num_tokens': len(all_tokens) } def _analyze_single_token(self, text: str, token_position: int) -> List[PipelineStep]: """ Analyze the pipeline for generating a single token Args: text: Current text to continue from token_position: Position of this token in the generation sequence Returns: List of PipelineStep objects for this token """ steps = [] step_counter = 0 # Step 1: Raw Input steps.append(PipelineStep( step_number=step_counter, step_name="Raw Input", step_type="input", description="The original text input provided by the user", data={"text": text, "length": len(text)} )) step_counter += 1 # Step 2: Tokenization inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True) input_ids = inputs["input_ids"].to(self.device) tokens = [self.tokenizer.decode([tid]) for tid in input_ids[0]] token_ids = input_ids[0].tolist() steps.append(PipelineStep( step_number=step_counter, step_name="Tokenization", step_type="tokenization", description="Text split into subword tokens using the model's tokenizer", data={ "tokens": tokens, "token_ids": token_ids, "num_tokens": len(tokens), "tokenizer_name": self.tokenizer.__class__.__name__ } )) step_counter += 1 # Step 3: Token Embeddings with torch.no_grad(): # Get token embeddings if hasattr(self.model, 'transformer'): embed_layer = self.model.transformer.wte pos_embed_layer = self.model.transformer.wpe if hasattr(self.model.transformer, 'wpe') else None else: embed_layer = self.model.get_input_embeddings() pos_embed_layer = None token_embeddings = embed_layer(input_ids) # Add positional embeddings if available if pos_embed_layer: position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=self.device) position_ids = position_ids.unsqueeze(0) position_embeddings = pos_embed_layer(position_ids) embeddings = token_embeddings + position_embeddings else: embeddings = token_embeddings position_embeddings = None steps.append(PipelineStep( step_number=step_counter, step_name="Initial Embeddings", step_type="embedding", description="Token embeddings combined with positional encodings", data={ "embedding_dim": embeddings.shape[-1], "has_position_encoding": pos_embed_layer is not None, "embeddings_sample": embeddings[0, :3, :8].cpu().numpy().tolist(), # First 3 tokens, 8 dims "embeddings_shape": list(embeddings.shape) } )) step_counter += 1 # Step 4-N: Process through layers current_hidden = embeddings # Get model layers - use adapter if available for multi-architecture support if self.adapter: # Use adapter to get layer count and access layers num_layers = self.adapter.get_num_layers() sample_layers = min(4, num_layers) # Sample first 4 layers for performance layers = [self.adapter.get_layer_module(i) for i in range(sample_layers)] elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'): # Fallback for CodeGen-style models layers = self.model.transformer.h[:4] else: # Fallback for other architectures layers = self.model.encoder.layer[:4] if hasattr(self.model, 'encoder') else [] # Process through each layer for layer_idx, layer in enumerate(layers): # Attention mechanism layer_output = self._process_layer(layer, current_hidden, layer_idx) # Add attention step with tokens for labeling steps.append(PipelineStep( step_number=step_counter, step_name=f"Layer {layer_idx} - Multi-Head Attention", step_type="attention", description=f"Self-attention computation in layer {layer_idx}", data={ "layer": layer_idx, "num_heads": self._get_num_heads(layer), "attention_pattern": layer_output.get("attention_pattern", None), "tokens": tokens, # Include tokens for labeling the attention matrix "hidden_state_norm": float(torch.norm(layer_output["hidden_states"]).item()) } )) step_counter += 1 # Feed-forward network if "ffn_output" in layer_output: steps.append(PipelineStep( step_number=step_counter, step_name=f"Layer {layer_idx} - Feed-Forward Network", step_type="ffn", description=f"Feed-forward transformation in layer {layer_idx}", data={ "layer": layer_idx, "activation": "gelu", # Most transformers use GELU "hidden_state_norm": float(torch.norm(layer_output["ffn_output"]).item()), "intermediate_size": layer_output.get("intermediate_size", 4096), "hidden_size": layer_output.get("hidden_size", 1024), "activation_stats": layer_output.get("activation_stats", {}), "gate_values": layer_output.get("gate_values", None), "tokens": tokens, # Include tokens for context "token_magnitudes": layer_output.get("token_magnitudes", []) } )) step_counter += 1 current_hidden = layer_output["hidden_states"] # Final layer norm (if exists) if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'ln_f'): current_hidden = self.model.transformer.ln_f(current_hidden) steps.append(PipelineStep( step_number=step_counter, step_name="Final Layer Normalization", step_type="normalization", description="Normalize hidden states before output projection", data={ "norm_type": "LayerNorm", "hidden_state_norm": float(torch.norm(current_hidden).item()) } )) step_counter += 1 # Output projection if hasattr(self.model, 'lm_head'): logits = self.model.lm_head(current_hidden) else: logits = current_hidden # Get probabilities for the last token last_token_logits = logits[0, -1, :] probs = torch.softmax(last_token_logits, dim=-1) # Get top 5 predictions top_probs, top_indices = torch.topk(probs, 5) # Decode tokens with context-aware decoding for SentencePiece tokenizers top_tokens = [] for idx in top_indices.tolist(): # For context-aware decoding: append token to existing sequence and decode the delta # This ensures proper SentencePiece decoding (handles leading spaces, etc.) full_sequence = torch.cat([input_ids[0], torch.tensor([idx], device=input_ids.device)]) full_decoded = self.tokenizer.decode(full_sequence, skip_special_tokens=False, clean_up_tokenization_spaces=False) # Extract just the new token by removing the original text decoded = full_decoded[len(text):] top_tokens.append(decoded) # Debug logging if idx == top_indices[0].item(): import logging logger = logging.getLogger(__name__) logger.info(f"Token generation - Input: '{text}', Predicted ID: {idx}, Context-aware decoded: '{decoded}'") steps.append(PipelineStep( step_number=step_counter, step_name="Output Projection", step_type="output", description="Project to vocabulary and compute probabilities", data={ "vocab_size": logits.shape[-1], "top_5_tokens": top_tokens, "top_5_probs": top_probs.cpu().numpy().tolist(), "predicted_token": top_tokens[0], "confidence": float(top_probs[0].item()) } )) step_counter += 1 # Step N: Generated Result # For code generation, we might want to show the first meaningful token # Check if the predicted token is just whitespace or quote predicted_token = top_tokens[0] display_token = predicted_token additional_info = "" # If it's a trivial token (quote, newline, whitespace), note what comes next if predicted_token in ["'", '"', "\n", " ", " ", "\t"]: additional_info = f"Next token: '{predicted_token}' (formatting)" # Show what would come after formatting tokens if len(top_tokens) > 1: for alt_token in top_tokens[1:]: if alt_token not in ["'", '"', "\n", " ", " ", "\t"]: additional_info += f", likely code token: '{alt_token}'" break generated_text = text + predicted_token steps.append(PipelineStep( step_number=step_counter, step_name="Generated Result", step_type="generated", description=f"Complete text with token #{token_position + 1}", data={ "original_text": text, "predicted_token": predicted_token, "complete_text": generated_text, "is_code": "def " in text.lower() or "class " in text.lower() or "import " in text.lower(), "additional_info": additional_info, "token_position": token_position + 1 } )) step_counter += 1 return steps def _process_layer(self, layer, hidden_states, layer_idx): """Process a single transformer layer""" output = {} try: # Process with attention weight capture with torch.no_grad(): # Get attention module using adapter for multi-architecture support attn_module = None if self.adapter: attn_module = self.adapter.get_attention_module(layer_idx) elif hasattr(layer, 'attn'): attn_module = layer.attn elif hasattr(layer, 'self_attn'): attn_module = layer.self_attn if attn_module: # Apply pre-attention layer norm # LLaMA uses input_layernorm, CodeGen uses ln_1 if hasattr(layer, 'input_layernorm'): ln_output = layer.input_layernorm(hidden_states) elif hasattr(layer, 'ln_1'): ln_output = layer.ln_1(hidden_states) else: ln_output = hidden_states # Try to extract attention manually for visualization attention_extracted = False # Check if this is CodeGen/GPT2 style (combined QKV) if hasattr(attn_module, 'qkv_proj'): # CodeGen architecture - has combined QKV projection qkv = attn_module.qkv_proj(ln_output) embed_dim = attn_module.embed_dim n_head = attn_module.num_attention_heads if hasattr(attn_module, 'num_attention_heads') else 8 # Split into Q, K, V query, key, value = qkv.split(embed_dim, dim=2) attention_extracted = True elif hasattr(attn_module, 'c_attn'): # GPT2-style architecture qkv = attn_module.c_attn(ln_output) embed_dim = attn_module.embed_dim n_head = attn_module.n_head if hasattr(attn_module, 'n_head') else 8 # Split into Q, K, V query, key, value = qkv.split(embed_dim, dim=2) attention_extracted = True elif hasattr(attn_module, 'q_proj') and hasattr(attn_module, 'k_proj') and hasattr(attn_module, 'v_proj'): # LLaMA architecture - separate Q, K, V projections query = attn_module.q_proj(ln_output) key = attn_module.k_proj(ln_output) value = attn_module.v_proj(ln_output) # Get dimensions if hasattr(attn_module, 'num_heads'): n_head = attn_module.num_heads elif hasattr(attn_module, 'num_attention_heads'): n_head = attn_module.num_attention_heads else: n_head = 32 # Default for LLaMA embed_dim = query.shape[-1] attention_extracted = True if attention_extracted: # Reshape for multi-head attention batch_size, seq_len = query.shape[:2] head_dim = embed_dim // n_head query = query.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) key = key.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) value = value.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2) # Compute attention scores attn_weights = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5) # Apply causal mask (for autoregressive models) causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=attn_weights.device) * -1e10, diagonal=1) attn_weights = attn_weights + causal_mask.unsqueeze(0).unsqueeze(0) # Apply softmax attn_probs = torch.softmax(attn_weights, dim=-1) # Average across heads for visualization avg_attn = attn_probs.mean(dim=1) # Shape: [batch, seq_len, seq_len] # Store the full attention pattern output["attention_pattern"] = avg_attn[0].cpu().numpy().tolist() logger.info(f"Extracted attention pattern with shape: {avg_attn[0].shape}") # Apply attention to values attn_output = torch.matmul(attn_probs, value) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) # Apply output projection if hasattr(attn_module, 'out_proj'): # CodeGen/LLaMA architecture attn_output = attn_module.out_proj(attn_output) if hasattr(attn_module, 'out_proj') else attn_output elif hasattr(attn_module, 'o_proj'): # LLaMA uses o_proj attn_output = attn_module.o_proj(attn_output) elif hasattr(attn_module, 'c_proj'): # GPT2-style architecture attn_output = attn_module.c_proj(attn_output) # Add residual connection attn_output = hidden_states + attn_output else: # Fallback: call the layer directly (won't get attention pattern) logger.warning(f"Could not extract attention manually for layer {layer_idx}, using layer forward pass") attn_result = layer(hidden_states) if isinstance(attn_result, tuple): attn_output = attn_result[0] else: attn_output = attn_result # Use identity matrix as fallback seq_len = hidden_states.shape[1] output["attention_pattern"] = np.eye(seq_len).tolist() # Apply MLP/FFN with detailed analysis # Get FFN module using adapter for multi-architecture support ffn_module = None if self.adapter: ffn_module = self.adapter.get_ffn_module(layer_idx) elif hasattr(layer, 'mlp'): ffn_module = layer.mlp if ffn_module: # Apply layer norm - LLaMA uses post_attention_layernorm, CodeGen uses ln_2 if hasattr(layer, 'post_attention_layernorm'): ln2_output = layer.post_attention_layernorm(attn_output) elif hasattr(layer, 'ln_2'): ln2_output = layer.ln_2(attn_output) else: ln2_output = attn_output # Extract detailed FFN information based on architecture intermediate = None if hasattr(ffn_module, 'gate_proj') and hasattr(ffn_module, 'up_proj'): # LLaMA architecture - uses gated FFN (SwiGLU) gate_output = ffn_module.gate_proj(ln2_output) up_output = ffn_module.up_proj(ln2_output) # SwiGLU activation: gate(x) * up(x) import torch.nn.functional as F intermediate = F.silu(gate_output) * up_output output["intermediate_size"] = ffn_module.gate_proj.out_features output["hidden_size"] = ffn_module.gate_proj.in_features # Store gate activation stats with torch.no_grad(): gate_values = F.silu(gate_output).detach() output["gate_values"] = { "mean": float(gate_values.mean().item()), "std": float(gate_values.std().item()), "max": float(gate_values.max().item()), "min": float(gate_values.min().item()) } elif hasattr(ffn_module, 'fc_in'): # CodeGen architecture intermediate = ffn_module.fc_in(ln2_output) output["intermediate_size"] = ffn_module.fc_in.out_features output["hidden_size"] = ffn_module.fc_in.in_features elif hasattr(ffn_module, 'c_fc'): # GPT2 architecture intermediate = ffn_module.c_fc(ln2_output) output["intermediate_size"] = ffn_module.c_fc.out_features output["hidden_size"] = ffn_module.c_fc.in_features if intermediate is not None: # Compute activation statistics with torch.no_grad(): act_values = intermediate.detach() output["activation_stats"] = { "mean": float(act_values.mean().item()), "std": float(act_values.std().item()), "max": float(act_values.max().item()), "min": float(act_values.min().item()), "sparsity": float((act_values == 0).float().mean().item()), # Fraction of zeros "active_neurons": int((act_values.abs() > 0.1).sum().item()) # Neurons with significant activation } # Get per-token magnitudes (average activation magnitude per token) token_mags = act_values.abs().mean(dim=-1)[0].cpu().numpy().tolist() output["token_magnitudes"] = token_mags # Apply full MLP mlp_output = ffn_module(ln2_output) output["ffn_output"] = mlp_output hidden_states = attn_output + mlp_output else: hidden_states = attn_output else: # BERT-style or other architecture hidden_states = layer(hidden_states)[0] output["hidden_states"] = hidden_states except Exception as e: logger.warning(f"Error processing layer {layer_idx}: {e}") import traceback logger.warning(f"Traceback: {traceback.format_exc()}") output["hidden_states"] = hidden_states # Fallback to simple pattern if real extraction fails if "attention_pattern" not in output: seq_len = hidden_states.shape[1] output["attention_pattern"] = np.eye(seq_len).tolist() # Identity matrix as fallback logger.warning(f"Using fallback attention pattern for layer {layer_idx}") return output def _get_num_heads(self, layer): """Get number of attention heads in a layer""" if hasattr(layer, 'attn'): if hasattr(layer.attn, 'num_attention_heads'): return layer.attn.num_attention_heads # CodeGen elif hasattr(layer.attn, 'n_head'): return layer.attn.n_head # GPT2 elif hasattr(layer.attn, 'num_heads'): return layer.attn.num_heads # Other architectures return 8 # Default guess def get_steps_dict(self) -> List[Dict]: """Convert steps to dictionary format for JSON serialization This is kept for backward compatibility but may not work with multi-token generation. Use the result from analyze_pipeline directly instead. """ # If we have stored steps from single token generation, return them if hasattr(self, 'last_single_token_steps'): return [asdict(step) for step in self.last_single_token_steps] return []