Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from nnunetv2.utilities.helpers import empty_cache | |
| from torch.backends import cudnn | |
| def iterative_3x3_same_padding_pool3d(x, kernel_size: int, use_min_pool: bool = False): | |
| """ | |
| Applies 3D max pooling with manual asymmetric padding such that | |
| the output shape is the same as the input shape. | |
| Args: | |
| x (Tensor): Input tensor of shape (N, C, D, H, W) | |
| kernel_size (int or tuple): Kernel size for the pooling. | |
| If int, the same kernel size is used for all three dimensions. | |
| Returns: | |
| Tensor: Output tensor with the same (D, H, W) dimensions as the input. | |
| """ | |
| benchmark = cudnn.benchmark | |
| cudnn.benchmark = False | |
| assert kernel_size % 2 == 1, 'Only works with odd kernels' | |
| # Compute asymmetric padding for each dimension: | |
| pad_front = (kernel_size - 1) // 2 | |
| pad_back = (kernel_size - 1) - pad_front | |
| # For 3D (input shape: [N, C, D, H, W]), F.pad expects the padding in the order: | |
| # (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back) | |
| x = F.pad(x, (pad_front, pad_back, | |
| pad_front, pad_back, | |
| pad_front, pad_back), mode='replicate') | |
| iters = (kernel_size - 1) // 2 | |
| # Apply max pooling with no additional padding. | |
| if not use_min_pool: | |
| for _ in range(iters): | |
| x = F.max_pool3d(x, kernel_size=3, stride=1, padding=0) | |
| empty_cache(x.device) | |
| cudnn.benchmark = benchmark | |
| return x | |
| else: | |
| for _ in range(iters): | |
| x = - F.max_pool3d(- x, kernel_size=3, stride=1, padding=0) | |
| empty_cache(x.device) | |
| cudnn.benchmark = benchmark | |
| return x | |