| """
|
| Hardware-accelerated softmax implementation for Helium virtual GPU
|
| """
|
| from typing import Optional, Union, Tuple, TYPE_CHECKING
|
| import helium as he
|
| from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout, Tensor
|
| import numpy as np
|
|
|
| def softmax(x: Tensor, dim: int = -1) -> Tensor:
|
| """
|
| Applies the softmax function along a dimension
|
|
|
| Args:
|
| x: Input tensor
|
| dim: Dimension along which to apply softmax
|
|
|
| Returns:
|
| Softmax output tensor
|
| """
|
|
|
|
|
| x_max = x.max(dim=dim, keepdim=True)
|
| exp_x = (x - x_max).exp()
|
| return exp_x / exp_x.sum(dim=dim, keepdim=True)
|
|
|
| class HeliumSoftmax:
|
| """
|
| Optimized softmax implementation for virtual GPU
|
| Handles multi-head attention patterns efficiently
|
| """
|
| def __init__(self, device_id: Optional[str] = None):
|
|
|
| self.driver = he.get_device(device_id) if device_id else he.get_default_device()
|
| self.device_id = device_id
|
|
|
|
|
| self._temp_tensors = {}
|
| self._counter = 0
|
|
|
| def _get_temp_tensor(self, shape: Tuple[int, ...], dtype: str = "float32") -> str:
|
| """Allocate temporary tensor in device memory"""
|
| tensor_id = f"softmax_temp_{self._counter}"
|
| self._counter += 1
|
|
|
| descriptor = TensorDescriptor(
|
| shape=shape,
|
| dtype=getattr(DType, dtype.upper()),
|
| device=Device.VGPU,
|
| layout=Layout.ROW_MAJOR
|
| )
|
|
|
| self._temp_tensors[tensor_id] = self.driver.allocate_tensor(descriptor)
|
| return tensor_id
|
|
|
| def _free_temp_tensor(self, tensor_id: str):
|
| """Release temporary tensor memory"""
|
| if tensor_id in self._temp_tensors:
|
| self.driver.free_tensor(self._temp_tensors[tensor_id])
|
| del self._temp_tensors[tensor_id]
|
|
|
| def __del__(self):
|
| """Cleanup all temporary tensors"""
|
| for tensor_id in list(self._temp_tensors.keys()):
|
| self._free_temp_tensor(tensor_id)
|
|
|
| def forward(
|
| self,
|
| input_tensor: Union[str, "HeliumTensor"],
|
| dim: int = -1,
|
| memory_efficient: bool = True,
|
| stream_id: Optional[int] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """
|
| Compute softmax along specified dimension
|
|
|
| Args:
|
| input_tensor: Input tensor or tensor name in driver
|
| dim: Dimension to compute softmax over (-1 for last dim)
|
| memory_efficient: Use memory-efficient algorithm
|
| stream_id: Optional stream for async execution
|
|
|
| Returns:
|
| Softmax output tensor or tensor name
|
| """
|
|
|
| if isinstance(input_tensor, str):
|
| tensor_name = input_tensor
|
| tensor_info = self.driver.get_tensor_info(tensor_name)
|
| shape = tensor_info.shape
|
| dtype = tensor_info.dtype
|
| else:
|
| tensor_name = input_tensor.name
|
| shape = input_tensor.shape
|
| dtype = input_tensor.dtype.name
|
|
|
|
|
| if dim < 0:
|
| dim = len(shape) + dim
|
|
|
|
|
| if memory_efficient:
|
|
|
| max_shape = list(shape)
|
| max_shape[dim] = 1
|
| max_tensor = self._get_temp_tensor(tuple(max_shape), dtype)
|
|
|
| self.driver.reduce_max(
|
| input=tensor_name,
|
| output=max_tensor,
|
| dim=dim,
|
| stream_id=stream_id
|
| )
|
|
|
|
|
| shifted = self._get_temp_tensor(shape, dtype)
|
| self.driver.broadcast_sub(
|
| input=tensor_name,
|
| other=max_tensor,
|
| output=shifted,
|
| dim=dim,
|
| stream_id=stream_id
|
| )
|
|
|
|
|
| exp_tensor = shifted
|
| self.driver.exp(
|
| input=shifted,
|
| output=exp_tensor,
|
| stream_id=stream_id
|
| )
|
|
|
|
|
| sum_tensor = max_tensor
|
| self.driver.reduce_sum(
|
| input=exp_tensor,
|
| output=sum_tensor,
|
| dim=dim,
|
| stream_id=stream_id
|
| )
|
|
|
|
|
| output = self._get_temp_tensor(shape, dtype)
|
| self.driver.broadcast_div(
|
| input=exp_tensor,
|
| other=sum_tensor,
|
| output=output,
|
| dim=dim,
|
| stream_id=stream_id
|
| )
|
|
|
|
|
| self._free_temp_tensor(shifted)
|
| self._free_temp_tensor(sum_tensor)
|
|
|
| return output
|
|
|
| else:
|
|
|
| output = self._get_temp_tensor(shape, dtype)
|
|
|
| self.driver.softmax(
|
| input=tensor_name,
|
| output=output,
|
| dim=dim,
|
| stream_id=stream_id
|
| )
|
|
|
| return output
|
|
|
| def softmax(
|
| x: Union[str, "HeliumTensor"],
|
| dim: int = -1,
|
| device_id: Optional[str] = None,
|
| memory_efficient: bool = True,
|
| stream_id: Optional[int] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """
|
| Functional interface for softmax operation
|
|
|
| Args:
|
| x: Input tensor
|
| dim: Dimension to compute softmax over
|
| device_id: Virtual GPU device ID
|
| memory_efficient: Use memory-efficient algorithm
|
| stream_id: Optional stream for async execution
|
|
|
| Returns:
|
| Softmax output tensor
|
| """
|
| module = HeliumSoftmax(device_id)
|
| return module.forward(x, dim, memory_efficient, stream_id)
|
|
|