api / backend /pipeline_analyzer.py
gary-boon
feat: Add pipeline analyzer and QKV extractor for transformer visualization
767a3fd
raw
history blame
23.9 kB
"""
Transformer Pipeline Analyzer
Captures and returns all intermediate states of transformer processing
"""
import torch
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
import logging
logger = logging.getLogger(__name__)
@dataclass
class PipelineStep:
"""Represents a single step in the transformer pipeline"""
step_number: int
step_name: str
step_type: str # 'tokenization', 'embedding', 'attention', 'ffn', 'output'
description: str
data: Dict[str, Any]
class TransformerPipelineAnalyzer:
"""Analyzes the complete flow through a transformer model"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
self.steps = []
self.intermediate_states = {}
def analyze_pipeline(self, text: str, max_new_tokens: int = 1,
temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95) -> Dict[str, Any]:
"""
Capture all steps of transformer processing for multiple tokens
Args:
text: Input text to analyze
max_new_tokens: Number of tokens to generate (default 1)
temperature: Controls randomness in generation (default 0.7)
top_k: Limits to top K most likely tokens (default 50)
top_p: Cumulative probability cutoff (default 0.95)
Returns:
Dict containing tokens generated and their pipeline steps
"""
all_tokens = []
all_pipelines = []
current_text = text
# First generate all the tokens using the model's generate method
# This ensures proper autoregressive generation
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True)
input_ids = inputs["input_ids"].to(self.device)
# Generate tokens properly using model.generate()
generated_ids = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True, # Enable sampling for variety
temperature=temperature,
top_k=top_k,
top_p=top_p,
pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
)
# Extract only the new tokens
new_token_ids = generated_ids[0, input_ids.shape[1]:].tolist()
generated_tokens = [self.tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in new_token_ids]
logger.info(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
# Now analyze the pipeline for each generated token
for token_idx, next_token in enumerate(generated_tokens):
# Analyze pipeline for current text (which will predict the next token)
pipeline_steps = self._analyze_single_token(current_text, token_idx)
# Update the output step with the actual generated token
# (since _analyze_single_token might predict differently due to sampling)
for step in reversed(pipeline_steps):
if step.step_type == 'output':
# Update with the actual generated token
step.data['predicted_token'] = next_token
step.data['actual_token_id'] = new_token_ids[token_idx] if token_idx < len(new_token_ids) else None
break
all_tokens.append(next_token)
all_pipelines.append(pipeline_steps)
current_text += next_token
# Store first pipeline for backward compatibility
if token_idx == 0:
self.last_single_token_steps = pipeline_steps
return {
'tokens': all_tokens,
'pipelines': all_pipelines,
'final_text': current_text,
'num_tokens': len(all_tokens)
}
def _analyze_single_token(self, text: str, token_position: int) -> List[PipelineStep]:
"""
Analyze the pipeline for generating a single token
Args:
text: Current text to continue from
token_position: Position of this token in the generation sequence
Returns:
List of PipelineStep objects for this token
"""
steps = []
step_counter = 0
# Step 1: Raw Input
steps.append(PipelineStep(
step_number=step_counter,
step_name="Raw Input",
step_type="input",
description="The original text input provided by the user",
data={"text": text, "length": len(text)}
))
step_counter += 1
# Step 2: Tokenization
inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True)
input_ids = inputs["input_ids"].to(self.device)
tokens = [self.tokenizer.decode([tid]) for tid in input_ids[0]]
token_ids = input_ids[0].tolist()
steps.append(PipelineStep(
step_number=step_counter,
step_name="Tokenization",
step_type="tokenization",
description="Text split into subword tokens using the model's tokenizer",
data={
"tokens": tokens,
"token_ids": token_ids,
"num_tokens": len(tokens),
"tokenizer_name": self.tokenizer.__class__.__name__
}
))
step_counter += 1
# Step 3: Token Embeddings
with torch.no_grad():
# Get token embeddings
if hasattr(self.model, 'transformer'):
embed_layer = self.model.transformer.wte
pos_embed_layer = self.model.transformer.wpe if hasattr(self.model.transformer, 'wpe') else None
else:
embed_layer = self.model.get_input_embeddings()
pos_embed_layer = None
token_embeddings = embed_layer(input_ids)
# Add positional embeddings if available
if pos_embed_layer:
position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=self.device)
position_ids = position_ids.unsqueeze(0)
position_embeddings = pos_embed_layer(position_ids)
embeddings = token_embeddings + position_embeddings
else:
embeddings = token_embeddings
position_embeddings = None
steps.append(PipelineStep(
step_number=step_counter,
step_name="Initial Embeddings",
step_type="embedding",
description="Token embeddings combined with positional encodings",
data={
"embedding_dim": embeddings.shape[-1],
"has_position_encoding": pos_embed_layer is not None,
"embeddings_sample": embeddings[0, :3, :8].cpu().numpy().tolist(), # First 3 tokens, 8 dims
"embeddings_shape": list(embeddings.shape)
}
))
step_counter += 1
# Step 4-N: Process through layers
current_hidden = embeddings
# Get model layers
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
layers = self.model.transformer.h
else:
layers = self.model.encoder.layer if hasattr(self.model, 'encoder') else []
# Process through each layer
for layer_idx, layer in enumerate(layers[:4]): # Sample first 4 layers for performance
# Attention mechanism
layer_output = self._process_layer(layer, current_hidden, layer_idx)
# Add attention step with tokens for labeling
steps.append(PipelineStep(
step_number=step_counter,
step_name=f"Layer {layer_idx} - Multi-Head Attention",
step_type="attention",
description=f"Self-attention computation in layer {layer_idx}",
data={
"layer": layer_idx,
"num_heads": self._get_num_heads(layer),
"attention_pattern": layer_output.get("attention_pattern", None),
"tokens": tokens, # Include tokens for labeling the attention matrix
"hidden_state_norm": float(torch.norm(layer_output["hidden_states"]).item())
}
))
step_counter += 1
# Feed-forward network
if "ffn_output" in layer_output:
steps.append(PipelineStep(
step_number=step_counter,
step_name=f"Layer {layer_idx} - Feed-Forward Network",
step_type="ffn",
description=f"Feed-forward transformation in layer {layer_idx}",
data={
"layer": layer_idx,
"activation": "gelu", # Most transformers use GELU
"hidden_state_norm": float(torch.norm(layer_output["ffn_output"]).item()),
"intermediate_size": layer_output.get("intermediate_size", 4096),
"hidden_size": layer_output.get("hidden_size", 1024),
"activation_stats": layer_output.get("activation_stats", {}),
"gate_values": layer_output.get("gate_values", None),
"tokens": tokens, # Include tokens for context
"token_magnitudes": layer_output.get("token_magnitudes", [])
}
))
step_counter += 1
current_hidden = layer_output["hidden_states"]
# Final layer norm (if exists)
if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'ln_f'):
current_hidden = self.model.transformer.ln_f(current_hidden)
steps.append(PipelineStep(
step_number=step_counter,
step_name="Final Layer Normalization",
step_type="normalization",
description="Normalize hidden states before output projection",
data={
"norm_type": "LayerNorm",
"hidden_state_norm": float(torch.norm(current_hidden).item())
}
))
step_counter += 1
# Output projection
if hasattr(self.model, 'lm_head'):
logits = self.model.lm_head(current_hidden)
else:
logits = current_hidden
# Get probabilities for the last token
last_token_logits = logits[0, -1, :]
probs = torch.softmax(last_token_logits, dim=-1)
# Get top 5 predictions
top_probs, top_indices = torch.topk(probs, 5)
# Decode tokens properly, preserving whitespace and special characters
top_tokens = []
for idx in top_indices.tolist():
decoded = self.tokenizer.decode([idx], skip_special_tokens=False, clean_up_tokenization_spaces=False)
top_tokens.append(decoded)
# Debug logging
if idx == top_indices[0].item():
import logging
logger = logging.getLogger(__name__)
logger.info(f"Token generation - Input: '{text}', Predicted ID: {idx}, Decoded: '{decoded}'")
steps.append(PipelineStep(
step_number=step_counter,
step_name="Output Projection",
step_type="output",
description="Project to vocabulary and compute probabilities",
data={
"vocab_size": logits.shape[-1],
"top_5_tokens": top_tokens,
"top_5_probs": top_probs.cpu().numpy().tolist(),
"predicted_token": top_tokens[0],
"confidence": float(top_probs[0].item())
}
))
step_counter += 1
# Step N: Generated Result
# For code generation, we might want to show the first meaningful token
# Check if the predicted token is just whitespace or quote
predicted_token = top_tokens[0]
display_token = predicted_token
additional_info = ""
# If it's a trivial token (quote, newline, whitespace), note what comes next
if predicted_token in ["'", '"', "\n", " ", " ", "\t"]:
additional_info = f"Next token: '{predicted_token}' (formatting)"
# Show what would come after formatting tokens
if len(top_tokens) > 1:
for alt_token in top_tokens[1:]:
if alt_token not in ["'", '"', "\n", " ", " ", "\t"]:
additional_info += f", likely code token: '{alt_token}'"
break
generated_text = text + predicted_token
steps.append(PipelineStep(
step_number=step_counter,
step_name="Generated Result",
step_type="generated",
description=f"Complete text with token #{token_position + 1}",
data={
"original_text": text,
"predicted_token": predicted_token,
"complete_text": generated_text,
"is_code": "def " in text.lower() or "class " in text.lower() or "import " in text.lower(),
"additional_info": additional_info,
"token_position": token_position + 1
}
))
step_counter += 1
return steps
def _process_layer(self, layer, hidden_states, layer_idx):
"""Process a single transformer layer"""
output = {}
try:
# Process with attention weight capture
with torch.no_grad():
if hasattr(layer, 'attn'):
# GPT-style architecture - capture attention weights
# First apply layer norm if present
ln_output = layer.ln_1(hidden_states) if hasattr(layer, 'ln_1') else hidden_states
# Get attention weights by calling the attention module with output_attentions
qkv = None
if hasattr(layer.attn, 'qkv_proj'):
# CodeGen architecture - has combined QKV projection
qkv = layer.attn.qkv_proj(ln_output)
embed_dim = layer.attn.embed_dim
n_head = layer.attn.num_attention_heads if hasattr(layer.attn, 'num_attention_heads') else 8
elif hasattr(layer.attn, 'c_attn'):
# GPT2-style architecture
qkv = layer.attn.c_attn(ln_output)
embed_dim = layer.attn.embed_dim
n_head = layer.attn.n_head if hasattr(layer.attn, 'n_head') else 8
if qkv is not None:
# Split into Q, K, V
query, key, value = qkv.split(embed_dim, dim=2)
# Reshape for multi-head attention
batch_size, seq_len = query.shape[:2]
head_dim = embed_dim // n_head
query = query.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
# Compute attention scores
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)
# Apply causal mask (for autoregressive models)
if hasattr(layer.attn, 'bias') and layer.attn.bias is not None:
attn_weights = attn_weights + layer.attn.bias[:, :, :seq_len, :seq_len]
else:
# Create causal mask manually if no bias exists
causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=attn_weights.device) * -1e4, diagonal=1)
attn_weights = attn_weights + causal_mask.unsqueeze(0).unsqueeze(0)
# Apply softmax
attn_probs = torch.softmax(attn_weights, dim=-1)
# Average across heads for visualization
avg_attn = attn_probs.mean(dim=1) # Shape: [batch, seq_len, seq_len]
# Store the full attention pattern
output["attention_pattern"] = avg_attn[0].cpu().numpy().tolist() # Full seq_len x seq_len
logger.info(f"Extracted attention pattern with shape: {avg_attn[0].shape}")
# Apply attention to values and continue processing
attn_output = torch.matmul(attn_probs, value)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
# Apply output projection
if hasattr(layer.attn, 'out_proj'):
# CodeGen architecture
attn_output = layer.attn.out_proj(attn_output)
elif hasattr(layer.attn, 'c_proj'):
# GPT2-style architecture
attn_output = layer.attn.c_proj(attn_output)
# Apply residual dropout if present
if hasattr(layer.attn, 'resid_dropout'):
attn_output = layer.attn.resid_dropout(attn_output)
# Add residual connection
attn_output = hidden_states + attn_output
else:
# Fallback for different architecture
attn_output = layer.attn(hidden_states)
if isinstance(attn_output, tuple):
attn_output = attn_output[0]
# Apply MLP with detailed analysis
if hasattr(layer, 'mlp'):
ln2_output = layer.ln_2(attn_output) if hasattr(layer, 'ln_2') else attn_output
# Extract detailed FFN information
if hasattr(layer.mlp, 'fc_in') or hasattr(layer.mlp, 'c_fc'):
# Get intermediate layer
if hasattr(layer.mlp, 'fc_in'):
# CodeGen architecture
intermediate = layer.mlp.fc_in(ln2_output)
output["intermediate_size"] = layer.mlp.fc_in.out_features
output["hidden_size"] = layer.mlp.fc_in.in_features
elif hasattr(layer.mlp, 'c_fc'):
# GPT2 architecture
intermediate = layer.mlp.c_fc(ln2_output)
output["intermediate_size"] = layer.mlp.c_fc.out_features
output["hidden_size"] = layer.mlp.c_fc.in_features
# Compute activation statistics
with torch.no_grad():
act_values = intermediate.detach()
output["activation_stats"] = {
"mean": float(act_values.mean().item()),
"std": float(act_values.std().item()),
"max": float(act_values.max().item()),
"min": float(act_values.min().item()),
"sparsity": float((act_values == 0).float().mean().item()), # Fraction of zeros
"active_neurons": int((act_values.abs() > 0.1).sum().item()) # Neurons with significant activation
}
# Get per-token magnitudes (average activation magnitude per token)
token_mags = act_values.abs().mean(dim=-1)[0].cpu().numpy().tolist()
output["token_magnitudes"] = token_mags
mlp_output = layer.mlp(ln2_output)
output["ffn_output"] = mlp_output
hidden_states = attn_output + mlp_output
else:
hidden_states = attn_output
else:
# BERT-style or other architecture
hidden_states = layer(hidden_states)[0]
output["hidden_states"] = hidden_states
except Exception as e:
logger.warning(f"Error processing layer {layer_idx}: {e}")
import traceback
logger.warning(f"Traceback: {traceback.format_exc()}")
output["hidden_states"] = hidden_states
# Fallback to simple pattern if real extraction fails
if "attention_pattern" not in output:
seq_len = hidden_states.shape[1]
output["attention_pattern"] = np.eye(seq_len).tolist() # Identity matrix as fallback
logger.warning(f"Using fallback attention pattern for layer {layer_idx}")
return output
def _get_num_heads(self, layer):
"""Get number of attention heads in a layer"""
if hasattr(layer, 'attn'):
if hasattr(layer.attn, 'num_attention_heads'):
return layer.attn.num_attention_heads # CodeGen
elif hasattr(layer.attn, 'n_head'):
return layer.attn.n_head # GPT2
elif hasattr(layer.attn, 'num_heads'):
return layer.attn.num_heads # Other architectures
return 8 # Default guess
def get_steps_dict(self) -> List[Dict]:
"""Convert steps to dictionary format for JSON serialization
This is kept for backward compatibility but may not work with multi-token generation.
Use the result from analyze_pipeline directly instead.
"""
# If we have stored steps from single token generation, return them
if hasattr(self, 'last_single_token_steps'):
return [asdict(step) for step in self.last_single_token_steps]
return []