File size: 2,693 Bytes
8870824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""

Pure PyTorch implementation of SoftPool.

This is a fallback that doesn't require CUDA kernel compilation.

SoftPool: https://arxiv.org/abs/2101.00440

"""
import torch
import torch.nn as nn
import torch.nn.functional as F


def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False):
    """

    Apply soft pooling on 2D input tensor.

    

    SoftPool approximates max pooling while maintaining differentiability

    by using exponential weighting: y = sum(x * exp(x)) / sum(exp(x))

    

    Args:

        x: Input tensor of shape (N, C, H, W)

        kernel_size: Pooling kernel size

        stride: Stride (defaults to kernel_size)

        force_inplace: Unused, for API compatibility

    

    Returns:

        Pooled tensor

    """
    if stride is None:
        stride = kernel_size
    
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if isinstance(stride, int):
        stride = (stride, stride)
    
    # Use unfold to extract patches
    batch, channels, height, width = x.shape
    kh, kw = kernel_size
    sh, sw = stride
    
    # Calculate output dimensions
    out_h = (height - kh) // sh + 1
    out_w = (width - kw) // sw + 1
    
    # Apply exponential weighting
    # For numerical stability, subtract max before exp
    x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride)  # (N, C*kh*kw, out_h*out_w)
    x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w)
    
    # Softmax-style weighting for soft pooling
    x_max = x_unfold.max(dim=2, keepdim=True)[0]
    exp_x = torch.exp(x_unfold - x_max)  # Numerical stability
    
    # Weighted sum: sum(x * exp(x)) / sum(exp(x))
    softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8)
    
    # Reshape to output format
    softpool = softpool.view(batch, channels, out_h, out_w)
    
    return softpool


class SoftPool2d(nn.Module):
    """

    SoftPool 2D Layer.

    

    A differentiable pooling operation that approximates max pooling

    using exponential weighting.

    """
    
    def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False):
        super(SoftPool2d, self).__init__()
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride if stride is not None else self.kernel_size
        self.force_inplace = force_inplace
    
    def forward(self, x):
        return soft_pool2d(x, self.kernel_size, self.stride, self.force_inplace)
    
    def extra_repr(self):
        return f'kernel_size={self.kernel_size}, stride={self.stride}'