hanjang's picture
Upload folder using huggingface_hub
24e5510 verified
import torch
import torch.nn.functional as F
from nnunetv2.utilities.helpers import empty_cache
from torch.backends import cudnn
@torch.inference_mode()
def iterative_3x3_same_padding_pool3d(x, kernel_size: int, use_min_pool: bool = False):
"""
Applies 3D max pooling with manual asymmetric padding such that
the output shape is the same as the input shape.
Args:
x (Tensor): Input tensor of shape (N, C, D, H, W)
kernel_size (int or tuple): Kernel size for the pooling.
If int, the same kernel size is used for all three dimensions.
Returns:
Tensor: Output tensor with the same (D, H, W) dimensions as the input.
"""
benchmark = cudnn.benchmark
cudnn.benchmark = False
assert kernel_size % 2 == 1, 'Only works with odd kernels'
# Compute asymmetric padding for each dimension:
pad_front = (kernel_size - 1) // 2
pad_back = (kernel_size - 1) - pad_front
# For 3D (input shape: [N, C, D, H, W]), F.pad expects the padding in the order:
# (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back)
x = F.pad(x, (pad_front, pad_back,
pad_front, pad_back,
pad_front, pad_back), mode='replicate')
iters = (kernel_size - 1) // 2
# Apply max pooling with no additional padding.
if not use_min_pool:
for _ in range(iters):
x = F.max_pool3d(x, kernel_size=3, stride=1, padding=0)
empty_cache(x.device)
cudnn.benchmark = benchmark
return x
else:
for _ in range(iters):
x = - F.max_pool3d(- x, kernel_size=3, stride=1, padding=0)
empty_cache(x.device)
cudnn.benchmark = benchmark
return x