Dhenenjay commited on
Commit
a9abf27
·
verified ·
1 Parent(s): 298f49d

Upload softpool.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. softpool.py +37 -0
softpool.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pure PyTorch SoftPool implementation."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False):
9
+ if stride is None:
10
+ stride = kernel_size
11
+ if isinstance(kernel_size, int):
12
+ kernel_size = (kernel_size, kernel_size)
13
+ if isinstance(stride, int):
14
+ stride = (stride, stride)
15
+
16
+ batch, channels, height, width = x.shape
17
+ kh, kw = kernel_size
18
+ sh, sw = stride
19
+ out_h = (height - kh) // sh + 1
20
+ out_w = (width - kw) // sw + 1
21
+
22
+ x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride)
23
+ x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w)
24
+ x_max = x_unfold.max(dim=2, keepdim=True)[0]
25
+ exp_x = torch.exp(x_unfold - x_max)
26
+ softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8)
27
+ return softpool.view(batch, channels, out_h, out_w)
28
+
29
+
30
+ class SoftPool2d(nn.Module):
31
+ def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False):
32
+ super(SoftPool2d, self).__init__()
33
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
34
+ self.stride = stride if stride is not None else self.kernel_size
35
+
36
+ def forward(self, x):
37
+ return soft_pool2d(x, self.kernel_size, self.stride)