File size: 412 Bytes
97aa5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | import torch
import torch.nn as nn
import torch.nn.functional as F
class Pooling(torch.nn.Module):
def __init__(self, pool_type='max'):
self.pool_type = pool_type
super(Pooling, self).__init__()
def forward(self, input):
if self.pool_type == 'max':
return torch.max(input, 2)[0].contiguous()
elif self.pool_type == 'avg' or self.pool_type == 'average':
return torch.mean(input, 2).contiguous() |