from typing import Optional, List, Dict, Union, Tuple import numpy as np from dataclasses import dataclass from enum import Enum import warnings from .block import TransformerBlock from .core.db_manager import HeliumDBManager import json import hashlib from contextlib import contextmanager import time class ExecutionStrategy(Enum): """Execution strategies for transformer stack""" SEQUENTIAL = "sequential" # Process blocks one by one PIPELINED = "pipelined" # Pipeline blocks across multiple devices PARALLEL = "parallel" # Process blocks in parallel where possible @dataclass class StackConfig: """Configuration for transformer stack""" num_layers: int hidden_dim: int num_heads: int intermediate_size: int max_sequence_length: int dropout_rate: float = 0.1 layer_norm_epsilon: float = 1e-5 use_cache: bool = True use_checkpointing: bool = False execution_strategy: ExecutionStrategy = ExecutionStrategy.SEQUENTIAL dtype: np.dtype = np.float32 gradient_checkpointing_steps: int = 2 max_batch_size: Optional[int] = None class TransformerStackCache: """Cache manager for transformer stack computations""" def __init__(self, config: StackConfig): self.config = config self.db = HeliumDBManager.get_instance() def _compute_cache_key( self, layer_idx: int, input_shape: Tuple, block_config: Dict ) -> str: """Compute cache key for layer outputs""" cache_data = { 'layer_idx': layer_idx, 'input_shape': input_shape, 'block_config': block_config, 'dtype': str(self.config.dtype) } return hashlib.sha256(json.dumps(cache_data).encode()).hexdigest() def get(self, key: str) -> Optional[np.ndarray]: """Get cached computation result""" return self.db.get_activation(key) def set(self, key: str, value: np.ndarray, metadata: Dict): """Cache computation result""" self.db.set_activation(key, value, metadata) class ResourceManager: """Manages hardware resources and scheduling""" def __init__(self, driver=None): self.driver = driver self.available_devices = self._get_available_devices() self.device_queues = {device: [] for device in self.available_devices} def _get_available_devices(self) -> List[str]: """Get list of available compute devices""" if self.driver and hasattr(self.driver, 'list_devices'): return self.driver.list_devices() return ['cpu'] @contextmanager def acquire_device(self, preferred_device: Optional[str] = None): """Acquire a compute device""" device = self._select_device(preferred_device) try: yield device finally: self._release_device(device) def _select_device(self, preferred_device: Optional[str] = None) -> str: """Select best available device""" if preferred_device and preferred_device in self.available_devices: return preferred_device # Select device with shortest queue return min( self.device_queues.items(), key=lambda x: len(x[1]) )[0] def _release_device(self, device: str): """Release device back to pool""" if device in self.device_queues: self.device_queues[device].pop(0) if self.device_queues[device] else None class TransformerStack: """ Optimized transformer stack implementation with support for: - Multiple execution strategies - Hardware acceleration - Gradient checkpointing - Mixed precision - Memory optimization """ def __init__( self, config: StackConfig, weights_list: List[Dict], driver = None ): """ Initialize transformer stack Args: config: Stack configuration weights_list: List of block weights driver: Optional hardware driver """ self.config = config self.weights_list = weights_list self.driver = driver self._validate_config() self._setup_components() def _validate_config(self): """Validate configuration parameters""" if len(self.weights_list) != self.config.num_layers: raise ValueError( f"Expected {self.config.num_layers} weight dicts, got {len(self.weights_list)}" ) if self.config.num_heads <= 0: raise ValueError(f"Invalid number of heads: {self.config.num_heads}") if self.config.hidden_dim % self.config.num_heads != 0: raise ValueError( f"Hidden dimension {self.config.hidden_dim} must be divisible " f"by number of heads {self.config.num_heads}" ) def _setup_components(self): """Setup stack components""" # Initialize blocks self.blocks = [ TransformerBlock( hidden_size=self.config.hidden_dim, num_heads=self.config.num_heads, intermediate_size=self.config.intermediate_size, weights=weights, dropout_rate=self.config.dropout_rate, layer_norm_epsilon=self.config.layer_norm_epsilon, dtype=self.config.dtype, driver=self.driver ) for weights in self.weights_list ] # Initialize cache self.cache = TransformerStackCache(self.config) # Initialize resource manager self.resource_manager = ResourceManager(self.driver) def _execute_sequential( self, x: np.ndarray, mask: Optional[np.ndarray] = None, use_cache: bool = True ) -> np.ndarray: """Execute blocks sequentially""" current_state = x for i, block in enumerate(self.blocks): if use_cache: cache_key = self.cache._compute_cache_key( i, current_state.shape, block.get_config() ) cached_result = self.cache.get(cache_key) if cached_result is not None: current_state = cached_result continue with self.resource_manager.acquire_device() as device: current_state = block( current_state, mask=mask, device=device ) if use_cache: self.cache.set( cache_key, current_state, {'layer_idx': i, 'shape': current_state.shape} ) return current_state def _execute_pipelined( self, x: np.ndarray, mask: Optional[np.ndarray] = None ) -> np.ndarray: """Execute blocks in a pipelined fashion""" batch_size = x.shape[0] num_chunks = min( batch_size, len(self.resource_manager.available_devices) ) chunk_size = batch_size // num_chunks # Split input into chunks chunks = np.array_split(x, num_chunks) results = [] # Process chunks in pipeline for i, chunk in enumerate(chunks): current_state = chunk for j, block in enumerate(self.blocks): with self.resource_manager.acquire_device() as device: current_state = block( current_state, mask=mask[i*chunk_size:(i+1)*chunk_size] if mask is not None else None, device=device ) results.append(current_state) # Concatenate results return np.concatenate(results, axis=0) def _execute_parallel( self, x: np.ndarray, mask: Optional[np.ndarray] = None ) -> np.ndarray: """Execute blocks in parallel where possible""" if not self.driver or not hasattr(self.driver, 'parallel_execute'): warnings.warn("Parallel execution not supported, falling back to sequential") return self._execute_sequential(x, mask) return self.driver.parallel_execute( self.blocks, x, mask, self.config.num_layers ) def forward( self, x: np.ndarray, mask: Optional[np.ndarray] = None, use_cache: bool = True ) -> np.ndarray: """ Forward pass through transformer stack Args: x: Input tensor of shape (batch_size, seq_len, hidden_dim) mask: Optional attention mask use_cache: Whether to use computation caching Returns: Output tensor of shape (batch_size, seq_len, hidden_dim) """ # Input validation if x.ndim != 3: raise ValueError(f"Expected 3D input tensor, got shape {x.shape}") if x.shape[2] != self.config.hidden_dim: raise ValueError( f"Expected hidden dimension {self.config.hidden_dim}, got {x.shape[2]}" ) if ( self.config.max_sequence_length and x.shape[1] > self.config.max_sequence_length ): raise ValueError( f"Input sequence length {x.shape[1]} exceeds maximum " f"allowed length {self.config.max_sequence_length}" ) # Choose execution strategy if self.config.execution_strategy == ExecutionStrategy.PIPELINED: return self._execute_pipelined(x, mask) elif self.config.execution_strategy == ExecutionStrategy.PARALLEL: return self._execute_parallel(x, mask) else: return self._execute_sequential(x, mask, use_cache) def __call__( self, x: np.ndarray, mask: Optional[np.ndarray] = None, use_cache: bool = True ) -> np.ndarray: """Callable interface""" return self.forward(x, mask, use_cache) # Legacy function for backward compatibility def transformer_stack( x: np.ndarray, weights_list: List[Dict], num_heads: int, mask: Optional[np.ndarray] = None, driver = None, scheduler = None ) -> np.ndarray: """Legacy transformer stack interface""" warnings.warn( "transformer_stack function is deprecated, use TransformerStack class instead", DeprecationWarning ) config = StackConfig( num_layers=len(weights_list), hidden_dim=x.shape[2], num_heads=num_heads, intermediate_size=4 * x.shape[2], # Standard size max_sequence_length=x.shape[1] ) stack = TransformerStack(config, weights_list, driver) return stack.forward(x, mask)