| from typing import Callable, Optional | |
| import torch.nn.functional as F | |
| from torch import Tensor, nn | |
| class SwiGLUFFN(nn.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| hidden_features: Optional[int] = None, | |
| out_features: Optional[int] = None, | |
| drop: float = 0.0, | |
| bias: bool = True, | |
| align_to: int = 8, | |
| device=None, | |
| ) -> None: | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| d = int(hidden_features * 2 / 3) | |
| swiglu_hidden_features = d + (-d % align_to) | |
| self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) | |
| self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) | |
| self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x1 = self.w1(x) | |
| x2 = self.w2(x) | |
| hidden = F.silu(x1) * x2 | |
| return self.drop(self.w3(hidden)) |