File size: 269 Bytes
4ec6f12
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from torch import nn
import torch


# pool of square window of size=3, stride=2
m = nn.AvgPool2d(3, stride=2)
# pool of non-square window
m = nn.AvgPool2d(5)
input = torch.randn(32,256, 5, 5)
output = m(input)
output = output.squeeze(-1).squeeze(-1)
print(output.shape)