Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """ | |
| Instrumentation layer for capturing model internals during generation. | |
| Designed for PhD study on architectural transparency. | |
| Captures: | |
| - Attention tensors A[L,H,T,T] per layer/head | |
| - Residual norms ||x_l|| per layer | |
| - Logits, logprobs, entropy per token | |
| - Timing per layer | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import Dict, List, Optional, Tuple | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| import time | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class TokenMetadata: | |
| """Metadata for a single generated token""" | |
| token_id: int | |
| text: str | |
| position: int | |
| logprob: float | |
| entropy: float | |
| top_k_tokens: List[Tuple[str, float]] # (token_text, probability) | |
| byte_length: int | |
| timestamp_ms: float | |
| class LayerMetadata: | |
| """Metadata captured per layer during forward pass""" | |
| layer_idx: int | |
| residual_norm: float | |
| time_ms: float | |
| attention_output_norm: Optional[float] = None | |
| ffn_output_norm: Optional[float] = None | |
| class InstrumentationData: | |
| """Complete instrumentation capture for a generation run""" | |
| # Run identification | |
| run_id: str | |
| seed: int | |
| model_name: str | |
| timestamp: float | |
| # Generation parameters | |
| prompt: str | |
| max_tokens: int | |
| temperature: float | |
| top_k: Optional[int] | |
| top_p: Optional[float] | |
| # Token-level data | |
| tokens: List[TokenMetadata] = field(default_factory=list) | |
| # Tensor data (will be stored separately in Zarr) | |
| attention_tensors: Optional[torch.Tensor] = None # [num_tokens, num_layers, num_heads, seq_len, seq_len] | |
| logits_history: Optional[torch.Tensor] = None # [num_tokens, vocab_size] | |
| # Layer-level metadata | |
| layer_metadata: List[List[LayerMetadata]] = field(default_factory=list) # [num_tokens][num_layers] | |
| # Summary statistics | |
| total_time_ms: float = 0.0 | |
| num_layers: int = 0 | |
| num_heads: int = 0 | |
| seq_length: int = 0 | |
| class ModelInstrumentor: | |
| """ | |
| Attaches PyTorch hooks to capture model internals during generation. | |
| Usage: | |
| instrumentor = ModelInstrumentor(model, tokenizer) | |
| with instrumentor.capture(): | |
| outputs = model.generate(...) | |
| data = instrumentor.get_data() | |
| """ | |
| def __init__(self, model, tokenizer, device): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| # Hook handles (for cleanup) | |
| self.hook_handles = [] | |
| # Capture buffers | |
| self.attention_buffer = [] | |
| self.residual_buffer = [] | |
| self.timing_buffer = [] | |
| self.logits_buffer = [] | |
| # Metadata | |
| self.config = model.config | |
| self.num_layers = getattr(self.config, 'num_hidden_layers', getattr(self.config, 'n_layer', 0)) | |
| self.num_heads = getattr(self.config, 'num_attention_heads', getattr(self.config, 'n_head', 0)) | |
| # State | |
| self.capturing = False | |
| self.start_time = None | |
| def _create_attention_hook(self, layer_idx: int): | |
| """ | |
| Create forward hook to capture attention weights for a specific layer. | |
| Attention outputs vary by model: | |
| - GPT-2/CodeGen: (attention_weights, present_key_value) | |
| - Llama: (hidden_states, attention_weights, ...) | |
| We extract the attention_weights tensor which has shape: | |
| [batch_size, num_heads, seq_len, seq_len] | |
| """ | |
| def hook(module, input, output): | |
| if not self.capturing: | |
| return | |
| start_time = time.perf_counter() | |
| try: | |
| # Extract attention weights from output | |
| # For most models, attention_weights is the second element | |
| if isinstance(output, tuple) and len(output) >= 2: | |
| attention_weights = output[1] | |
| if attention_weights is not None and torch.is_tensor(attention_weights): | |
| # Store attention weights | |
| # Shape: [batch_size, num_heads, seq_len, seq_len] | |
| self.attention_buffer.append({ | |
| 'layer_idx': layer_idx, | |
| 'weights': attention_weights.detach().cpu(), | |
| 'timestamp': time.perf_counter() | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Attention hook failed for layer {layer_idx}: {e}") | |
| elapsed_ms = (time.perf_counter() - start_time) * 1000 | |
| self.timing_buffer.append({ | |
| 'layer_idx': layer_idx, | |
| 'time_ms': elapsed_ms, | |
| 'stage': 'attention' | |
| }) | |
| return hook | |
| def _create_residual_hook(self, layer_idx: int): | |
| """ | |
| Create forward hook to capture residual stream norms. | |
| For transformer layers, the output includes the hidden states (residual stream). | |
| We compute ||x_l|| to track representation magnitude. | |
| """ | |
| def hook(module, input, output): | |
| if not self.capturing: | |
| return | |
| try: | |
| # Output is typically (hidden_states, ...) or just hidden_states | |
| hidden_states = output[0] if isinstance(output, tuple) else output | |
| if torch.is_tensor(hidden_states): | |
| # Compute L2 norm across the hidden dimension | |
| # Shape: [batch_size, seq_len, hidden_dim] -> [batch_size, seq_len] | |
| residual_norm = torch.norm(hidden_states, p=2, dim=-1) | |
| # Store mean norm across batch and sequence | |
| mean_norm = residual_norm.mean().item() | |
| self.residual_buffer.append({ | |
| 'layer_idx': layer_idx, | |
| 'norm': mean_norm, | |
| 'timestamp': time.perf_counter() | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Residual hook failed for layer {layer_idx}: {e}") | |
| return hook | |
| def attach_hooks(self): | |
| """Attach forward hooks to all transformer layers""" | |
| logger.info(f"Attaching instrumentation hooks to {self.num_layers} layers...") | |
| # Get model layers based on architecture | |
| # Most models: model.transformer.h (GPT-2, CodeGen) or model.model.layers (Llama) | |
| if hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'): | |
| layers = self.model.transformer.h | |
| elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'): | |
| layers = self.model.model.layers | |
| else: | |
| logger.error("Could not find transformer layers in model") | |
| return | |
| for layer_idx, layer in enumerate(layers): | |
| # Attention hook | |
| attn_hook = self._create_attention_hook(layer_idx) | |
| handle = layer.register_forward_hook(attn_hook) | |
| self.hook_handles.append(handle) | |
| # Residual hook (attach to layer output) | |
| res_hook = self._create_residual_hook(layer_idx) | |
| handle = layer.register_forward_hook(res_hook) | |
| self.hook_handles.append(handle) | |
| logger.info(f"✅ Attached {len(self.hook_handles)} hooks") | |
| def remove_hooks(self): | |
| """Remove all forward hooks""" | |
| for handle in self.hook_handles: | |
| handle.remove() | |
| self.hook_handles = [] | |
| logger.info("Removed instrumentation hooks") | |
| def capture(self): | |
| """Context manager for capturing generation""" | |
| class CaptureContext: | |
| def __init__(self, instrumentor): | |
| self.instrumentor = instrumentor | |
| def __enter__(self): | |
| self.instrumentor.start_capture() | |
| return self.instrumentor | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.instrumentor.stop_capture() | |
| return False | |
| return CaptureContext(self) | |
| def start_capture(self): | |
| """Start capturing data""" | |
| self.capturing = True | |
| self.start_time = time.perf_counter() | |
| self.clear_buffers() | |
| self.attach_hooks() | |
| logger.info("Started instrumentation capture") | |
| def stop_capture(self): | |
| """Stop capturing data""" | |
| self.capturing = False | |
| self.remove_hooks() | |
| logger.info("Stopped instrumentation capture") | |
| def clear_buffers(self): | |
| """Clear all capture buffers""" | |
| self.attention_buffer = [] | |
| self.residual_buffer = [] | |
| self.timing_buffer = [] | |
| self.logits_buffer = [] | |
| def compute_token_metadata(self, token_ids: torch.Tensor, logits: torch.Tensor, position: int) -> TokenMetadata: | |
| """ | |
| Compute metadata for a single token from logits. | |
| Args: | |
| token_ids: Generated token IDs [batch_size] | |
| logits: Model logits [batch_size, vocab_size] | |
| position: Position in sequence | |
| Returns: | |
| TokenMetadata with probabilities, entropy, top-k alternatives | |
| """ | |
| # Get probabilities via softmax | |
| probs = torch.softmax(logits[0], dim=-1) # [vocab_size] | |
| # Get generated token info | |
| token_id = token_ids[0].item() | |
| token_text = self.tokenizer.decode([token_id]) | |
| token_prob = probs[token_id].item() | |
| logprob = np.log(token_prob + 1e-10) | |
| # Compute entropy | |
| # H = -sum(p * log(p)) | |
| entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() | |
| # Get top-k alternatives | |
| top_k = 5 | |
| top_probs, top_indices = torch.topk(probs, k=top_k) | |
| top_k_tokens = [ | |
| (self.tokenizer.decode([idx.item()]), prob.item()) | |
| for idx, prob in zip(top_indices, top_probs) | |
| ] | |
| # Byte length | |
| byte_length = len(token_text.encode('utf-8')) | |
| return TokenMetadata( | |
| token_id=token_id, | |
| text=token_text, | |
| position=position, | |
| logprob=logprob, | |
| entropy=entropy, | |
| top_k_tokens=top_k_tokens, | |
| byte_length=byte_length, | |
| timestamp_ms=(time.perf_counter() - self.start_time) * 1000 | |
| ) | |
| def process_buffers(self) -> Tuple[torch.Tensor, List[List[LayerMetadata]]]: | |
| """ | |
| Process captured buffers into structured tensors. | |
| Returns: | |
| attention_tensor: [num_tokens, num_layers, num_heads, seq_len, seq_len] | |
| layer_metadata: [num_tokens][num_layers] | |
| """ | |
| # Group attention by token step | |
| # Each forward pass captures attention for all layers | |
| # Estimate number of tokens from buffer size | |
| # Each token generates num_layers attention captures | |
| num_tokens = len(self.attention_buffer) // self.num_layers if self.attention_buffer else 0 | |
| if num_tokens == 0: | |
| logger.warning("No attention data captured") | |
| return None, [] | |
| # Organize attention tensors by token and layer | |
| attention_list = [] | |
| layer_metadata_list = [] | |
| for token_idx in range(num_tokens): | |
| token_attentions = [] | |
| token_layer_meta = [] | |
| for layer_idx in range(self.num_layers): | |
| buffer_idx = token_idx * self.num_layers + layer_idx | |
| if buffer_idx < len(self.attention_buffer): | |
| attn_data = self.attention_buffer[buffer_idx] | |
| token_attentions.append(attn_data['weights']) | |
| # Get residual norm | |
| residual_norm = 0.0 | |
| if buffer_idx < len(self.residual_buffer): | |
| residual_norm = self.residual_buffer[buffer_idx]['norm'] | |
| # Get timing | |
| time_ms = 0.0 | |
| if buffer_idx < len(self.timing_buffer): | |
| time_ms = self.timing_buffer[buffer_idx]['time_ms'] | |
| token_layer_meta.append(LayerMetadata( | |
| layer_idx=layer_idx, | |
| residual_norm=residual_norm, | |
| time_ms=time_ms | |
| )) | |
| if token_attentions: | |
| # Stack layer attentions: [num_layers, num_heads, seq_len, seq_len] | |
| attention_list.append(torch.stack(token_attentions)) | |
| layer_metadata_list.append(token_layer_meta) | |
| # Stack token attentions with padding for varying sequence lengths | |
| # During autoregressive generation, seq_len grows with each token | |
| if attention_list: | |
| # Find maximum sequence length across all tokens | |
| max_seq_len = max(attn.shape[-1] for attn in attention_list) | |
| # Pad all tensors to max_seq_len | |
| padded_attentions = [] | |
| for attn in attention_list: | |
| # attn shape: [num_layers, num_heads, seq_len, seq_len] | |
| current_seq_len = attn.shape[-1] | |
| if current_seq_len < max_seq_len: | |
| pad_size = max_seq_len - current_seq_len | |
| # Create zero tensor with correct dtype for padding | |
| pad_shape = list(attn.shape) | |
| pad_shape[-1] = max_seq_len | |
| pad_shape[-2] = max_seq_len | |
| padded = torch.zeros(pad_shape, dtype=attn.dtype, device=attn.device) | |
| # Copy original data into padded tensor | |
| padded[..., :current_seq_len, :current_seq_len] = attn | |
| attn = padded | |
| padded_attentions.append(attn) | |
| # Now stack: [num_tokens, num_layers, num_heads, max_seq_len, max_seq_len] | |
| attention_tensor = torch.stack(padded_attentions) | |
| else: | |
| attention_tensor = None | |
| return attention_tensor, layer_metadata_list | |
| def get_data(self, run_id: str, prompt: str, max_tokens: int, | |
| temperature: float, seed: int, tokens: List[TokenMetadata], | |
| top_k: Optional[int] = None, top_p: Optional[float] = None) -> InstrumentationData: | |
| """ | |
| Package all captured data into InstrumentationData structure. | |
| Args: | |
| run_id: Unique run identifier | |
| prompt: Original prompt | |
| max_tokens: Max tokens setting | |
| temperature: Temperature setting | |
| seed: Random seed used | |
| tokens: List of TokenMetadata for generated tokens | |
| top_k: Top-k sampling parameter | |
| top_p: Top-p sampling parameter | |
| Returns: | |
| InstrumentationData with all captured tensors and metadata | |
| """ | |
| # Process buffers | |
| attention_tensor, layer_metadata = self.process_buffers() | |
| # Calculate total time | |
| total_time_ms = (time.perf_counter() - self.start_time) * 1000 if self.start_time else 0.0 | |
| # Get sequence length from attention tensor | |
| seq_length = attention_tensor.shape[-1] if attention_tensor is not None else 0 | |
| data = InstrumentationData( | |
| run_id=run_id, | |
| seed=seed, | |
| model_name=self.model.config._name_or_path, | |
| timestamp=datetime.now().timestamp(), | |
| prompt=prompt, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| tokens=tokens, | |
| attention_tensors=attention_tensor, | |
| logits_history=None, # Could capture this if needed | |
| layer_metadata=layer_metadata, | |
| total_time_ms=total_time_ms, | |
| num_layers=self.num_layers, | |
| num_heads=self.num_heads, | |
| seq_length=seq_length | |
| ) | |
| logger.info(f"Instrumentation data: {len(tokens)} tokens, " | |
| f"{self.num_layers} layers, {self.num_heads} heads, " | |
| f"seq_len={seq_length}, total_time={total_time_ms:.1f}ms") | |
| return data | |