Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 Shigeki Karita | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| """Layer normalization module.""" | |
| import torch | |
| class LayerNorm(torch.nn.LayerNorm): | |
| """Layer normalization module. | |
| :param int nout: output dim size | |
| :param int dim: dimension to be normalized | |
| """ | |
| def __init__(self, nout, dim=-1): | |
| """Construct an LayerNorm object.""" | |
| super(LayerNorm, self).__init__(nout, eps=1e-12) | |
| self.dim = dim | |
| def forward(self, x): | |
| """Apply layer normalization. | |
| :param torch.Tensor x: input tensor | |
| :return: layer normalized tensor | |
| :rtype torch.Tensor | |
| """ | |
| if self.dim == -1: | |
| return super(LayerNorm, self).forward(x) | |
| return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) | |