| """
|
| Hardware-accelerated pooling operations for Helium virtual GPU
|
| """
|
| from typing import Optional, Union, Tuple, List, TYPE_CHECKING
|
| from virtual_gpu_driver.src.ai.tensor_types import (
|
| TensorDescriptor, DType, Device, Layout,
|
| PoolingDescriptor, PoolingMode
|
| )
|
| from .main import get_device, get_default_device
|
|
|
| if TYPE_CHECKING:
|
| from .main import HeliumTensor
|
|
|
| class HeliumPooling2D:
|
| """Base class for 2D pooling operations on virtual GPU"""
|
|
|
| def __init__(
|
| self,
|
| kernel_size: Union[int, Tuple[int, int]],
|
| stride: Optional[Union[int, Tuple[int, int]]] = None,
|
| padding: Union[int, Tuple[int, int]] = 0,
|
| device_id: Optional[str] = None
|
| ):
|
|
|
| self.driver = get_device(device_id) if device_id else get_default_device()
|
| self.device_id = device_id
|
|
|
|
|
| if isinstance(kernel_size, int):
|
| self.kernel_height = self.kernel_width = kernel_size
|
| else:
|
| self.kernel_height, self.kernel_width = kernel_size
|
|
|
|
|
| if stride is None:
|
| self.stride_height = self.kernel_height
|
| self.stride_width = self.kernel_width
|
| elif isinstance(stride, int):
|
| self.stride_height = self.stride_width = stride
|
| else:
|
| self.stride_height, self.stride_width = stride
|
|
|
|
|
| if isinstance(padding, int):
|
| self.padding_height = self.padding_width = padding
|
| else:
|
| self.padding_height, self.padding_width = padding
|
|
|
|
|
| self._temp_tensors = {}
|
| self._counter = 0
|
|
|
| def _get_temp_tensor(self, shape: Tuple[int, ...], dtype: str = "float32") -> str:
|
| """Allocate temporary tensor in device memory"""
|
| tensor_id = f"pool_temp_{self._counter}"
|
| self._counter += 1
|
|
|
| descriptor = TensorDescriptor(
|
| shape=shape,
|
| dtype=getattr(DType, dtype.upper()),
|
| device=Device.VGPU,
|
| layout=Layout.NHWC
|
| )
|
|
|
| self._temp_tensors[tensor_id] = self.driver.allocate_tensor(descriptor)
|
| return tensor_id
|
|
|
| def _free_temp_tensor(self, tensor_id: str):
|
| """Release temporary tensor memory"""
|
| if tensor_id in self._temp_tensors:
|
| self.driver.free_tensor(self._temp_tensors[tensor_id])
|
| del self._temp_tensors[tensor_id]
|
|
|
| def __del__(self):
|
| """Cleanup temporary tensors"""
|
| for tensor_id in list(self._temp_tensors.keys()):
|
| self._free_temp_tensor(tensor_id)
|
|
|
| def _create_pooling_descriptor(self, mode: PoolingMode) -> PoolingDescriptor:
|
| """Create pooling descriptor for hardware"""
|
| return PoolingDescriptor(
|
| mode=mode,
|
| kernel_height=self.kernel_height,
|
| kernel_width=self.kernel_width,
|
| stride_height=self.stride_height,
|
| stride_width=self.stride_width,
|
| padding_height=self.padding_height,
|
| padding_width=self.padding_width
|
| )
|
|
|
| class MaxPool2D(HeliumPooling2D):
|
| """2D max pooling layer"""
|
|
|
| def forward(
|
| self,
|
| input_tensor: Union[str, "HeliumTensor"],
|
| stream_id: Optional[int] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """
|
| Compute 2D max pooling
|
|
|
| Args:
|
| input_tensor: Input tensor (NCHW format)
|
| stream_id: Optional stream for async execution
|
| """
|
|
|
| if isinstance(input_tensor, str):
|
| tensor_name = input_tensor
|
| tensor_info = self.driver.get_tensor_info(tensor_name)
|
| input_shape = tensor_info.shape
|
| dtype = tensor_info.dtype
|
| else:
|
| tensor_name = input_tensor.name
|
| input_shape = input_tensor.shape
|
| dtype = input_tensor.dtype.name
|
|
|
|
|
| batch_size, channels, in_height, in_width = input_shape
|
| out_height = (in_height + 2*self.padding_height - self.kernel_height) // self.stride_height + 1
|
| out_width = (in_width + 2*self.padding_width - self.kernel_width) // self.stride_width + 1
|
| output_shape = (batch_size, channels, out_height, out_width)
|
|
|
|
|
| output = self._get_temp_tensor(output_shape, dtype)
|
|
|
|
|
| pool_desc = self._create_pooling_descriptor(PoolingMode.MAX)
|
|
|
|
|
| self.driver.pooling_forward(
|
| pooling_desc=pool_desc,
|
| input=tensor_name,
|
| output=output,
|
| stream_id=stream_id
|
| )
|
|
|
| return output
|
|
|
| class AvgPool2D(HeliumPooling2D):
|
| """2D average pooling layer"""
|
|
|
| def forward(
|
| self,
|
| input_tensor: Union[str, "HeliumTensor"],
|
| stream_id: Optional[int] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """
|
| Compute 2D average pooling
|
|
|
| Args:
|
| input_tensor: Input tensor (NCHW format)
|
| stream_id: Optional stream for async execution
|
| """
|
|
|
| if isinstance(input_tensor, str):
|
| tensor_name = input_tensor
|
| tensor_info = self.driver.get_tensor_info(tensor_name)
|
| input_shape = tensor_info.shape
|
| dtype = tensor_info.dtype
|
| else:
|
| tensor_name = input_tensor.name
|
| input_shape = input_tensor.shape
|
| dtype = input_tensor.dtype.name
|
|
|
|
|
| batch_size, channels, in_height, in_width = input_shape
|
| out_height = (in_height + 2*self.padding_height - self.kernel_height) // self.stride_height + 1
|
| out_width = (in_width + 2*self.padding_width - self.kernel_width) // self.stride_width + 1
|
| output_shape = (batch_size, channels, out_height, out_width)
|
|
|
|
|
| output = self._get_temp_tensor(output_shape, dtype)
|
|
|
|
|
| pool_desc = self._create_pooling_descriptor(PoolingMode.AVERAGE)
|
|
|
|
|
| self.driver.pooling_forward(
|
| pooling_desc=pool_desc,
|
| input=tensor_name,
|
| output=output,
|
| stream_id=stream_id
|
| )
|
|
|
| return output
|
|
|
| class GlobalAvgPool2D:
|
| """Global average pooling layer"""
|
|
|
| def __init__(self, device_id: Optional[str] = None):
|
| self.driver = get_device(device_id) if device_id else get_default_device()
|
| self.device_id = device_id
|
| self._temp_tensors = {}
|
| self._counter = 0
|
|
|
| def _get_temp_tensor(self, shape: Tuple[int, ...], dtype: str = "float32") -> str:
|
| tensor_id = f"global_pool_temp_{self._counter}"
|
| self._counter += 1
|
|
|
| descriptor = TensorDescriptor(
|
| shape=shape,
|
| dtype=getattr(DType, dtype.upper()),
|
| device=Device.VGPU,
|
| layout=Layout.NHWC
|
| )
|
|
|
| self._temp_tensors[tensor_id] = self.driver.allocate_tensor(descriptor)
|
| return tensor_id
|
|
|
| def forward(
|
| self,
|
| input_tensor: Union[str, "HeliumTensor"],
|
| stream_id: Optional[int] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """
|
| Compute global average pooling
|
|
|
| Args:
|
| input_tensor: Input tensor (NCHW format)
|
| stream_id: Optional stream for async execution
|
| """
|
|
|
| if isinstance(input_tensor, str):
|
| tensor_name = input_tensor
|
| tensor_info = self.driver.get_tensor_info(tensor_name)
|
| input_shape = tensor_info.shape
|
| dtype = tensor_info.dtype
|
| else:
|
| tensor_name = input_tensor.name
|
| input_shape = input_tensor.shape
|
| dtype = input_tensor.dtype.name
|
|
|
| batch_size, channels = input_shape[0], input_shape[1]
|
| output_shape = (batch_size, channels)
|
|
|
|
|
| output = self._get_temp_tensor(output_shape, dtype)
|
|
|
|
|
| self.driver.reduce_mean(
|
| input=tensor_name,
|
| output=output,
|
| dims=(2, 3),
|
| stream_id=stream_id
|
| )
|
|
|
| return output
|
|
|
|
|
| def max_pool2d(
|
| x: Union[str, "HeliumTensor"],
|
| kernel_size: Union[int, Tuple[int, int]],
|
| stride: Optional[Union[int, Tuple[int, int]]] = None,
|
| padding: Union[int, Tuple[int, int]] = 0,
|
| device_id: Optional[str] = None,
|
| stream_id: Optional[int] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """Functional interface for 2D max pooling"""
|
| module = MaxPool2D(kernel_size, stride, padding, device_id)
|
| return module.forward(x, stream_id)
|
|
|
| def avg_pool2d(
|
| x: Union[str, "HeliumTensor"],
|
| kernel_size: Union[int, Tuple[int, int]],
|
| stride: Optional[Union[int, Tuple[int, int]]] = None,
|
| padding: Union[int, Tuple[int, int]] = 0,
|
| device_id: Optional[str] = None,
|
| stream_id: Optional[int] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """Functional interface for 2D average pooling"""
|
| module = AvgPool2D(kernel_size, stride, padding, device_id)
|
| return module.forward(x, stream_id)
|
|
|
| def global_avg_pool2d(
|
| x: Union[str, "HeliumTensor"],
|
| device_id: Optional[str] = None,
|
| stream_id: Optional[int] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """Functional interface for global average pooling"""
|
| module = GlobalAvgPool2D(device_id)
|
| return module.forward(x, stream_id)
|
|
|