| from typing import Optional, Union, Dict, Any, TYPE_CHECKING
|
| import numpy as np
|
| from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, Device, DType, Layout
|
| from virtual_gpu_driver.src.stream import Stream
|
| from .module import HeliumModule
|
| from .core.db_manager import HeliumDBManager
|
|
|
| if TYPE_CHECKING:
|
| from .tensor import HeliumTensor
|
|
|
| class HeliumLayerNorm(HeliumModule):
|
| """
|
| Hardware-accelerated Layer Normalization implementation
|
|
|
| Applies Layer Normalization over a mini-batch of inputs as described in
|
| the paper "Layer Normalization" [Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton]
|
|
|
| y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
|
|
|
| where gamma (weight) and beta (bias) are learnable parameters.
|
| """
|
|
|
| def __init__(
|
| self,
|
| normalized_shape: int,
|
| eps: float = 1e-5,
|
| device_id: int = 0,
|
| dtype: str = "float32"
|
| ):
|
| """
|
| Initialize layer normalization module
|
|
|
| Args:
|
| normalized_shape: Size of the last dimension
|
| eps: Small value added to variance for numerical stability
|
| device_id: Virtual GPU device ID
|
| dtype: Data type for computations
|
| """
|
| super().__init__(device_id=device_id, dtype=dtype)
|
|
|
| self.normalized_shape = normalized_shape
|
| self.eps = eps
|
|
|
|
|
| self.weight = self._create_param(normalized_shape)
|
| self.bias = self._create_param(normalized_shape)
|
|
|
|
|
| self.stream = Stream(self.driver)
|
|
|
|
|
| self.db = HeliumDBManager.get_instance()
|
|
|
|
|
| self._temp_tensors = {}
|
| self._counter = 0
|
|
|
| def _create_param(self, size: int) -> str:
|
| """Create a parameter tensor"""
|
| desc = TensorDescriptor(
|
| shape=(size,),
|
| dtype=getattr(DType, self.dtype.upper()),
|
| device=Device.VGPU,
|
| layout=Layout.ROW_MAJOR
|
| )
|
| return self.driver.allocate_tensor(desc)
|
|
|
| def _get_temp_tensor(self, shape: tuple) -> str:
|
| """Get a temporary tensor for intermediate computations"""
|
| tensor_id = f"ln_temp_{self._counter}"
|
| self._counter += 1
|
|
|
| desc = TensorDescriptor(
|
| shape=shape,
|
| dtype=getattr(DType, self.dtype.upper()),
|
| device=Device.VGPU,
|
| layout=Layout.ROW_MAJOR
|
| )
|
|
|
| self._temp_tensors[tensor_id] = self.driver.allocate_tensor(desc)
|
| return tensor_id
|
| """
|
| All computations done in driver memory
|
| Returns: name of normalized tensor in driver
|
| """
|
| def _free_temp_tensor(self, tensor_id: str):
|
| """Free a temporary tensor"""
|
| if tensor_id in self._temp_tensors:
|
| self.driver.free_tensor(self._temp_tensors[tensor_id])
|
| del self._temp_tensors[tensor_id]
|
|
|
| def __del__(self):
|
| """Clean up allocated tensors"""
|
| if hasattr(self, '_temp_tensors'):
|
| for tensor_id in list(self._temp_tensors.keys()):
|
| self._free_temp_tensor(tensor_id)
|
|
|
|
|
| if hasattr(self, 'weight'):
|
| self.driver.free_tensor(self.weight)
|
| if hasattr(self, 'bias'):
|
| self.driver.free_tensor(self.bias)
|
|
|
| def _check_input_shape(self, input_shape: tuple):
|
| """Validate input shape"""
|
| if input_shape[-1] != self.normalized_shape:
|
| raise ValueError(
|
| f"Expected last dimension to be {self.normalized_shape}, "
|
| f"got {input_shape[-1]}"
|
| )
|
|
|
| def forward(
|
| self,
|
| input_tensor: Union[str, "HeliumTensor"],
|
| scale: Optional[Union[str, "HeliumTensor"]] = None,
|
| offset: Optional[Union[str, "HeliumTensor"]] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """
|
| Apply layer normalization
|
|
|
| Args:
|
| input_tensor: Input of shape (*, normalized_shape)
|
| scale: Optional override for weight parameter
|
| offset: Optional override for bias parameter
|
|
|
| Returns:
|
| Normalized tensor of same shape as input
|
| """
|
| input_shape = self.driver.get_tensor_shape(input_tensor)
|
| self._check_input_shape(input_shape)
|
|
|
| with self.stream:
|
|
|
| mean = self._get_temp_tensor(input_shape[:-1])
|
| self.driver.reduce_mean(
|
| input_tensor,
|
| mean,
|
| axis=-1,
|
| keepdims=True
|
| )
|
|
|
|
|
| variance = self._get_temp_tensor(input_shape[:-1])
|
| self.driver.reduce_variance(
|
| input_tensor,
|
| variance,
|
| mean,
|
| axis=-1,
|
| keepdims=True
|
| )
|
|
|
|
|
| normalized = self._get_temp_tensor(input_shape)
|
| self.driver.normalize(
|
| input_tensor,
|
| mean,
|
| variance,
|
| normalized,
|
| eps=self.eps
|
| )
|
|
|
|
|
| scale_tensor = scale if scale is not None else self.weight
|
| offset_tensor = offset if offset is not None else self.bias
|
|
|
| output = self._get_temp_tensor(input_shape)
|
| self.driver.scale_and_shift(
|
| normalized,
|
| scale_tensor,
|
| offset_tensor,
|
| output
|
| )
|
|
|
|
|
| self._free_temp_tensor(mean)
|
| self._free_temp_tensor(variance)
|
| self._free_temp_tensor(normalized)
|
|
|
| return output
|
|
|
| def compute_variance(self, state, driver, x_name, mean_name, gamma_name, beta_name, eps=1e-5):
|
| """Compute variance in driver memory"""
|
| chip_id = driver.default_chip_id
|
| sm_id = driver.default_sm_id
|
|
|
| diff_name = state.get_temp_tensor(
|
| driver.sub(x_name, mean_name),
|
| "diff"
|
| )
|
| squared_name = state.get_temp_tensor(
|
| driver.mul(diff_name, diff_name),
|
| "squared"
|
| )
|
| var_name = state.get_temp_tensor(
|
| driver.mean(squared_name, axis=-1, keepdims=True),
|
| "var"
|
| )
|
|
|
|
|
| state.free_temp_tensor(squared_name)
|
|
|
|
|
| std_name = state.get_temp_tensor(
|
| driver.sqrt(driver.add_scalar(var_name, eps)),
|
| "std"
|
| )
|
| normalized_name = state.get_temp_tensor(
|
| driver.div(diff_name, std_name),
|
| "normalized"
|
| )
|
|
|
|
|
| state.free_temp_tensor(diff_name)
|
| state.free_temp_tensor(mean_name)
|
| state.free_temp_tensor(var_name)
|
| state.free_temp_tensor(std_name)
|
|
|
|
|
| scaled_name = state.get_temp_tensor(
|
| driver.mul(normalized_name, gamma_name),
|
| "scaled"
|
| )
|
| output_name = state.get_temp_tensor(
|
| driver.add(scaled_name, beta_name),
|
| "output"
|
| )
|
|
|
|
|
| state.free_temp_tensor(normalized_name)
|
| state.free_temp_tensor(scaled_name)
|
|
|
| return output_name
|
|
|
| def layer_norm(
|
| input_tensor: Union[str, "HeliumTensor"],
|
| normalized_shape: int,
|
| weight: Optional[Union[str, "HeliumTensor"]] = None,
|
| bias: Optional[Union[str, "HeliumTensor"]] = None,
|
| eps: float = 1e-5,
|
| device_id: int = 0,
|
| dtype: str = "float32"
|
| ) -> Union[str, "HeliumTensor"]:
|
| """
|
| Apply Layer Normalization over a mini-batch of inputs
|
|
|
| Args:
|
| input_tensor: Input of shape (*, normalized_shape)
|
| normalized_shape: Size of last dimension to normalize over
|
| weight: Optional scale parameter
|
| bias: Optional offset parameter
|
| eps: Small value for numerical stability
|
| device_id: Virtual GPU device ID
|
| dtype: Data type for computations
|
|
|
| Returns:
|
| Normalized tensor of same shape as input
|
| """
|
| module = HeliumLayerNorm(
|
| normalized_shape=normalized_shape,
|
| eps=eps,
|
| device_id=device_id,
|
| dtype=dtype
|
| )
|
| return module.forward(input_tensor, weight, bias) |