from typing import Callable, Optional import torch.nn.functional as F from torch import Tensor, nn class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, drop: float = 0.0, bias: bool = True, align_to: int = 8, device=None, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features d = int(hidden_features * 2 / 3) swiglu_hidden_features = d + (-d % align_to) self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x1 = self.w1(x) x2 = self.w2(x) hidden = F.silu(x1) * x2 return self.drop(self.w3(hidden))