| """ |
| Simple MoE routing implementations that replace the MLP block in a standard transformer. |
| References: |
| 1) Mistral Source for Mixtral MoEs: |
| https://github.com/mistralai/mistral-src |
| 2) ST-MoE: |
| https://arxiv.org/abs/2202.08906 |
| 3) Our notepad of MoE resources: |
| https://docs.google.com/document/d/1NuQ5jr7V-Jv1ui7p4KrxO_JTz-7bpYcYMmh49EeJ-QA/edit?usp=sharing |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class MoE(nn.Module): |
| """ |
| Simplest MoE implementation with a linear router and softmax over experts. |
| |
| Note that in this implementation, we simply loop over the experts and |
| aggregate the results. This is not the most efficient way to do it, but |
| it also avoids the large memory overhead _and_ has no token dropping |
| (because we do not need the capacity factor). |
| """ |
|
|
| def __init__(self, config, mlp): |
| super().__init__() |
| assert config.moe_num_experts > 0 |
| self.experts = nn.ModuleList( |
| [mlp(config=config) for _ in range(config.moe_num_experts)] |
| ) |
| self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False) |
| self.top_k = config.moe_num_experts_per_tok |
| self.softmax_order = config.moe_softmax_order |
|
|
| def forward(self, inputs: torch.Tensor): |
| |
| inputs_squashed = inputs.view(-1, inputs.shape[-1]) |
| |
| router_logits = self.router(inputs_squashed) |
|
|
| |
| |
| if self.softmax_order == "softmax_topk": |
| all_probs = F.softmax(router_logits, dim=1, dtype=torch.float32) |
| weights, selected_experts = torch.topk(all_probs, self.top_k) |
| elif self.softmax_order == "topk_softmax": |
| weights, selected_experts = torch.topk(router_logits, self.top_k) |
| weights = F.softmax(weights, dim=-1, dtype=torch.float32) |
| else: |
| raise ValueError(f"Unknown softmax_order: {self.softmax_order}") |
|
|
| results = torch.zeros_like(inputs_squashed) |
| |
| for i, expert in enumerate(self.experts): |
| batch_idx, nth_expert = torch.where(selected_experts == i) |
| output, _ = expert(inputs_squashed[batch_idx]) |
| results[batch_idx] += weights[batch_idx, nth_expert, None] * output |
|
|
| |
| return results.view_as(inputs), { |
| "router_logits": router_logits, |
| "selected_experts": selected_experts, |
| } |
|
|
|
|
| class ExpertChoiceMoE(nn.Module): |
| """ |
| This is the MoE implementation that uses the expert choice method from |
| https://arxiv.org/pdf/2202.09368v2.pdf. |
| |
| The main difference is that the router takes the softmax over the tokens, not the experts |
| (i.e. each expert chooses its top-k tokens, not the other way around). |
| For the same capacity factor, in theory, the same compute will be used as in standard top-k routing. |
| AFAICT, there is no way around the capacity factor (whereas the code above does not need it). |
| """ |
|
|
| def __init__(self, config, mlp): |
| super().__init__() |
| assert config.moe_num_experts > 0 |
| self.n_experts = config.moe_num_experts |
| self.experts = nn.ModuleList( |
| [mlp(config=config) for _ in range(config.moe_num_experts)] |
| ) |
| self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False) |
| self.capacity_factor = config.capacity_factor |
| self.softmax_order = config.moe_softmax_order |
| self.top_k = int( |
| self.capacity_factor |
| * config.batch_size |
| * config.sequence_length |
| / config.moe_num_experts |
| ) |
|
|
| def forward(self, inputs: torch.Tensor): |
| |
| inputs_squashed = inputs.view(-1, inputs.shape[-1]) |
| num_tokens = inputs_squashed.shape[0] |
| top_k = min(self.top_k, int(self.capacity_factor * num_tokens / self.n_experts)) |
| |
| router_logits = self.router(inputs_squashed) |
|
|
| |
| |
| if self.softmax_order == "softmax_topk": |
| all_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) |
| |
| |
| weights, selected_tokens = torch.topk(all_probs.T, top_k) |
| elif self.softmax_order == "topk_softmax": |
| |
| weights, selected_tokens = torch.topk(router_logits.T, top_k) |
| weights = F.softmax(weights, dim=-1, dtype=torch.float32) |
| else: |
| raise ValueError(f"Unknown softmax_order: {self.softmax_order}") |
|
|
| """ this is the full parallel version with einsum -- this can OOM quickly """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ this is the naive loop version """ |
| |
| |
| |
| results = torch.zeros_like(inputs_squashed) |
| for i, expert in enumerate(self.experts): |
| |
| batch_idx = selected_tokens[i] |
| |
| output, _ = expert(inputs_squashed[batch_idx]) |
| results[batch_idx] += weights[i, :, None] * output |
|
|
| |
| return results.view_as(inputs), { |
| "router_logits": router_logits, |
| "selected_experts": selected_tokens, |
| } |
|
|