| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| |
|
| | class ChannelLastConv1d(nn.Conv1d): |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = x.permute(0, 2, 1) |
| | x = super().forward(x) |
| | x = x.permute(0, 2, 1) |
| | return x |
| |
|
| |
|
| | |
| | class MLP(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | hidden_dim: int, |
| | multiple_of: int = 256, |
| | ): |
| | """ |
| | Initialize the FeedForward module. |
| | |
| | Args: |
| | dim (int): Input dimension. |
| | hidden_dim (int): Hidden dimension of the feedforward layer. |
| | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
| | |
| | Attributes: |
| | w1 (ColumnParallelLinear): Linear transformation for the first layer. |
| | w2 (RowParallelLinear): Linear transformation for the second layer. |
| | w3 (ColumnParallelLinear): Linear transformation for the third layer. |
| | |
| | """ |
| | super().__init__() |
| | hidden_dim = int(2 * hidden_dim / 3) |
| | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| |
|
| | self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| | self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| | self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| |
|
| | def forward(self, x): |
| | return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| |
|
| |
|
| | class ConvMLP(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | hidden_dim: int, |
| | multiple_of: int = 256, |
| | kernel_size: int = 3, |
| | padding: int = 1, |
| | ): |
| | """ |
| | Initialize the FeedForward module. |
| | |
| | Args: |
| | dim (int): Input dimension. |
| | hidden_dim (int): Hidden dimension of the feedforward layer. |
| | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. |
| | |
| | Attributes: |
| | w1 (ColumnParallelLinear): Linear transformation for the first layer. |
| | w2 (RowParallelLinear): Linear transformation for the second layer. |
| | w3 (ColumnParallelLinear): Linear transformation for the third layer. |
| | |
| | """ |
| | super().__init__() |
| | hidden_dim = int(2 * hidden_dim / 3) |
| | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| |
|
| | self.w1 = ChannelLastConv1d(dim, |
| | hidden_dim, |
| | bias=False, |
| | kernel_size=kernel_size, |
| | padding=padding) |
| | self.w2 = ChannelLastConv1d(hidden_dim, |
| | dim, |
| | bias=False, |
| | kernel_size=kernel_size, |
| | padding=padding) |
| | self.w3 = ChannelLastConv1d(dim, |
| | hidden_dim, |
| | bias=False, |
| | kernel_size=kernel_size, |
| | padding=padding) |
| |
|
| | def forward(self, x): |
| | return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| |
|