| |
| |
| |
| |
|
|
|
|
| import copy |
| import numbers |
| from typing import Any, List, Tuple, Union |
|
|
| import torch |
| from torch import Tensor, nn |
| from torch.nn import functional as F |
|
|
| from modules.general.scaling import ActivationBalancer |
| from modules.general.scaling import BasicNorm as _BasicNorm |
|
|
|
|
| _shape_t = Union[int, List[int], torch.Size] |
|
|
|
|
| class LayerNorm(nn.Module): |
| __constants__ = ["normalized_shape", "eps", "elementwise_affine"] |
| normalized_shape: Tuple[int, ...] |
| eps: float |
| elementwise_affine: bool |
|
|
| def __init__( |
| self, |
| normalized_shape: _shape_t, |
| eps: float = 1e-5, |
| elementwise_affine: bool = True, |
| device=None, |
| dtype=None, |
| ) -> None: |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super(LayerNorm, self).__init__() |
| if isinstance(normalized_shape, numbers.Integral): |
| normalized_shape = (normalized_shape,) |
| self.normalized_shape = tuple(normalized_shape) |
| self.eps = eps |
| self.elementwise_affine = elementwise_affine |
| if self.elementwise_affine: |
| self.weight = nn.Parameter( |
| torch.empty(self.normalized_shape, **factory_kwargs) |
| ) |
| self.bias = nn.Parameter( |
| torch.empty(self.normalized_shape, **factory_kwargs) |
| ) |
| else: |
| self.register_parameter("weight", None) |
| self.register_parameter("bias", None) |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| if self.elementwise_affine: |
| nn.init.ones_(self.weight) |
| nn.init.zeros_(self.bias) |
|
|
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| if isinstance(input, tuple): |
| input, embedding = input |
| output = F.layer_norm( |
| input, self.normalized_shape, self.weight, self.bias, self.eps |
| ) |
| return output, embedding |
|
|
| assert embedding is None |
| return F.layer_norm( |
| input, self.normalized_shape, self.weight, self.bias, self.eps |
| ) |
|
|
| def extra_repr(self) -> str: |
| return ( |
| "{normalized_shape}, eps={eps}, " |
| "elementwise_affine={elementwise_affine}".format(**self.__dict__) |
| ) |
|
|
|
|
| class AdaptiveLayerNorm(nn.Module): |
| r"""Adaptive Layer Normalization""" |
|
|
| def __init__(self, d_model, norm) -> None: |
| super(AdaptiveLayerNorm, self).__init__() |
| self.project_layer = nn.Linear(d_model, 2 * d_model) |
| self.norm = norm |
| self.d_model = d_model |
| self.eps = self.norm.eps |
|
|
| def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: |
| if isinstance(input, tuple): |
| input, embedding = input |
| weight, bias = torch.split( |
| self.project_layer(embedding), |
| split_size_or_sections=self.d_model, |
| dim=-1, |
| ) |
| return (weight * self.norm(input) + bias, embedding) |
|
|
| weight, bias = torch.split( |
| self.project_layer(embedding), |
| split_size_or_sections=self.d_model, |
| dim=-1, |
| ) |
| return weight * self.norm(input) + bias |
|
|
|
|
| class BasicNorm(_BasicNorm): |
| def __init__( |
| self, |
| d_model: int, |
| eps: float = 1e-5, |
| device=None, |
| dtype=None, |
| ): |
| super(BasicNorm, self).__init__(d_model, eps=eps) |
|
|
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| if isinstance(input, tuple): |
| input, embedding = input |
| return ( |
| super(BasicNorm, self).forward(input), |
| embedding, |
| ) |
|
|
| assert embedding is None |
| return super(BasicNorm, self).forward(input) |
|
|
|
|
| class BalancedBasicNorm(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| eps: float = 1e-5, |
| device=None, |
| dtype=None, |
| ): |
| super(BalancedBasicNorm, self).__init__() |
| self.balancer = ActivationBalancer( |
| d_model, |
| channel_dim=-1, |
| min_positive=0.45, |
| max_positive=0.55, |
| max_abs=6.0, |
| ) |
| self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) |
|
|
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| if isinstance(input, tuple): |
| input, embedding = input |
| return self.norm((self.balancer(input), embedding)) |
|
|
| assert embedding is None |
| return self.norm(self.balancer(input)) |
|
|
|
|
| class IdentityNorm(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| eps: float = 1e-5, |
| device=None, |
| dtype=None, |
| ) -> None: |
| super(IdentityNorm, self).__init__() |
|
|
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: |
| if isinstance(input, tuple): |
| return input |
|
|
| assert embedding is None |
| return input |
|
|