INV / helium /softmax.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Hardware-accelerated softmax implementation for Helium virtual GPU
"""
from typing import Optional, Union, Tuple, TYPE_CHECKING
import helium as he
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout, Tensor
import numpy as np
def softmax(x: Tensor, dim: int = -1) -> Tensor:
"""
Applies the softmax function along a dimension
Args:
x: Input tensor
dim: Dimension along which to apply softmax
Returns:
Softmax output tensor
"""
# For numerical stability, subtract the maximum value
# before applying exp
x_max = x.max(dim=dim, keepdim=True)
exp_x = (x - x_max).exp()
return exp_x / exp_x.sum(dim=dim, keepdim=True)
class HeliumSoftmax:
"""
Optimized softmax implementation for virtual GPU
Handles multi-head attention patterns efficiently
"""
def __init__(self, device_id: Optional[str] = None):
# Get virtual GPU driver
self.driver = he.get_device(device_id) if device_id else he.get_default_device()
self.device_id = device_id
# Track allocated tensors for cleanup
self._temp_tensors = {}
self._counter = 0
def _get_temp_tensor(self, shape: Tuple[int, ...], dtype: str = "float32") -> str:
"""Allocate temporary tensor in device memory"""
tensor_id = f"softmax_temp_{self._counter}"
self._counter += 1
descriptor = TensorDescriptor(
shape=shape,
dtype=getattr(DType, dtype.upper()),
device=Device.VGPU,
layout=Layout.ROW_MAJOR
)
self._temp_tensors[tensor_id] = self.driver.allocate_tensor(descriptor)
return tensor_id
def _free_temp_tensor(self, tensor_id: str):
"""Release temporary tensor memory"""
if tensor_id in self._temp_tensors:
self.driver.free_tensor(self._temp_tensors[tensor_id])
del self._temp_tensors[tensor_id]
def __del__(self):
"""Cleanup all temporary tensors"""
for tensor_id in list(self._temp_tensors.keys()):
self._free_temp_tensor(tensor_id)
def forward(
self,
input_tensor: Union[str, "HeliumTensor"],
dim: int = -1,
memory_efficient: bool = True,
stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
"""
Compute softmax along specified dimension
Args:
input_tensor: Input tensor or tensor name in driver
dim: Dimension to compute softmax over (-1 for last dim)
memory_efficient: Use memory-efficient algorithm
stream_id: Optional stream for async execution
Returns:
Softmax output tensor or tensor name
"""
# Get input tensor info
if isinstance(input_tensor, str):
tensor_name = input_tensor
tensor_info = self.driver.get_tensor_info(tensor_name)
shape = tensor_info.shape
dtype = tensor_info.dtype
else:
tensor_name = input_tensor.name
shape = input_tensor.shape
dtype = input_tensor.dtype.name
# Handle negative dim
if dim < 0:
dim = len(shape) + dim
# Memory-efficient implementation (streaming)
if memory_efficient:
# Compute max along dim
max_shape = list(shape)
max_shape[dim] = 1
max_tensor = self._get_temp_tensor(tuple(max_shape), dtype)
self.driver.reduce_max(
input=tensor_name,
output=max_tensor,
dim=dim,
stream_id=stream_id
)
# Subtract max (for numerical stability)
shifted = self._get_temp_tensor(shape, dtype)
self.driver.broadcast_sub(
input=tensor_name,
other=max_tensor,
output=shifted,
dim=dim,
stream_id=stream_id
)
# Compute exp
exp_tensor = shifted # Reuse memory
self.driver.exp(
input=shifted,
output=exp_tensor,
stream_id=stream_id
)
# Compute sum
sum_tensor = max_tensor # Reuse memory
self.driver.reduce_sum(
input=exp_tensor,
output=sum_tensor,
dim=dim,
stream_id=stream_id
)
# Final division
output = self._get_temp_tensor(shape, dtype)
self.driver.broadcast_div(
input=exp_tensor,
other=sum_tensor,
output=output,
dim=dim,
stream_id=stream_id
)
# Cleanup
self._free_temp_tensor(shifted)
self._free_temp_tensor(sum_tensor)
return output
else:
# Direct implementation (uses more memory but fewer kernels)
output = self._get_temp_tensor(shape, dtype)
self.driver.softmax(
input=tensor_name,
output=output,
dim=dim,
stream_id=stream_id
)
return output
def softmax(
x: Union[str, "HeliumTensor"],
dim: int = -1,
device_id: Optional[str] = None,
memory_efficient: bool = True,
stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
"""
Functional interface for softmax operation
Args:
x: Input tensor
dim: Dimension to compute softmax over
device_id: Virtual GPU device ID
memory_efficient: Use memory-efficient algorithm
stream_id: Optional stream for async execution
Returns:
Softmax output tensor
"""
module = HeliumSoftmax(device_id)
return module.forward(x, dim, memory_efficient, stream_id)