File size: 1,743 Bytes
24e5510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn.functional as F

from nnunetv2.utilities.helpers import empty_cache
from torch.backends import cudnn


@torch.inference_mode()
def iterative_3x3_same_padding_pool3d(x, kernel_size: int, use_min_pool: bool = False):
    """
    Applies 3D max pooling with manual asymmetric padding such that
    the output shape is the same as the input shape.

    Args:
        x (Tensor): Input tensor of shape (N, C, D, H, W)
        kernel_size (int or tuple): Kernel size for the pooling.
            If int, the same kernel size is used for all three dimensions.

    Returns:
        Tensor: Output tensor with the same (D, H, W) dimensions as the input.
    """
    benchmark = cudnn.benchmark
    cudnn.benchmark = False

    assert kernel_size % 2 == 1, 'Only works with odd kernels'

    # Compute asymmetric padding for each dimension:
    pad_front = (kernel_size - 1) // 2
    pad_back = (kernel_size - 1) - pad_front


    # For 3D (input shape: [N, C, D, H, W]), F.pad expects the padding in the order:
    # (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back)
    x = F.pad(x, (pad_front, pad_back,
                         pad_front, pad_back,
                         pad_front, pad_back), mode='replicate')

    iters = (kernel_size - 1) // 2
    # Apply max pooling with no additional padding.
    if not use_min_pool:
        for _ in range(iters):
            x = F.max_pool3d(x, kernel_size=3, stride=1, padding=0)
        empty_cache(x.device)
        cudnn.benchmark = benchmark
        return x
    else:
        for _ in range(iters):
            x = - F.max_pool3d(- x, kernel_size=3, stride=1, padding=0)
        empty_cache(x.device)
        cudnn.benchmark = benchmark
        return x