INV / helium /layer_norm.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import Optional, Union, Dict, Any, TYPE_CHECKING
import numpy as np
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, Device, DType, Layout
from virtual_gpu_driver.src.stream import Stream
from .module import HeliumModule
from .core.db_manager import HeliumDBManager
if TYPE_CHECKING:
from .tensor import HeliumTensor
class HeliumLayerNorm(HeliumModule):
"""
Hardware-accelerated Layer Normalization implementation
Applies Layer Normalization over a mini-batch of inputs as described in
the paper "Layer Normalization" [Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton]
y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
where gamma (weight) and beta (bias) are learnable parameters.
"""
def __init__(
self,
normalized_shape: int,
eps: float = 1e-5,
device_id: int = 0,
dtype: str = "float32"
):
"""
Initialize layer normalization module
Args:
normalized_shape: Size of the last dimension
eps: Small value added to variance for numerical stability
device_id: Virtual GPU device ID
dtype: Data type for computations
"""
super().__init__(device_id=device_id, dtype=dtype)
self.normalized_shape = normalized_shape
self.eps = eps
# Create parameter tensors
self.weight = self._create_param(normalized_shape)
self.bias = self._create_param(normalized_shape)
# Create stream for async execution
self.stream = Stream(self.driver)
# Get database manager instance
self.db = HeliumDBManager.get_instance()
# Initialize temp tensors dict
self._temp_tensors = {}
self._counter = 0
def _create_param(self, size: int) -> str:
"""Create a parameter tensor"""
desc = TensorDescriptor(
shape=(size,),
dtype=getattr(DType, self.dtype.upper()),
device=Device.VGPU,
layout=Layout.ROW_MAJOR
)
return self.driver.allocate_tensor(desc)
def _get_temp_tensor(self, shape: tuple) -> str:
"""Get a temporary tensor for intermediate computations"""
tensor_id = f"ln_temp_{self._counter}"
self._counter += 1
desc = TensorDescriptor(
shape=shape,
dtype=getattr(DType, self.dtype.upper()),
device=Device.VGPU,
layout=Layout.ROW_MAJOR
)
self._temp_tensors[tensor_id] = self.driver.allocate_tensor(desc)
return tensor_id
"""
All computations done in driver memory
Returns: name of normalized tensor in driver
"""
def _free_temp_tensor(self, tensor_id: str):
"""Free a temporary tensor"""
if tensor_id in self._temp_tensors:
self.driver.free_tensor(self._temp_tensors[tensor_id])
del self._temp_tensors[tensor_id]
def __del__(self):
"""Clean up allocated tensors"""
if hasattr(self, '_temp_tensors'):
for tensor_id in list(self._temp_tensors.keys()):
self._free_temp_tensor(tensor_id)
# Free parameter tensors
if hasattr(self, 'weight'):
self.driver.free_tensor(self.weight)
if hasattr(self, 'bias'):
self.driver.free_tensor(self.bias)
def _check_input_shape(self, input_shape: tuple):
"""Validate input shape"""
if input_shape[-1] != self.normalized_shape:
raise ValueError(
f"Expected last dimension to be {self.normalized_shape}, "
f"got {input_shape[-1]}"
)
def forward(
self,
input_tensor: Union[str, "HeliumTensor"],
scale: Optional[Union[str, "HeliumTensor"]] = None,
offset: Optional[Union[str, "HeliumTensor"]] = None
) -> Union[str, "HeliumTensor"]:
"""
Apply layer normalization
Args:
input_tensor: Input of shape (*, normalized_shape)
scale: Optional override for weight parameter
offset: Optional override for bias parameter
Returns:
Normalized tensor of same shape as input
"""
input_shape = self.driver.get_tensor_shape(input_tensor)
self._check_input_shape(input_shape)
with self.stream:
# Calculate mean
mean = self._get_temp_tensor(input_shape[:-1])
self.driver.reduce_mean(
input_tensor,
mean,
axis=-1,
keepdims=True
)
# Calculate variance
variance = self._get_temp_tensor(input_shape[:-1])
self.driver.reduce_variance(
input_tensor,
variance,
mean,
axis=-1,
keepdims=True
)
# Normalize
normalized = self._get_temp_tensor(input_shape)
self.driver.normalize(
input_tensor,
mean,
variance,
normalized,
eps=self.eps
)
# Scale and offset
scale_tensor = scale if scale is not None else self.weight
offset_tensor = offset if offset is not None else self.bias
output = self._get_temp_tensor(input_shape)
self.driver.scale_and_shift(
normalized,
scale_tensor,
offset_tensor,
output
)
# Clean up intermediate tensors
self._free_temp_tensor(mean)
self._free_temp_tensor(variance)
self._free_temp_tensor(normalized)
return output
def compute_variance(self, state, driver, x_name, mean_name, gamma_name, beta_name, eps=1e-5):
"""Compute variance in driver memory"""
chip_id = driver.default_chip_id
sm_id = driver.default_sm_id
diff_name = state.get_temp_tensor(
driver.sub(x_name, mean_name),
"diff"
)
squared_name = state.get_temp_tensor(
driver.mul(diff_name, diff_name),
"squared"
)
var_name = state.get_temp_tensor(
driver.mean(squared_name, axis=-1, keepdims=True),
"var"
)
# Free intermediates
state.free_temp_tensor(squared_name)
# Normalize in driver memory
std_name = state.get_temp_tensor(
driver.sqrt(driver.add_scalar(var_name, eps)),
"std"
)
normalized_name = state.get_temp_tensor(
driver.div(diff_name, std_name),
"normalized"
)
# Free more intermediates
state.free_temp_tensor(diff_name)
state.free_temp_tensor(mean_name)
state.free_temp_tensor(var_name)
state.free_temp_tensor(std_name)
# Scale and shift in driver memory
scaled_name = state.get_temp_tensor(
driver.mul(normalized_name, gamma_name),
"scaled"
)
output_name = state.get_temp_tensor(
driver.add(scaled_name, beta_name),
"output"
)
# Free final intermediates
state.free_temp_tensor(normalized_name)
state.free_temp_tensor(scaled_name)
return output_name
def layer_norm(
input_tensor: Union[str, "HeliumTensor"],
normalized_shape: int,
weight: Optional[Union[str, "HeliumTensor"]] = None,
bias: Optional[Union[str, "HeliumTensor"]] = None,
eps: float = 1e-5,
device_id: int = 0,
dtype: str = "float32"
) -> Union[str, "HeliumTensor"]:
"""
Apply Layer Normalization over a mini-batch of inputs
Args:
input_tensor: Input of shape (*, normalized_shape)
normalized_shape: Size of last dimension to normalize over
weight: Optional scale parameter
bias: Optional offset parameter
eps: Small value for numerical stability
device_id: Virtual GPU device ID
dtype: Data type for computations
Returns:
Normalized tensor of same shape as input
"""
module = HeliumLayerNorm(
normalized_shape=normalized_shape,
eps=eps,
device_id=device_id,
dtype=dtype
)
return module.forward(input_tensor, weight, bias)