duycse1603's picture
[Add] source
6163604
from torch import nn
class Maxout(nn.Module):
"""
Maxout makes pools from the last dimension and keeps only the maximum value from
each pool.
"""
def __init__(self, pool_size):
"""
Args:
pool_size (int): Number of elements per pool
"""
super(Maxout, self).__init__()
self.pool_size = pool_size
def forward(self, x):
[*shape, last] = x.size()
out = x.view(*shape, last // self.pool_size, self.pool_size)
out, _ = out.max(-1)
return out