# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. # NOTE: This is a mirror of the code in # https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe import logging import time from typing import Any, Tuple, cast import torch import torch.distributed as dist from torch import Tensor from torch.nn import Module, ModuleList try: from fairseq.modules.moe import MOELayer has_fairseq = True Base = MOELayer except ModuleNotFoundError: Base = Module has_fairseq = False try: # To enable Tutel MoE optimizations: # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x from tutel import moe as tutel_moe has_tutel, fused_cumsum_sub_one = True, tutel_moe.fast_cumsum_sub_one except ModuleNotFoundError: has_tutel, fused_cumsum_sub_one = False, lambda mask: torch.cumsum(mask, dim=0) - 1 logger = logging.getLogger(__name__) # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity # See https://arxiv.org/pdf/2006.16668.pdf for details. # Based on https://github.com/pytorch/pytorch/pull/40762 class _AllToAll(torch.autograd.Function): @staticmethod def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore ctx.group = group input = input.contiguous() output = torch.empty_like(input) if torch.distributed.is_initialized(): dist.all_to_all_single(output, input, group=group) else: assert group is None output = input return output @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: return (None, _AllToAll.apply(ctx.group, *grad_output)) def _find_my_group_index(grouped_ranks): my_rank = dist.get_rank() for i, group in enumerate(grouped_ranks): if my_rank in group: return i raise RuntimeError def get_moe_group(moe_expert_count): if dist.is_initialized(): if not hasattr(get_moe_group, "_moe_groups"): world_size = dist.get_world_size() if world_size <= moe_expert_count: assert moe_expert_count % world_size == 0 moe_groups = [[i] for i in range(world_size)] else: assert world_size % moe_expert_count == 0 ranks_per_group = world_size // moe_expert_count moe_groups = [ [i + j * moe_expert_count for j in range(ranks_per_group)] for i in range(moe_expert_count) ] get_moe_group._moe_group_idx = moe_groups get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups] my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx) return get_moe_group._moe_groups[my_group_idx] def get_all2all_group(moe_expert_count): if dist.is_initialized(): if not hasattr(get_all2all_group, "_all2all_groups"): world_size = dist.get_world_size() # more experts than world size if world_size <= moe_expert_count: assert moe_expert_count % world_size == 0 all2all_groups = [[i for i in range(world_size)]] # larger world than num experts else: assert world_size % moe_expert_count == 0 ranks_per_group = world_size // moe_expert_count all2all_groups = [ [i * moe_expert_count + j for j in range(moe_expert_count)] for i in range(ranks_per_group) ] get_all2all_group._all2all_group_idx = all2all_groups get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups] my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx) return get_all2all_group._all2all_groups[my_group_idx] class MOELayer(Base): """MOELayer module which implements MixtureOfExperts as described in Gshard_. :: gate = Top2Gate(model_dim, num_experts) moe = MOELayer(gate, expert) output = moe(input) l_aux = moe.l_aux .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf Args: gate (torch.nn.Module): gate network expert (torch.nn.Module): expert network """ def __init__(self, gate, experts, args): if has_fairseq: super(Base, self).__init__() else: super().__init__() self.gate = gate if type(experts) == ModuleList: self.experts = cast(ModuleList, experts) else: self.experts = ModuleList([experts]) self.expert_group = get_moe_group(args.moe_expert_count) self.all2all_group = get_all2all_group(args.moe_expert_count) self.world_size = dist.get_world_size(group=self.expert_group) self.all2all_size = dist.get_world_size(group=self.all2all_group) for p in experts.parameters(): p.expert = True # type: ignore self.num_local_experts = len(self.experts) self.args = args self.in_generation = False self.a2a_cuda_event_intervals = [] self.a2a_cpu_time_ms = 0.0 def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor: assert len(input) == 1, "only single input Tensor supported" input = input[0] assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel" if input_padding_mask is not None: assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken" assert input_padding_mask.shape[0] == input.shape[0] assert input_padding_mask.shape[1] == input.shape[1] # assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts" # Implement Algorithm 2 from GShard paper. d_model = input.shape[2] # Pad to expected batch size input_shape = list(input.shape) expected_bsz = ( getattr(self.args, "batch_size", 0) if self.training else getattr(self.args, "batch_size_valid", 0) ) # This indicates that --batch-size or --max-sentences is not specified if expected_bsz is None: expected_bsz = 0 # Note: Padding is not necessary at generation time at present # because all DDP workers process the same batch. Also, batch size at generation time # can be different from that present in the checkpoint state if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz: logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})") assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}" padded_input = torch.zeros( (expected_bsz, input_shape[1], input_shape[2]), dtype=input.dtype, layout=input.layout, device=input.device, ) padded_input[: input_shape[0], :, :] = input input = padded_input padded_input_padding_mask = torch.ones( ( expected_bsz, input_shape[1], ), dtype=torch.bool, device=input.device, ) if input_padding_mask is not None: padded_input_padding_mask[: input_shape[0], :] = input_padding_mask else: padded_input_padding_mask[: input_shape[0], :] = False input_padding_mask = padded_input_padding_mask # Reshape into S tokens by dropping sequence dimension. reshaped_input = input.reshape(-1, d_model) reshaped_input_shape = reshaped_input.shape reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None # Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences # Pro of --max-tokens: more flexible for MT variable sequence lengths # Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM if expected_bsz == 0: expected_dim = reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device) dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX) expected_dim = int(expected_dim.item()) padded_input = torch.zeros( (expected_dim, reshaped_input_shape[1]), dtype=input.dtype, layout=input.layout, device=input.device, ) padded_input[: reshaped_input_shape[0], :] = reshaped_input reshaped_input = padded_input padded_input_padding_mask = torch.ones((expected_dim,), dtype=torch.bool, device=padded_input.device) if reshaped_input_padding_mask is not None: padded_input_padding_mask[: reshaped_input_shape[0]] = reshaped_input_padding_mask else: padded_input_padding_mask[: reshaped_input_shape[0]] = False reshaped_input_padding_mask = padded_input_padding_mask if has_tutel: l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate( reshaped_input, reshaped_input_padding_mask ) S, M = reshaped_input.size(0), reshaped_input.size(1) if not hasattr(self, "_tutel_dispatcher"): self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype) self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) dispatched_input = self._tutel_dispatcher.encode(reshaped_input) else: l_aux, combine_weights, dispatch_mask, self.metadata = self.gate( reshaped_input, reshaped_input_padding_mask ) dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S E, C, S = dispatch_mask.size() M = reshaped_input.size(1) assert reshaped_input.size() == (S, M) # einsum("sec,sm->ecm") dispatched_input = torch.mm(dispatch_mask.view(E * C, S), reshaped_input) # -> (E*C),M if self.all2all_size > 1: dispatched_input = self.all_to_all_wrapper(dispatched_input) # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model) chunks = dispatched_input.chunk(self.num_local_experts, dim=1) expert_outputs = [] for chunk, expert in zip(chunks, self.experts): expert_outputs += [expert(chunk)] expert_output = torch.cat(expert_outputs, dim=1) if self.all2all_size > 1: expert_output = self.all_to_all_wrapper(expert_output) # Re-shape back: gecm -> ecm expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model) if has_tutel: combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M)) else: # einsum("sec,ecm->sm") combined_output = combine_weights.view(S, E * C).mm(expert_output.view(E * C, M)) # Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences combined_output = combined_output[: reshaped_input_shape[0], :] combined_output = combined_output.reshape(input.shape) combined_output = combined_output[: input_shape[0], :, :] self.record_all_to_all_stats() return combined_output, l_aux def prepare_for_inference_(self): self.in_generation = True def all_to_all_wrapper(self, input: Tensor): dummy_a2a = getattr(self.args, "dummy_a2a", False) if dummy_a2a: input = input.contiguous() output = input.detach().clone() return input # always record times, since it is not a lot of overhead # if we do not log it we simply clear it off in record_all_to_all_stats cuda_start = torch.cuda.Event(enable_timing=True) cuda_end = torch.cuda.Event(enable_timing=True) cpu_start = time.time() * 1000 cuda_start.record() output = _AllToAll.apply(self.all2all_group, input) cuda_end.record() cpu_end = time.time() * 1000 self.a2a_cpu_time_ms += cpu_end - cpu_start self.a2a_cuda_event_intervals.append((cuda_start, cuda_end)) return output def record_all_to_all_stats(self): # controlled via an argument as we want to minimize any impact from torch.cuda.synchronize() record_a2a_perf_stats = getattr(self.args, "record_a2a_perf_stats", False) if record_a2a_perf_stats: torch.cuda.synchronize() self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms a2a_cuda_time_ms = 0.0 for ev_start, ev_end in self.a2a_cuda_event_intervals: a2a_cuda_time_ms += ev_start.elapsed_time(ev_end) self.metadata["all_to_all_cuda_time_ms"] = a2a_cuda_time_ms # reset stats self.a2a_cpu_time_ms = 0.0 self.a2a_cuda_event_intervals = []