| from torch import nn | |
| import torch | |
| # pool of square window of size=3, stride=2 | |
| m = nn.AvgPool2d(3, stride=2) | |
| # pool of non-square window | |
| m = nn.AvgPool2d(5) | |
| input = torch.randn(32,256, 5, 5) | |
| output = m(input) | |
| output = output.squeeze(-1).squeeze(-1) | |
| print(output.shape) |