|
|
|
|
|
|
|
|
|
|
| from typing import Union
|
|
|
| import torch
|
| from torch import Tensor, nn
|
|
|
|
|
| class LayerScale(nn.Module):
|
| def __init__(
|
| self,
|
| dim: int,
|
| init_values: Union[float, Tensor] = 1e-5,
|
| inplace: bool = False,
|
| device=None,
|
| ) -> None:
|
| super().__init__()
|
| self.inplace = inplace
|
| self.gamma = nn.Parameter(torch.empty(dim, device=device))
|
| self.init_values = init_values
|
|
|
| def reset_parameters(self):
|
| nn.init.constant_(self.gamma, self.init_values)
|
|
|
| def forward(self, x: Tensor) -> Tensor:
|
| return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
|