INV / helium /pooling.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Hardware-accelerated pooling operations for Helium virtual GPU
"""
from typing import Optional, Union, Tuple, List, TYPE_CHECKING
from virtual_gpu_driver.src.ai.tensor_types import (
TensorDescriptor, DType, Device, Layout,
PoolingDescriptor, PoolingMode
)
from .main import get_device, get_default_device
if TYPE_CHECKING:
from .main import HeliumTensor
class HeliumPooling2D:
"""Base class for 2D pooling operations on virtual GPU"""
def __init__(
self,
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Union[int, Tuple[int, int]] = 0,
device_id: Optional[str] = None
):
# Get virtual GPU driver
self.driver = get_device(device_id) if device_id else get_default_device()
self.device_id = device_id
# Parse kernel size
if isinstance(kernel_size, int):
self.kernel_height = self.kernel_width = kernel_size
else:
self.kernel_height, self.kernel_width = kernel_size
# Parse stride
if stride is None:
self.stride_height = self.kernel_height
self.stride_width = self.kernel_width
elif isinstance(stride, int):
self.stride_height = self.stride_width = stride
else:
self.stride_height, self.stride_width = stride
# Parse padding
if isinstance(padding, int):
self.padding_height = self.padding_width = padding
else:
self.padding_height, self.padding_width = padding
# Track allocated tensors
self._temp_tensors = {}
self._counter = 0
def _get_temp_tensor(self, shape: Tuple[int, ...], dtype: str = "float32") -> str:
"""Allocate temporary tensor in device memory"""
tensor_id = f"pool_temp_{self._counter}"
self._counter += 1
descriptor = TensorDescriptor(
shape=shape,
dtype=getattr(DType, dtype.upper()),
device=Device.VGPU,
layout=Layout.NHWC # Use NHWC for better performance on GPU
)
self._temp_tensors[tensor_id] = self.driver.allocate_tensor(descriptor)
return tensor_id
def _free_temp_tensor(self, tensor_id: str):
"""Release temporary tensor memory"""
if tensor_id in self._temp_tensors:
self.driver.free_tensor(self._temp_tensors[tensor_id])
del self._temp_tensors[tensor_id]
def __del__(self):
"""Cleanup temporary tensors"""
for tensor_id in list(self._temp_tensors.keys()):
self._free_temp_tensor(tensor_id)
def _create_pooling_descriptor(self, mode: PoolingMode) -> PoolingDescriptor:
"""Create pooling descriptor for hardware"""
return PoolingDescriptor(
mode=mode,
kernel_height=self.kernel_height,
kernel_width=self.kernel_width,
stride_height=self.stride_height,
stride_width=self.stride_width,
padding_height=self.padding_height,
padding_width=self.padding_width
)
class MaxPool2D(HeliumPooling2D):
"""2D max pooling layer"""
def forward(
self,
input_tensor: Union[str, "HeliumTensor"],
stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
"""
Compute 2D max pooling
Args:
input_tensor: Input tensor (NCHW format)
stream_id: Optional stream for async execution
"""
# Get input tensor info
if isinstance(input_tensor, str):
tensor_name = input_tensor
tensor_info = self.driver.get_tensor_info(tensor_name)
input_shape = tensor_info.shape
dtype = tensor_info.dtype
else:
tensor_name = input_tensor.name
input_shape = input_tensor.shape
dtype = input_tensor.dtype.name
# Calculate output shape
batch_size, channels, in_height, in_width = input_shape
out_height = (in_height + 2*self.padding_height - self.kernel_height) // self.stride_height + 1
out_width = (in_width + 2*self.padding_width - self.kernel_width) // self.stride_width + 1
output_shape = (batch_size, channels, out_height, out_width)
# Allocate output tensor
output = self._get_temp_tensor(output_shape, dtype)
# Create pooling descriptor
pool_desc = self._create_pooling_descriptor(PoolingMode.MAX)
# Execute pooling on device
self.driver.pooling_forward(
pooling_desc=pool_desc,
input=tensor_name,
output=output,
stream_id=stream_id
)
return output
class AvgPool2D(HeliumPooling2D):
"""2D average pooling layer"""
def forward(
self,
input_tensor: Union[str, "HeliumTensor"],
stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
"""
Compute 2D average pooling
Args:
input_tensor: Input tensor (NCHW format)
stream_id: Optional stream for async execution
"""
# Get input tensor info
if isinstance(input_tensor, str):
tensor_name = input_tensor
tensor_info = self.driver.get_tensor_info(tensor_name)
input_shape = tensor_info.shape
dtype = tensor_info.dtype
else:
tensor_name = input_tensor.name
input_shape = input_tensor.shape
dtype = input_tensor.dtype.name
# Calculate output shape
batch_size, channels, in_height, in_width = input_shape
out_height = (in_height + 2*self.padding_height - self.kernel_height) // self.stride_height + 1
out_width = (in_width + 2*self.padding_width - self.kernel_width) // self.stride_width + 1
output_shape = (batch_size, channels, out_height, out_width)
# Allocate output tensor
output = self._get_temp_tensor(output_shape, dtype)
# Create pooling descriptor
pool_desc = self._create_pooling_descriptor(PoolingMode.AVERAGE)
# Execute pooling on device
self.driver.pooling_forward(
pooling_desc=pool_desc,
input=tensor_name,
output=output,
stream_id=stream_id
)
return output
class GlobalAvgPool2D:
"""Global average pooling layer"""
def __init__(self, device_id: Optional[str] = None):
self.driver = get_device(device_id) if device_id else get_default_device()
self.device_id = device_id
self._temp_tensors = {}
self._counter = 0
def _get_temp_tensor(self, shape: Tuple[int, ...], dtype: str = "float32") -> str:
tensor_id = f"global_pool_temp_{self._counter}"
self._counter += 1
descriptor = TensorDescriptor(
shape=shape,
dtype=getattr(DType, dtype.upper()),
device=Device.VGPU,
layout=Layout.NHWC
)
self._temp_tensors[tensor_id] = self.driver.allocate_tensor(descriptor)
return tensor_id
def forward(
self,
input_tensor: Union[str, "HeliumTensor"],
stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
"""
Compute global average pooling
Args:
input_tensor: Input tensor (NCHW format)
stream_id: Optional stream for async execution
"""
# Get input info
if isinstance(input_tensor, str):
tensor_name = input_tensor
tensor_info = self.driver.get_tensor_info(tensor_name)
input_shape = tensor_info.shape
dtype = tensor_info.dtype
else:
tensor_name = input_tensor.name
input_shape = input_tensor.shape
dtype = input_tensor.dtype.name
batch_size, channels = input_shape[0], input_shape[1]
output_shape = (batch_size, channels)
# Allocate output
output = self._get_temp_tensor(output_shape, dtype)
# Execute global pooling
self.driver.reduce_mean(
input=tensor_name,
output=output,
dims=(2, 3), # Height and width dimensions
stream_id=stream_id
)
return output
# Functional interface
def max_pool2d(
x: Union[str, "HeliumTensor"],
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Union[int, Tuple[int, int]] = 0,
device_id: Optional[str] = None,
stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
"""Functional interface for 2D max pooling"""
module = MaxPool2D(kernel_size, stride, padding, device_id)
return module.forward(x, stream_id)
def avg_pool2d(
x: Union[str, "HeliumTensor"],
kernel_size: Union[int, Tuple[int, int]],
stride: Optional[Union[int, Tuple[int, int]]] = None,
padding: Union[int, Tuple[int, int]] = 0,
device_id: Optional[str] = None,
stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
"""Functional interface for 2D average pooling"""
module = AvgPool2D(kernel_size, stride, padding, device_id)
return module.forward(x, stream_id)
def global_avg_pool2d(
x: Union[str, "HeliumTensor"],
device_id: Optional[str] = None,
stream_id: Optional[int] = None
) -> Union[str, "HeliumTensor"]:
"""Functional interface for global average pooling"""
module = GlobalAvgPool2D(device_id)
return module.forward(x, stream_id)