| | """ |
| | Simple MoE routing implementations that replace the MLP block in a standard transformer. |
| | References: |
| | 1) Mistral Source for Mixtral MoEs: |
| | https://github.com/mistralai/mistral-src |
| | 2) ST-MoE: |
| | https://arxiv.org/abs/2202.08906 |
| | 3) Our notepad of MoE resources: |
| | https://docs.google.com/document/d/1NuQ5jr7V-Jv1ui7p4KrxO_JTz-7bpYcYMmh49EeJ-QA/edit?usp=sharing |
| | """ |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import bisect |
| | import math |
| |
|
| | class MoE(nn.Module): |
| | """ |
| | Simplest MoE implementation with a linear router and softmax over experts. |
| | |
| | Note that in this implementation, we simply loop over the experts and |
| | aggregate the results. This is not the most efficient way to do it, but |
| | it also avoids the large memory overhead _and_ has no token dropping |
| | (because we do not need the capacity factor). |
| | """ |
| |
|
| | def __init__(self, config, mlp): |
| | super().__init__() |
| | assert config.moe_num_experts > 0 |
| | self.experts = nn.ModuleList( |
| | [mlp(config=config) for _ in range(config.moe_num_experts)] |
| | ) |
| | self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False) |
| | self.top_k = config.moe_num_experts_per_tok |
| | self.softmax_order = config.moe_softmax_order |
| |
|
| | def forward(self, inputs: torch.Tensor): |
| | |
| | inputs_squashed = inputs.view(-1, inputs.shape[-1]) |
| | |
| | router_logits = self.router(inputs_squashed) |
| |
|
| | |
| | |
| | if self.softmax_order == "softmax_topk": |
| | all_probs = F.softmax(router_logits, dim=1) |
| | weights, selected_experts = torch.topk(all_probs, self.top_k) |
| | elif self.softmax_order == "topk_softmax": |
| | weights, selected_experts = torch.topk(router_logits, self.top_k) |
| | weights = F.softmax(weights, dim=-1) |
| | else: |
| | raise ValueError(f"Unknown softmax_order: {self.softmax_order}") |
| |
|
| | results = torch.zeros_like(inputs_squashed) |
| | |
| | for i, expert in enumerate(self.experts): |
| | batch_idx, nth_expert = torch.where(selected_experts == i) |
| | expert_input = inputs_squashed[batch_idx] |
| | output, _ = expert(expert_input) |
| | results[batch_idx] += weights[batch_idx, nth_expert, None] * output.squeeze(0) |
| |
|
| | |
| | return results.view_as(inputs), { |
| | "router_logits": router_logits, |
| | "selected_experts": selected_experts, |
| | } |
| |
|
| |
|
| | class DummyExpert(nn.Module): |
| | def __init__(self, output_size: int): |
| | super().__init__() |
| | self._output_size = output_size |
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | out = torch.zeros((self._output_size,), device=inputs.device) |
| | return out, {} |
| | |
| | |
| |
|
| | class MaskedMoE(MoE): |
| | def __init__(self, config, mlp): |
| | super().__init__(config, mlp) |
| | self._sequence_length = config.sequence_length |
| | self.experts.append(DummyExpert(config.n_embd)) |
| | self.router = nn.Linear(config.n_embd, config.moe_num_experts+1, bias=False) |
| |
|
| |
|
| | def forward(self, inputs: torch.Tensor, mask: torch.Tensor): |
| | inputs_squashed = inputs.view(-1, inputs.shape[-1]) |
| | router_logits = self.router(inputs_squashed) |
| | mask = torch.cat( |
| | (mask, torch.ones((mask.shape[0], 1), device=mask.device)), |
| | dim=1 |
| | ) |
| | mask = mask.repeat_interleave(self._sequence_length, dim=0) |
| | router_logits = router_logits*mask |
| |
|
| | |
| | |
| | if self.softmax_order == "softmax_topk": |
| | all_probs = F.softmax(router_logits, dim=1) |
| | weights, selected_experts = torch.topk(all_probs, self.top_k) |
| | elif self.softmax_order == "topk_softmax": |
| | weights, selected_experts = torch.topk(router_logits, self.top_k) |
| | weights = F.softmax(weights, dim=-1) |
| | else: |
| | raise ValueError(f"Unknown softmax_order: {self.softmax_order}") |
| |
|
| | results = torch.zeros_like(inputs_squashed) |
| | |
| | for i, expert in enumerate(self.experts): |
| | batch_idx, nth_expert = torch.where(selected_experts == i) |
| | expert_input = inputs_squashed[batch_idx] |
| | output, _ = expert(expert_input) |
| | results[batch_idx] += weights[batch_idx, nth_expert, None] * output.squeeze(0) |
| |
|
| | |
| | return results.view_as(inputs), { |
| | "router_logits": router_logits, |
| | "selected_experts": selected_experts, |
| | } |
| | |
| |
|
| | class TimeDependantMoE(nn.Module): |
| | def __init__(self, config, mlp): |
| | super().__init__() |
| | self._num_experts = config.moe_num_experts |
| | self._mask_moe = MaskedMoE(config, mlp) |
| |
|
| | def forward(self, x, date): |
| | mask_date = torch.zeros(x.shape[0], self._num_experts).to(x.device) |
| | range_tensor = torch.arange(self._num_experts).unsqueeze(0).to(x.device) |
| | mask_date = (range_tensor < date.unsqueeze(1)).float() |
| | return self._mask_moe(x, mask_date) |