Spaces:
Runtime error
Runtime error
| from torch import nn | |
| class Maxout(nn.Module): | |
| """ | |
| Maxout makes pools from the last dimension and keeps only the maximum value from | |
| each pool. | |
| """ | |
| def __init__(self, pool_size): | |
| """ | |
| Args: | |
| pool_size (int): Number of elements per pool | |
| """ | |
| super(Maxout, self).__init__() | |
| self.pool_size = pool_size | |
| def forward(self, x): | |
| [*shape, last] = x.size() | |
| out = x.view(*shape, last // self.pool_size, self.pool_size) | |
| out, _ = out.max(-1) | |
| return out | |