| import torch |
|
|
|
|
| class LayerNorm(torch.nn.Module): |
| """Layer normalization module. |
| |
| Args: |
| channels (int): Number of channels. |
| eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-5. |
| """ |
|
|
| def __init__(self, channels, eps=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.gamma = torch.nn.Parameter(torch.ones(channels)) |
| self.beta = torch.nn.Parameter(torch.zeros(channels)) |
|
|
| def forward(self, x): |
| """Forward pass. |
| |
| Args: |
| x (torch.Tensor): Input tensor of shape (batch_size, channels, time_steps). |
| |
| """ |
| |
| x = x.transpose(1, -1) |
| x = torch.nn.functional.layer_norm( |
| x, (x.size(-1),), self.gamma, self.beta, self.eps |
| ) |
| |
| return x.transpose(1, -1) |
|
|