Spaces:
Sleeping
Sleeping
File size: 1,743 Bytes
24e5510 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch
import torch.nn.functional as F
from nnunetv2.utilities.helpers import empty_cache
from torch.backends import cudnn
@torch.inference_mode()
def iterative_3x3_same_padding_pool3d(x, kernel_size: int, use_min_pool: bool = False):
"""
Applies 3D max pooling with manual asymmetric padding such that
the output shape is the same as the input shape.
Args:
x (Tensor): Input tensor of shape (N, C, D, H, W)
kernel_size (int or tuple): Kernel size for the pooling.
If int, the same kernel size is used for all three dimensions.
Returns:
Tensor: Output tensor with the same (D, H, W) dimensions as the input.
"""
benchmark = cudnn.benchmark
cudnn.benchmark = False
assert kernel_size % 2 == 1, 'Only works with odd kernels'
# Compute asymmetric padding for each dimension:
pad_front = (kernel_size - 1) // 2
pad_back = (kernel_size - 1) - pad_front
# For 3D (input shape: [N, C, D, H, W]), F.pad expects the padding in the order:
# (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back)
x = F.pad(x, (pad_front, pad_back,
pad_front, pad_back,
pad_front, pad_back), mode='replicate')
iters = (kernel_size - 1) // 2
# Apply max pooling with no additional padding.
if not use_min_pool:
for _ in range(iters):
x = F.max_pool3d(x, kernel_size=3, stride=1, padding=0)
empty_cache(x.device)
cudnn.benchmark = benchmark
return x
else:
for _ in range(iters):
x = - F.max_pool3d(- x, kernel_size=3, stride=1, padding=0)
empty_cache(x.device)
cudnn.benchmark = benchmark
return x
|