| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from typing import Union |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class LayerScale(nn.Module): |
| """Layer scale module for scaling the output of a layer. |
| |
| Parameters |
| ---------- |
| dim : int |
| Dimension of the layer scale. |
| init_values : float or torch.Tensor, optional |
| Initial values for the layer scale, by default 1e-5. If a tensor is provided, it should have shape (dim,). |
| inplace : bool, optional |
| Whether to perform the operation in-place, by default False. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| init_values: Union[float, torch.Tensor] = 1e-5, |
| inplace: bool = False, |
| ) -> None: |
| """Inits :class:`LayerScale |
| |
| Parameters |
| ---------- |
| dim : int |
| Dimension of the layer scale. |
| init_values : float or torch.Tensor, optional |
| Initial values for the layer scale, by default 1e-5. If a tensor is provided, it should have shape (dim,). |
| inplace : bool, optional |
| Whether to perform the operation in-place, by default False. |
| """ |
| super().__init__() |
|
|
| self.inplace = inplace |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Forward pass of :class:`LayerScale`. |
| |
| Parameters |
| ---------- |
| x : torch.Tensor |
| Input tensor of shape (B, N, C) where B is the batch size, N is the sequence length, and C is |
| the feature dimension. |
| |
| Returns |
| ------- |
| torch.Tensor |
| Scaled output tensor of shape (B, N, C). |
| """ |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|