| | |
| | |
| | |
| | |
| |
|
| |
|
| | import copy |
| | import numbers |
| | from typing import Any, List, Tuple, Union |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| | from torch.nn import functional as F |
| |
|
| | from modules.general.scaling import ActivationBalancer |
| | from modules.general.scaling import BasicNorm as _BasicNorm |
| | |
| |
|
| | _shape_t = Union[int, List[int], torch.Size] |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | __constants__ = ["normalized_shape", "eps", "elementwise_affine"] |
| | normalized_shape: Tuple[int, ...] |
| | eps: float |
| | elementwise_affine: bool |
| |
|
| | def __init__( |
| | self, |
| | normalized_shape: _shape_t, |
| | eps: float = 1e-5, |
| | elementwise_affine: bool = True, |
| | device=None, |
| | dtype=None, |
| | ) -> None: |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| | super(LayerNorm, self).__init__() |
| | if isinstance(normalized_shape, numbers.Integral): |
| | normalized_shape = (normalized_shape,) |
| | self.normalized_shape = tuple(normalized_shape) |
| | self.eps = eps |
| | self.elementwise_affine = elementwise_affine |
| | if self.elementwise_affine: |
| | self.weight = nn.Parameter( |
| | torch.empty(self.normalized_shape, **factory_kwargs) |
| | ) |
| | self.bias = nn.Parameter( |
| | torch.empty(self.normalized_shape, **factory_kwargs) |
| | ) |
| | else: |
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self) -> None: |
| | if self.elementwise_affine: |
| | nn.init.ones_(self.weight) |
| | nn.init.zeros_(self.bias) |
| |
|
| | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| | if isinstance(input, tuple): |
| | input, embedding = input |
| | output = F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) |
| | return output, embedding |
| |
|
| | assert embedding is None |
| | return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) |
| |
|
| |
|
| | def extra_repr(self) -> str: |
| | return ( |
| | "{normalized_shape}, eps={eps}, " |
| | "elementwise_affine={elementwise_affine}".format(**self.__dict__) |
| | ) |
| |
|
| |
|
| | class AdaptiveLayerNorm(nn.Module): |
| | r"""Adaptive Layer Normalization""" |
| |
|
| | def __init__(self, d_model, norm) -> None: |
| | super(AdaptiveLayerNorm, self).__init__() |
| | self.project_layer = nn.Linear(d_model, 2 * d_model) |
| | self.norm = norm |
| | self.d_model = d_model |
| | self.eps = self.norm.eps |
| |
|
| | def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: |
| | if isinstance(input, tuple): |
| | input, embedding = input |
| | weight, bias = torch.split( |
| | self.project_layer(embedding), |
| | split_size_or_sections=self.d_model, |
| | dim=-1, |
| | ) |
| | return (weight * self.norm(input) + bias, embedding) |
| |
|
| | weight, bias = torch.split( |
| | self.project_layer(embedding), |
| | split_size_or_sections=self.d_model, |
| | dim=-1, |
| | ) |
| | return weight * self.norm(input) + bias |
| |
|
| |
|
| | class BasicNorm(_BasicNorm): |
| | def __init__( |
| | self, |
| | d_model: int, |
| | eps: float = 1e-5, |
| | device=None, |
| | dtype=None, |
| | ): |
| | super(BasicNorm, self).__init__(d_model, eps=eps) |
| |
|
| | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| | if isinstance(input, tuple): |
| | input, embedding = input |
| | return ( |
| | super(BasicNorm, self).forward(input), |
| | embedding, |
| | ) |
| |
|
| | assert embedding is None |
| | return super(BasicNorm, self).forward(input) |
| |
|
| |
|
| | class BalancedBasicNorm(nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int, |
| | eps: float = 1e-5, |
| | device=None, |
| | dtype=None, |
| | ): |
| | super(BalancedBasicNorm, self).__init__() |
| | self.balancer = ActivationBalancer( |
| | d_model, |
| | channel_dim=-1, |
| | min_positive=0.45, |
| | max_positive=0.55, |
| | max_abs=6.0, |
| | ) |
| | self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) |
| |
|
| | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| | if isinstance(input, tuple): |
| | input, embedding = input |
| | return self.norm((self.balancer(input), embedding)) |
| |
|
| | assert embedding is None |
| | return self.norm(self.balancer(input)) |
| |
|
| |
|
| | class IdentityNorm(nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int, |
| | eps: float = 1e-5, |
| | device=None, |
| | dtype=None, |
| | ) -> None: |
| | super(IdentityNorm, self).__init__() |
| |
|
| | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| | if isinstance(input, tuple): |
| | return input |
| |
|
| | assert embedding is None |
| | return input |
| |
|
| |
|