File size: 269 Bytes
4ec6f12 |
1 2 3 4 5 6 7 8 9 10 11 12 |
from torch import nn
import torch
# pool of square window of size=3, stride=2
m = nn.AvgPool2d(3, stride=2)
# pool of non-square window
m = nn.AvgPool2d(5)
input = torch.randn(32,256, 5, 5)
output = m(input)
output = output.squeeze(-1).squeeze(-1)
print(output.shape) |