| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| import math |
| from typing import Callable, Dict, Optional, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from .moe_layer import fused_cumsum_sub_one, has_tutel |
|
|
| |
| TEMPERATURE_FOR_L_UAX = 0.07 |
|
|
| |
| |
| EVAL_CAPACITY_TOKEN_FRACTION = 0.25 |
|
|
| |
| SAMPLE_FRACTION = 0.2 |
|
|
|
|
| def top1gating( |
| logits: torch.Tensor, |
| input_mask: Optional[torch.Tensor] = None, |
| use_fp32=False, |
| capacity_factor=1.0, |
| eval_mode=False, |
| moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION, |
| use_xmoe=False, |
| gate_obj=None, |
| ) -> Tuple[Tensor, Tensor, Tensor, Dict]: |
| """Implements Top2Gating on logits.""" |
| metadata = {} |
| if use_fp32: |
| orig_dtype = logits.dtype |
| logits = logits.float() |
|
|
| gates = F.softmax(logits, dim=1) |
| metadata["entropy_gating"] = entropy(probs=gates).mean().detach() |
|
|
| |
| num_tokens = gates.shape[0] |
| num_experts = gates.shape[1] |
| if moe_eval_capacity_token_fraction > 0.0 and eval_mode: |
| capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens) |
| else: |
| |
| capacity = int(capacity_factor * math.ceil(num_tokens / num_experts)) |
|
|
| |
| indices1_s = torch.argmax(gates, dim=1) |
| mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True) |
| if input_mask is not None and input_mask.any(): |
| nonpadding = ~input_mask |
| mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype) |
|
|
| |
| expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens |
| metadata["unused_expert1_count"] = (expert1_hist == 0).sum() |
| expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny |
|
|
| sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1) |
| metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum() |
| metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum() |
|
|
| gates1_s = (gates * mask1).sum(dim=1) |
|
|
| |
| locations1 = fused_cumsum_sub_one(mask1) |
|
|
| |
| me = torch.mean(gates, dim=0) |
| ce = torch.mean(mask1.to(gates.dtype), dim=0) |
|
|
| l_aux = torch.mean(me * ce) |
| l_aux = l_aux * num_experts * num_experts |
|
|
| if has_tutel: |
| locations1_s = torch.sum(locations1 * mask1, dim=1) |
| return ( |
| l_aux, |
| metadata, |
| capacity, |
| num_experts, |
| [ |
| indices1_s, |
| ], |
| [ |
| locations1_s, |
| ], |
| [ |
| gates1_s, |
| ], |
| ) |
|
|
| |
| mask1 = mask1 * torch.lt(locations1, capacity) |
| |
| locations1_s = torch.sum(locations1 * mask1, dim=1) |
|
|
| |
| gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) |
| |
| locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True) |
| combine1_sec = torch.bmm( |
| |
| gates1.unsqueeze(-1), |
| locations1_sc.to(gates1.dtype).unsqueeze(1), |
| ) |
| dispatch_mask = combine1_sec.bool() |
| if use_fp32: |
| return l_aux, combine1_sec.to(orig_dtype), dispatch_mask, metadata |
| else: |
| return l_aux, combine1_sec, dispatch_mask, metadata |
|
|
|
|
| class Top1Gate(torch.nn.Module): |
| """Gate module which implements Top2Gating as described in Gshard_. |
| :: |
| |
| gate = Top2Gate(model_dim, num_experts) |
| l_aux, combine_weights, dispatch_mask = gate(input) |
| |
| .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf |
| |
| Args: |
| model_dim (int): |
| size of model embedding dimension |
| num_experts (ints): |
| number of experts in model |
| """ |
|
|
| wg: torch.nn.Linear |
|
|
| def __init__( |
| self, |
| model_dim: int, |
| num_experts: int, |
| use_fp32=False, |
| input_noise_type=None, |
| capacity_factor=1.0, |
| moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION, |
| use_xmoe=False, |
| ) -> None: |
| |
| |
| super().__init__() |
|
|
| if not use_xmoe: |
| self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) |
| else: |
| self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False) |
| wg = torch.empty(num_experts, 16) |
| torch.nn.init.orthogonal_(wg, gain=0.32) |
| self.register_parameter("wg", torch.nn.Parameter(wg)) |
|
|
| self.use_xmoe = use_xmoe |
| self.use_fp32 = use_fp32 |
| self.input_noise_type = input_noise_type |
| self.capacity_factor = capacity_factor |
| self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction |
|
|
| def forward(self, input, mask=None): |
| if self.use_xmoe: |
| input = self.wg_reduction(input) |
| with torch.no_grad(): |
| wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True) |
| self.wg.mul_(1.5 / wg_norm) |
| logits = self._cosine(input, self.wg) |
| logits = self._make_finite(logits) |
| else: |
| logits = self.wg(input) |
|
|
| return top1gating( |
| logits, |
| mask, |
| use_fp32=self.use_fp32, |
| capacity_factor=self.capacity_factor, |
| eval_mode=not self.training, |
| moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction, |
| use_xmoe=self.use_xmoe, |
| gate_obj=self, |
| ) |
|
|
| def _make_finite(self, scores): |
| ok = scores.isfinite() |
| if not ok.all(): |
| |
| scores[~ok] = scores[ok].min() |
| return scores |
|
|
| def _get_gating_temperature(self, eps=1e-4): |
| if self.gating_t.data.item() < eps: |
| return eps |
| return self.gating_t |
|
|
| def _cosine(self, mat1, mat2, eps=1e-4): |
| assert mat1.dim() == 2 |
| assert mat2.dim() == 2 |
| |
| mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps) |
| return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1) |
|
|
|
|
| gumbel_map: Dict[torch.device, Callable] = {} |
|
|
|
|
| def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: |
| gumbel = gumbel_map.get(device) |
| if gumbel is None: |
| one = torch.tensor(1.0, device=device) |
| zero = torch.tensor(0.0, device=device) |
| gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample |
| gumbel_map[device] = gumbel |
| return gumbel(shape) |
|
|
|
|
| def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) -> Tensor: |
| if unsqueeze_indices: |
| indices = indices.unsqueeze(-1) |
| assert indices.shape[-1] == 1, "last dimension of indices must be have size 1" |
| output = torch.zeros(indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype) |
| output.scatter_(len(output.shape) - 1, indices, 1) |
| return output |
|
|
|
|
| def entropy(probs): |
| logits = torch.distributions.utils.probs_to_logits(probs) |
| p_log_p = probs * logits |
| return -p_log_p.sum(-1) |
|
|
|
|
| def top2gating( |
| logits: torch.Tensor, |
| input_mask: Optional[torch.Tensor] = None, |
| use_fp32=False, |
| second_expert_policy="sampling", |
| normalize_gate_prob_before_dropping=False, |
| eval_mode=False, |
| moe_eval_capacity_token_fraction=0.25, |
| batch_prioritized_routing=False, |
| ) -> Tuple[Tensor, Tensor, Tensor]: |
| """Implements Top2Gating on logits.""" |
| metadata = {} |
| if use_fp32: |
| orig_dtype = logits.dtype |
| logits = logits.float() |
| gates = F.softmax(logits, dim=1) |
| metadata["entropy_gating"] = entropy(probs=gates).mean().detach() |
| |
| num_tokens = gates.shape[0] |
| num_experts = gates.shape[1] |
| if moe_eval_capacity_token_fraction > 0.0 and eval_mode: |
| capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens) |
| else: |
| |
| capacity = 2 * math.ceil(num_tokens / num_experts) |
|
|
| |
| indices1_s = torch.argmax(gates, dim=1, keepdim=True) |
| mask1 = one_hot(indices1_s, num_experts) |
| if second_expert_policy == "sampling": |
| |
| |
| logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) |
| else: |
| logits_w_noise = logits |
| |
| logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) |
| indices2_s = torch.argmax(logits_except1, dim=1, keepdim=True) |
| mask2 = one_hot(indices2_s, num_experts) |
| gates1_s = (gates * mask1).sum(dim=1) |
| gates2_s = (gates * mask2).sum(dim=1) |
|
|
| if normalize_gate_prob_before_dropping: |
| |
| denom_s = gates1_s + gates2_s |
| |
| denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) |
| gates1_s = gates1_s / denom_s |
| gates2_s = gates2_s / denom_s |
|
|
| if second_expert_policy == "random": |
| sampled = (2 * gates2_s) > torch.rand_like(gates2_s) |
| mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0) |
|
|
| |
| if input_mask is not None and input_mask.any(): |
| nonpadding = ~input_mask |
| mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype) |
| mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype) |
|
|
| if batch_prioritized_routing: |
| |
| importance_scores = -1 * gates.max(dim=1)[0] |
| sorted_mask1 = mask1[importance_scores.argsort(dim=0)] |
| sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1 |
| importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)] |
|
|
| sorted_mask2 = mask2[importance_scores.argsort(dim=0)] |
| sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2 |
| importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)] |
|
|
| importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True) |
|
|
| locations1, locations2 = ( |
| importance_sorted_locations1, |
| importance_sorted_locations2, |
| ) |
| else: |
| locations1 = fused_cumsum_sub_one(mask1) |
| locations2 = fused_cumsum_sub_one(mask2) |
| |
| locations2 += torch.sum(mask1, dim=0, keepdim=True) |
|
|
| |
| me = torch.mean(gates, dim=0) |
| ce = torch.mean(mask1.to(gates.dtype), dim=0) |
| l_aux = torch.mean(me * ce) |
| l_aux = l_aux * num_experts * num_experts |
|
|
| |
| metadata["overflow_expert1"] = 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1) |
| metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2) |
|
|
| |
| mask1_, mask2_ = mask1, mask2 |
| mask1 = mask1 * torch.lt(locations1, capacity) |
| mask2 = mask2 * torch.lt(locations2, capacity) |
|
|
| |
| expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens |
| metadata["unused_expert1_count"] = (expert1_hist == 0).sum() |
| expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny |
|
|
| expert2_hist = 100 * torch.histc((indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens |
| metadata["unused_expert2_count"] = (expert2_hist == 0).sum() |
| expert2_hist = torch.sort(expert2_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny |
|
|
| sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1) |
| metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum() |
| metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum() |
|
|
| metadata["expert2_balance_top"] = expert2_hist[:sample_count].sum() |
| metadata["expert2_balance_bottom"] = expert2_hist[-sample_count:].sum() |
|
|
| if not normalize_gate_prob_before_dropping: |
| |
| gates1_s = (gates * mask1).sum(dim=1) |
| gates2_s = (gates * mask2).sum(dim=1) |
| denom_s = gates1_s + gates2_s |
| |
| denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) |
| gates1_s /= denom_s |
| gates2_s /= denom_s |
|
|
| if has_tutel: |
| locations1_s = torch.sum(locations1 * mask1_, dim=1) |
| locations2_s = torch.sum(locations2 * mask2_, dim=1) |
| return ( |
| l_aux, |
| metadata, |
| capacity, |
| num_experts, |
| [indices1_s, indices2_s], |
| [locations1_s, locations2_s], |
| [gates1_s, gates2_s], |
| ) |
|
|
| |
| locations1_s = torch.sum(locations1 * mask1, dim=1) |
| locations2_s = torch.sum(locations2 * mask2, dim=1) |
|
|
| |
| gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) |
| gates2 = gates2_s.unsqueeze(-1) * mask2.to(gates2_s.dtype) |
| locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True) |
| locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True) |
| combine1_sec = torch.bmm( |
| |
| gates1.unsqueeze(-1), |
| locations1_sc.to(gates1.dtype).unsqueeze(1), |
| ) |
| combine2_sec = torch.bmm( |
| |
| gates2.unsqueeze(-1), |
| locations2_sc.to(gates2.dtype).unsqueeze(1), |
| ) |
| combine_weights = combine1_sec + combine2_sec |
| dispatch_mask = combine_weights.bool() |
| if use_fp32: |
| return l_aux, combine_weights.to(orig_dtype), dispatch_mask, metadata |
| else: |
| return l_aux, combine_weights, dispatch_mask, metadata |
|
|
|
|
| class Top2Gate(torch.nn.Module): |
| """Gate module which implements Top2Gating as described in Gshard_. |
| :: |
| |
| gate = Top2Gate(model_dim, num_experts) |
| l_aux, combine_weights, dispatch_mask = gate(input) |
| |
| .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf |
| |
| Args: |
| model_dim (int): |
| size of model embedding dimension |
| num_experts (ints): |
| number of experts in model |
| """ |
|
|
| wg: torch.nn.Linear |
|
|
| def __init__( |
| self, |
| model_dim: int, |
| num_experts: int, |
| use_fp32=False, |
| second_expert_policy="sampling", |
| normalize_gate_prob_before_dropping=False, |
| moe_eval_capacity_token_fraction=0.25, |
| batch_prioritized_routing=False, |
| use_xmoe=False, |
| ) -> None: |
| super().__init__() |
| if not use_xmoe: |
| self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) |
| else: |
| self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False) |
| wg = torch.empty(num_experts, 16) |
| torch.nn.init.orthogonal_(wg, gain=0.32) |
| self.register_parameter("wg", torch.nn.Parameter(wg)) |
| self.use_fp32 = use_fp32 |
| self.second_expert_policy = second_expert_policy |
| self.normalize_gate_prob_before_dropping = normalize_gate_prob_before_dropping |
| self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction |
| self.batch_prioritized_routing = batch_prioritized_routing |
| self.use_xmoe = use_xmoe |
|
|
| def forward(self, input, mask=None): |
| if self.use_xmoe: |
| input = self.wg_reduction(input) |
| with torch.no_grad(): |
| wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True) |
| self.wg.mul_(1.5 / wg_norm) |
| logits = self._cosine(input, self.wg) |
| logits = self._make_finite(logits) |
| else: |
| logits = self.wg(input) |
| return top2gating( |
| logits, |
| mask, |
| use_fp32=self.use_fp32, |
| second_expert_policy=self.second_expert_policy, |
| normalize_gate_prob_before_dropping=self.normalize_gate_prob_before_dropping, |
| eval_mode=not self.training, |
| moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction, |
| batch_prioritized_routing=self.batch_prioritized_routing, |
| ) |
|
|
| def _cosine(self, mat1, mat2, eps=1e-4): |
| assert mat1.dim() == 2 |
| assert mat2.dim() == 2 |
| |
| mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps) |
| return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1) |
|
|
| def _make_finite(self, scores): |
| ok = scores.isfinite() |
| if not ok.all(): |
| |
| scores[~ok] = scores[ok].min() |
| return scores |
|
|