Spaces:
Running
on
Zero
Running
on
Zero
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def get_activation_layer(act_type): | |
| if act_type == "gelu": | |
| return lambda: nn.GELU() | |
| elif act_type == "gelu_tanh": | |
| # Approximate `tanh` requires torch >= 1.13 | |
| return lambda: nn.GELU(approximate="tanh") | |
| elif act_type == "relu": | |
| return nn.ReLU | |
| elif act_type == "silu": | |
| return nn.SiLU | |
| else: | |
| raise ValueError(f"Unknown activation type: {act_type}") | |
| class SwiGLU(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: int, | |
| out_dim: int, | |
| ): | |
| """ | |
| Initialize the SwiGLU FeedForward module. | |
| Args: | |
| dim (int): Input dimension. | |
| hidden_dim (int): Hidden dimension of the feedforward layer. | |
| Attributes: | |
| w1: Linear transformation for the first layer. | |
| w2: Linear transformation for the second layer. | |
| w3: Linear transformation for the third layer. | |
| """ | |
| super().__init__() | |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
| self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) | |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
| def forward(self, x): | |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |