| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Normalization modules.""" |
| |
|
| | import typing as tp |
| |
|
| | import einops |
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | class ConvLayerNorm(nn.LayerNorm): |
| | """ |
| | Convolution-friendly LayerNorm that moves channels to last dimensions |
| | before running the normalization and moves them back to original position right after. |
| | """ |
| |
|
| | def __init__( |
| | self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs |
| | ): |
| | super().__init__(normalized_shape, **kwargs) |
| |
|
| | def forward(self, x): |
| | x = einops.rearrange(x, "b ... t -> b t ...") |
| | x = super().forward(x) |
| | x = einops.rearrange(x, "b t ... -> b ... t") |
| | return |
| |
|