|
|
from typing import Optional, List, Dict, Union, Tuple
|
|
|
import numpy as np
|
|
|
from dataclasses import dataclass
|
|
|
from enum import Enum
|
|
|
import warnings
|
|
|
from .block import TransformerBlock
|
|
|
from .core.db_manager import HeliumDBManager
|
|
|
import json
|
|
|
import hashlib
|
|
|
from contextlib import contextmanager
|
|
|
import time
|
|
|
|
|
|
class ExecutionStrategy(Enum):
|
|
|
"""Execution strategies for transformer stack"""
|
|
|
SEQUENTIAL = "sequential"
|
|
|
PIPELINED = "pipelined"
|
|
|
PARALLEL = "parallel"
|
|
|
|
|
|
@dataclass
|
|
|
class StackConfig:
|
|
|
"""Configuration for transformer stack"""
|
|
|
num_layers: int
|
|
|
hidden_dim: int
|
|
|
num_heads: int
|
|
|
intermediate_size: int
|
|
|
max_sequence_length: int
|
|
|
dropout_rate: float = 0.1
|
|
|
layer_norm_epsilon: float = 1e-5
|
|
|
use_cache: bool = True
|
|
|
use_checkpointing: bool = False
|
|
|
execution_strategy: ExecutionStrategy = ExecutionStrategy.SEQUENTIAL
|
|
|
dtype: np.dtype = np.float32
|
|
|
gradient_checkpointing_steps: int = 2
|
|
|
max_batch_size: Optional[int] = None
|
|
|
|
|
|
class TransformerStackCache:
|
|
|
"""Cache manager for transformer stack computations"""
|
|
|
def __init__(self, config: StackConfig):
|
|
|
self.config = config
|
|
|
self.db = HeliumDBManager.get_instance()
|
|
|
|
|
|
def _compute_cache_key(
|
|
|
self,
|
|
|
layer_idx: int,
|
|
|
input_shape: Tuple,
|
|
|
block_config: Dict
|
|
|
) -> str:
|
|
|
"""Compute cache key for layer outputs"""
|
|
|
cache_data = {
|
|
|
'layer_idx': layer_idx,
|
|
|
'input_shape': input_shape,
|
|
|
'block_config': block_config,
|
|
|
'dtype': str(self.config.dtype)
|
|
|
}
|
|
|
return hashlib.sha256(json.dumps(cache_data).encode()).hexdigest()
|
|
|
|
|
|
def get(self, key: str) -> Optional[np.ndarray]:
|
|
|
"""Get cached computation result"""
|
|
|
return self.db.get_activation(key)
|
|
|
|
|
|
def set(self, key: str, value: np.ndarray, metadata: Dict):
|
|
|
"""Cache computation result"""
|
|
|
self.db.set_activation(key, value, metadata)
|
|
|
|
|
|
class ResourceManager:
|
|
|
"""Manages hardware resources and scheduling"""
|
|
|
def __init__(self, driver=None):
|
|
|
self.driver = driver
|
|
|
self.available_devices = self._get_available_devices()
|
|
|
self.device_queues = {device: [] for device in self.available_devices}
|
|
|
|
|
|
def _get_available_devices(self) -> List[str]:
|
|
|
"""Get list of available compute devices"""
|
|
|
if self.driver and hasattr(self.driver, 'list_devices'):
|
|
|
return self.driver.list_devices()
|
|
|
return ['cpu']
|
|
|
|
|
|
@contextmanager
|
|
|
def acquire_device(self, preferred_device: Optional[str] = None):
|
|
|
"""Acquire a compute device"""
|
|
|
device = self._select_device(preferred_device)
|
|
|
try:
|
|
|
yield device
|
|
|
finally:
|
|
|
self._release_device(device)
|
|
|
|
|
|
def _select_device(self, preferred_device: Optional[str] = None) -> str:
|
|
|
"""Select best available device"""
|
|
|
if preferred_device and preferred_device in self.available_devices:
|
|
|
return preferred_device
|
|
|
|
|
|
|
|
|
return min(
|
|
|
self.device_queues.items(),
|
|
|
key=lambda x: len(x[1])
|
|
|
)[0]
|
|
|
|
|
|
def _release_device(self, device: str):
|
|
|
"""Release device back to pool"""
|
|
|
if device in self.device_queues:
|
|
|
self.device_queues[device].pop(0) if self.device_queues[device] else None
|
|
|
|
|
|
class TransformerStack:
|
|
|
"""
|
|
|
Optimized transformer stack implementation with support for:
|
|
|
- Multiple execution strategies
|
|
|
- Hardware acceleration
|
|
|
- Gradient checkpointing
|
|
|
- Mixed precision
|
|
|
- Memory optimization
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: StackConfig,
|
|
|
weights_list: List[Dict],
|
|
|
driver = None
|
|
|
):
|
|
|
"""
|
|
|
Initialize transformer stack
|
|
|
|
|
|
Args:
|
|
|
config: Stack configuration
|
|
|
weights_list: List of block weights
|
|
|
driver: Optional hardware driver
|
|
|
"""
|
|
|
self.config = config
|
|
|
self.weights_list = weights_list
|
|
|
self.driver = driver
|
|
|
|
|
|
self._validate_config()
|
|
|
self._setup_components()
|
|
|
|
|
|
def _validate_config(self):
|
|
|
"""Validate configuration parameters"""
|
|
|
if len(self.weights_list) != self.config.num_layers:
|
|
|
raise ValueError(
|
|
|
f"Expected {self.config.num_layers} weight dicts, got {len(self.weights_list)}"
|
|
|
)
|
|
|
|
|
|
if self.config.num_heads <= 0:
|
|
|
raise ValueError(f"Invalid number of heads: {self.config.num_heads}")
|
|
|
|
|
|
if self.config.hidden_dim % self.config.num_heads != 0:
|
|
|
raise ValueError(
|
|
|
f"Hidden dimension {self.config.hidden_dim} must be divisible "
|
|
|
f"by number of heads {self.config.num_heads}"
|
|
|
)
|
|
|
|
|
|
def _setup_components(self):
|
|
|
"""Setup stack components"""
|
|
|
|
|
|
self.blocks = [
|
|
|
TransformerBlock(
|
|
|
hidden_size=self.config.hidden_dim,
|
|
|
num_heads=self.config.num_heads,
|
|
|
intermediate_size=self.config.intermediate_size,
|
|
|
weights=weights,
|
|
|
dropout_rate=self.config.dropout_rate,
|
|
|
layer_norm_epsilon=self.config.layer_norm_epsilon,
|
|
|
dtype=self.config.dtype,
|
|
|
driver=self.driver
|
|
|
)
|
|
|
for weights in self.weights_list
|
|
|
]
|
|
|
|
|
|
|
|
|
self.cache = TransformerStackCache(self.config)
|
|
|
|
|
|
|
|
|
self.resource_manager = ResourceManager(self.driver)
|
|
|
|
|
|
def _execute_sequential(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
mask: Optional[np.ndarray] = None,
|
|
|
use_cache: bool = True
|
|
|
) -> np.ndarray:
|
|
|
"""Execute blocks sequentially"""
|
|
|
current_state = x
|
|
|
|
|
|
for i, block in enumerate(self.blocks):
|
|
|
if use_cache:
|
|
|
cache_key = self.cache._compute_cache_key(
|
|
|
i, current_state.shape, block.get_config()
|
|
|
)
|
|
|
cached_result = self.cache.get(cache_key)
|
|
|
if cached_result is not None:
|
|
|
current_state = cached_result
|
|
|
continue
|
|
|
|
|
|
with self.resource_manager.acquire_device() as device:
|
|
|
current_state = block(
|
|
|
current_state,
|
|
|
mask=mask,
|
|
|
device=device
|
|
|
)
|
|
|
|
|
|
if use_cache:
|
|
|
self.cache.set(
|
|
|
cache_key,
|
|
|
current_state,
|
|
|
{'layer_idx': i, 'shape': current_state.shape}
|
|
|
)
|
|
|
|
|
|
return current_state
|
|
|
|
|
|
def _execute_pipelined(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
mask: Optional[np.ndarray] = None
|
|
|
) -> np.ndarray:
|
|
|
"""Execute blocks in a pipelined fashion"""
|
|
|
batch_size = x.shape[0]
|
|
|
num_chunks = min(
|
|
|
batch_size,
|
|
|
len(self.resource_manager.available_devices)
|
|
|
)
|
|
|
chunk_size = batch_size // num_chunks
|
|
|
|
|
|
|
|
|
chunks = np.array_split(x, num_chunks)
|
|
|
results = []
|
|
|
|
|
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
|
current_state = chunk
|
|
|
for j, block in enumerate(self.blocks):
|
|
|
with self.resource_manager.acquire_device() as device:
|
|
|
current_state = block(
|
|
|
current_state,
|
|
|
mask=mask[i*chunk_size:(i+1)*chunk_size] if mask is not None else None,
|
|
|
device=device
|
|
|
)
|
|
|
results.append(current_state)
|
|
|
|
|
|
|
|
|
return np.concatenate(results, axis=0)
|
|
|
|
|
|
def _execute_parallel(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
mask: Optional[np.ndarray] = None
|
|
|
) -> np.ndarray:
|
|
|
"""Execute blocks in parallel where possible"""
|
|
|
if not self.driver or not hasattr(self.driver, 'parallel_execute'):
|
|
|
warnings.warn("Parallel execution not supported, falling back to sequential")
|
|
|
return self._execute_sequential(x, mask)
|
|
|
|
|
|
return self.driver.parallel_execute(
|
|
|
self.blocks,
|
|
|
x,
|
|
|
mask,
|
|
|
self.config.num_layers
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
mask: Optional[np.ndarray] = None,
|
|
|
use_cache: bool = True
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Forward pass through transformer stack
|
|
|
|
|
|
Args:
|
|
|
x: Input tensor of shape (batch_size, seq_len, hidden_dim)
|
|
|
mask: Optional attention mask
|
|
|
use_cache: Whether to use computation caching
|
|
|
|
|
|
Returns:
|
|
|
Output tensor of shape (batch_size, seq_len, hidden_dim)
|
|
|
"""
|
|
|
|
|
|
if x.ndim != 3:
|
|
|
raise ValueError(f"Expected 3D input tensor, got shape {x.shape}")
|
|
|
|
|
|
if x.shape[2] != self.config.hidden_dim:
|
|
|
raise ValueError(
|
|
|
f"Expected hidden dimension {self.config.hidden_dim}, got {x.shape[2]}"
|
|
|
)
|
|
|
|
|
|
if (
|
|
|
self.config.max_sequence_length and
|
|
|
x.shape[1] > self.config.max_sequence_length
|
|
|
):
|
|
|
raise ValueError(
|
|
|
f"Input sequence length {x.shape[1]} exceeds maximum "
|
|
|
f"allowed length {self.config.max_sequence_length}"
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.config.execution_strategy == ExecutionStrategy.PIPELINED:
|
|
|
return self._execute_pipelined(x, mask)
|
|
|
elif self.config.execution_strategy == ExecutionStrategy.PARALLEL:
|
|
|
return self._execute_parallel(x, mask)
|
|
|
else:
|
|
|
return self._execute_sequential(x, mask, use_cache)
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
mask: Optional[np.ndarray] = None,
|
|
|
use_cache: bool = True
|
|
|
) -> np.ndarray:
|
|
|
"""Callable interface"""
|
|
|
return self.forward(x, mask, use_cache)
|
|
|
|
|
|
|
|
|
def transformer_stack(
|
|
|
x: np.ndarray,
|
|
|
weights_list: List[Dict],
|
|
|
num_heads: int,
|
|
|
mask: Optional[np.ndarray] = None,
|
|
|
driver = None,
|
|
|
scheduler = None
|
|
|
) -> np.ndarray:
|
|
|
"""Legacy transformer stack interface"""
|
|
|
warnings.warn(
|
|
|
"transformer_stack function is deprecated, use TransformerStack class instead",
|
|
|
DeprecationWarning
|
|
|
)
|
|
|
|
|
|
config = StackConfig(
|
|
|
num_layers=len(weights_list),
|
|
|
hidden_dim=x.shape[2],
|
|
|
num_heads=num_heads,
|
|
|
intermediate_size=4 * x.shape[2],
|
|
|
max_sequence_length=x.shape[1]
|
|
|
)
|
|
|
|
|
|
stack = TransformerStack(config, weights_list, driver)
|
|
|
return stack.forward(x, mask)
|
|
|
|