YasiiKB's picture
initial commit
97aa5af verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class Pooling(torch.nn.Module):
def __init__(self, pool_type='max'):
self.pool_type = pool_type
super(Pooling, self).__init__()
def forward(self, input):
if self.pool_type == 'max':
return torch.max(input, 2)[0].contiguous()
elif self.pool_type == 'avg' or self.pool_type == 'average':
return torch.mean(input, 2).contiguous()