| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Positionwise feed forward layer definition.""" |
| |
|
| | import torch |
| |
|
| |
|
| | class PositionwiseFeedForward(torch.nn.Module): |
| | """Positionwise feed forward layer. |
| | |
| | FeedForward are appied on each position of the sequence. |
| | The output dim is same with the input dim. |
| | |
| | Args: |
| | idim (int): Input dimenstion. |
| | hidden_units (int): The number of hidden units. |
| | dropout_rate (float): Dropout rate. |
| | activation (torch.nn.Module): Activation function |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | idim: int, |
| | hidden_units: int, |
| | dropout_rate: float, |
| | activation: torch.nn.Module = torch.nn.ReLU(), |
| | ): |
| | """Construct a PositionwiseFeedForward object.""" |
| | super(PositionwiseFeedForward, self).__init__() |
| | self.w_1 = torch.nn.Linear(idim, hidden_units) |
| | self.activation = activation |
| | self.dropout = torch.nn.Dropout(dropout_rate) |
| | self.w_2 = torch.nn.Linear(hidden_units, idim) |
| |
|
| | def forward(self, xs: torch.Tensor) -> torch.Tensor: |
| | """Forward function. |
| | |
| | Args: |
| | xs: input tensor (B, L, D) |
| | Returns: |
| | output tensor, (B, L, D) |
| | """ |
| | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) |
| |
|
| |
|
| | class MoEFFNLayer(torch.nn.Module): |
| | """ |
| | Mixture of expert with Positionwise feed forward layer |
| | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf |
| | The output dim is same with the input dim. |
| | |
| | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 |
| | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 |
| | Args: |
| | n_expert: number of expert. |
| | n_expert_per_token: The actual number of experts used for each frame |
| | idim (int): Input dimenstion. |
| | hidden_units (int): The number of hidden units. |
| | dropout_rate (float): Dropout rate. |
| | activation (torch.nn.Module): Activation function |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_expert: int, |
| | n_expert_per_token: int, |
| | idim: int, |
| | hidden_units: int, |
| | dropout_rate: float, |
| | activation: torch.nn.Module = torch.nn.ReLU(), |
| | ): |
| | super(MoEFFNLayer, self).__init__() |
| | self.gate = torch.nn.Linear(idim, n_expert, bias=False) |
| | self.experts = torch.nn.ModuleList( |
| | PositionwiseFeedForward(idim, hidden_units, dropout_rate, |
| | activation) for _ in range(n_expert)) |
| | self.n_expert_per_token = n_expert_per_token |
| |
|
| | def forward(self, xs: torch.Tensor) -> torch.Tensor: |
| | """Foward function. |
| | Args: |
| | xs: input tensor (B, L, D) |
| | Returns: |
| | output tensor, (B, L, D) |
| | |
| | """ |
| | B, L, D = xs.size( |
| | ) |
| | xs = xs.view(-1, D) |
| | router = self.gate(xs) |
| | logits, indices = torch.topk( |
| | router, self.n_expert_per_token |
| | ) |
| | weights = torch.nn.functional.softmax( |
| | logits, dim=1, |
| | dtype=torch.float).to(dtype=xs.dtype) |
| | output = torch.zeros_like(xs) |
| | for i, expert in enumerate(self.experts): |
| | mask = indices == i |
| | batch_idx, ith_expert = torch.where(mask) |
| | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( |
| | xs[batch_idx]) |
| | return output.view(B, L, D) |
| |
|