import torch import torch.nn as nn import torch.nn.functional as F class Pooling(torch.nn.Module): def __init__(self, pool_type='max'): self.pool_type = pool_type super(Pooling, self).__init__() def forward(self, input): if self.pool_type == 'max': return torch.max(input, 2)[0].contiguous() elif self.pool_type == 'avg' or self.pool_type == 'average': return torch.mean(input, 2).contiguous()