| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def normalization(channels: int, groups: int = 32): |
| | r"""Make a standard normalization layer, i.e. GroupNorm. |
| | |
| | Args: |
| | channels: number of input channels. |
| | groups: number of groups for group normalization. |
| | |
| | Returns: |
| | a ``nn.Module`` for normalization. |
| | """ |
| | assert groups > 0, f"invalid number of groups: {groups}" |
| | return nn.GroupNorm(groups, channels) |
| |
|
| |
|
| | def Linear(*args, **kwargs): |
| | r"""Wrapper of ``nn.Linear`` with kaiming_normal_ initialization.""" |
| | layer = nn.Linear(*args, **kwargs) |
| | nn.init.kaiming_normal_(layer.weight) |
| | return layer |
| |
|
| |
|
| | def Conv1d(*args, **kwargs): |
| | r"""Wrapper of ``nn.Conv1d`` with kaiming_normal_ initialization.""" |
| | layer = nn.Conv1d(*args, **kwargs) |
| | nn.init.kaiming_normal_(layer.weight) |
| | return layer |
| |
|
| |
|
| | def Conv2d(*args, **kwargs): |
| | r"""Wrapper of ``nn.Conv2d`` with kaiming_normal_ initialization.""" |
| | layer = nn.Conv2d(*args, **kwargs) |
| | nn.init.kaiming_normal_(layer.weight) |
| | return layer |
| |
|
| |
|
| | def ConvNd(dims: int = 1, *args, **kwargs): |
| | r"""Wrapper of N-dimension convolution with kaiming_normal_ initialization. |
| | |
| | Args: |
| | dims: number of dimensions of the convolution. |
| | """ |
| | if dims == 1: |
| | return Conv1d(*args, **kwargs) |
| | elif dims == 2: |
| | return Conv2d(*args, **kwargs) |
| | else: |
| | raise ValueError(f"invalid number of dimensions: {dims}") |
| |
|
| |
|
| | def zero_module(module: nn.Module): |
| | r"""Zero out the parameters of a module and return it.""" |
| | nn.init.zeros_(module.weight) |
| | nn.init.zeros_(module.bias) |
| | return module |
| |
|
| |
|
| | def scale_module(module: nn.Module, scale): |
| | r"""Scale the parameters of a module and return it.""" |
| | for p in module.parameters(): |
| | p.detach().mul_(scale) |
| | return module |
| |
|
| |
|
| | def mean_flat(tensor: torch.Tensor): |
| | r"""Take the mean over all non-batch dimensions.""" |
| | return tensor.mean(dim=tuple(range(1, tensor.dim()))) |
| |
|
| |
|
| | def append_dims(x, target_dims): |
| | r"""Appends dimensions to the end of a tensor until |
| | it has target_dims dimensions. |
| | """ |
| | dims_to_append = target_dims - x.dim() |
| | if dims_to_append < 0: |
| | raise ValueError( |
| | f"input has {x.dim()} dims but target_dims is {target_dims}, which is less" |
| | ) |
| | return x[(...,) + (None,) * dims_to_append] |
| |
|
| |
|
| | def append_zero(x, count=1): |
| | r"""Appends ``count`` zeros to the end of a tensor along the last dimension.""" |
| | assert count > 0, f"invalid count: {count}" |
| | return torch.cat([x, x.new_zeros((*x.size()[:-1], count))], dim=-1) |
| |
|
| |
|
| | class Transpose(nn.Identity): |
| | """(N, T, D) -> (N, D, T)""" |
| |
|
| | def forward(self, input: torch.Tensor) -> torch.Tensor: |
| | return input.transpose(1, 2) |
| |
|