File size: 1,103 Bytes
e101805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from typing import Callable, Optional
import torch.nn.functional as F
from torch import Tensor, nn
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
drop: float = 0.0,
bias: bool = True,
align_to: int = 8,
device=None,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
d = int(hidden_features * 2 / 3)
swiglu_hidden_features = d + (-d % align_to)
self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.drop(self.w3(hidden)) |