UniMoE-Audio-Preview / deepspeed_utils.py
foggyforest's picture
Upload 14 files
e576ca4 verified
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import deepspeed
import torch
import torch.nn.functional as F
from deepspeed import comm as dist
from deepspeed.moe.sharded_moe import _capacity, _one_hot_to_float, einsum, gumbel_rsample
from torch import Tensor
try:
# To enable Tutel MoE optimizations:
# python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
from tutel import moe as tutel_moe
TUTEL_INSTALLED = True
except:
# Fail silently so we don't spam logs unnecessarily if user isn't using tutel
TUTEL_INSTALLED = False
pass
# =============================================================================
# DeepSpeed MoE Inference Utilities
# =============================================================================
def _AllToAll_forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
return input
def gate_forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
d_model = input[0].shape[-1]
reshaped_input = input[0].reshape(-1, d_model)
if self.use_tutel:
self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
S, M = reshaped_input.size(0), reshaped_input.size(1)
if not hasattr(self, "_tutel_dispatcher"):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
expert_output = self.experts(dispatched_input)
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, dispatched_input.shape[2], -1)
if self.use_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
else:
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
a = combined_output.reshape(input[0].size()[:-1] + (-1,))
return a
def top2gating(
logits: Tensor, capacity_factor: float, min_capacity: int, drop_tokens: bool = True, ep_group: Union[torch.distributed.ProcessGroup, None] = None, top2_2nd_expert_sampling: bool = True
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
gates = F.softmax(logits, dim=1)
indices1_s = torch.argmax(gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
if top2_2nd_expert_sampling:
logits += gumbel_rsample(logits.shape, device=logits.device)
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
locations2 += torch.sum(mask1, dim=0, keepdim=True)
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
if drop_tokens:
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
else:
new_capacity = torch.max(exp_counts)
capacity = new_capacity
locations1_s = torch.sum(locations1 * mask1, dim=1)
locations2_s = torch.sum(locations2 * mask2, dim=1)
mask1_float = mask1.float()
mask2_float = mask2.float()
gates1_s = einsum("se,se->s", gates, mask1_float)
gates2_s = einsum("se,se->s", gates, mask2_float)
denom_s = gates1_s + gates2_s
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gates1_s /= denom_s
gates2_s /= denom_s
gates1 = einsum("s,se->se", gates1_s, mask1_float)
gates2 = einsum("s,se->se", gates2_s, mask2_float)
locations1_sc = _one_hot_to_float(locations1_s, capacity)
locations2_sc = _one_hot_to_float(locations2_s, capacity)
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask, exp_counts
# Apply the modifications to deepspeed
deepspeed.moe.sharded_moe.MOELayer.forward = gate_forward
deepspeed.moe.sharded_moe.top2gating = top2gating
deepspeed.moe.sharded_moe._AllToAll.forward = _AllToAll_forward