| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| import logging |
| import time |
| from typing import Any, Tuple, cast |
|
|
| import torch |
| import torch.distributed as dist |
| from torch import Tensor |
| from torch.nn import Module, ModuleList |
|
|
| try: |
| from fairseq.modules.moe import MOELayer |
|
|
| has_fairseq = True |
| Base = MOELayer |
| except ModuleNotFoundError: |
| Base = Module |
| has_fairseq = False |
|
|
| try: |
| |
| |
| from tutel import moe as tutel_moe |
|
|
| has_tutel, fused_cumsum_sub_one = True, tutel_moe.fast_cumsum_sub_one |
| except ModuleNotFoundError: |
| has_tutel, fused_cumsum_sub_one = False, lambda mask: torch.cumsum(mask, dim=0) - 1 |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
|
|
|
|
| |
| class _AllToAll(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: |
| ctx.group = group |
| input = input.contiguous() |
| output = torch.empty_like(input) |
| if torch.distributed.is_initialized(): |
| dist.all_to_all_single(output, input, group=group) |
| else: |
| assert group is None |
| output = input |
| return output |
|
|
| @staticmethod |
| def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: |
| return (None, _AllToAll.apply(ctx.group, *grad_output)) |
|
|
|
|
| def _find_my_group_index(grouped_ranks): |
| my_rank = dist.get_rank() |
| for i, group in enumerate(grouped_ranks): |
| if my_rank in group: |
| return i |
| raise RuntimeError |
|
|
|
|
| def get_moe_group(moe_expert_count): |
| if dist.is_initialized(): |
| if not hasattr(get_moe_group, "_moe_groups"): |
| world_size = dist.get_world_size() |
|
|
| if world_size <= moe_expert_count: |
| assert moe_expert_count % world_size == 0 |
| moe_groups = [[i] for i in range(world_size)] |
|
|
| else: |
| assert world_size % moe_expert_count == 0 |
| ranks_per_group = world_size // moe_expert_count |
| moe_groups = [ |
| [i + j * moe_expert_count for j in range(ranks_per_group)] for i in range(moe_expert_count) |
| ] |
|
|
| get_moe_group._moe_group_idx = moe_groups |
| get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups] |
|
|
| my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx) |
| return get_moe_group._moe_groups[my_group_idx] |
|
|
|
|
| def get_all2all_group(moe_expert_count): |
| if dist.is_initialized(): |
| if not hasattr(get_all2all_group, "_all2all_groups"): |
| world_size = dist.get_world_size() |
|
|
| |
| if world_size <= moe_expert_count: |
| assert moe_expert_count % world_size == 0 |
| all2all_groups = [[i for i in range(world_size)]] |
|
|
| |
| else: |
| assert world_size % moe_expert_count == 0 |
| ranks_per_group = world_size // moe_expert_count |
| all2all_groups = [ |
| [i * moe_expert_count + j for j in range(moe_expert_count)] for i in range(ranks_per_group) |
| ] |
|
|
| get_all2all_group._all2all_group_idx = all2all_groups |
| get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups] |
|
|
| my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx) |
| return get_all2all_group._all2all_groups[my_group_idx] |
|
|
|
|
| class MOELayer(Base): |
| """MOELayer module which implements MixtureOfExperts as described in Gshard_. |
| :: |
| |
| gate = Top2Gate(model_dim, num_experts) |
| moe = MOELayer(gate, expert) |
| output = moe(input) |
| l_aux = moe.l_aux |
| |
| .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf |
| |
| Args: |
| gate (torch.nn.Module): |
| gate network |
| expert (torch.nn.Module): |
| expert network |
| """ |
|
|
| def __init__(self, gate, experts, args): |
| if has_fairseq: |
| super(Base, self).__init__() |
| else: |
| super().__init__() |
| self.gate = gate |
| if type(experts) == ModuleList: |
| self.experts = cast(ModuleList, experts) |
| else: |
| self.experts = ModuleList([experts]) |
| self.expert_group = get_moe_group(args.moe_expert_count) |
| self.all2all_group = get_all2all_group(args.moe_expert_count) |
| self.world_size = dist.get_world_size(group=self.expert_group) |
| self.all2all_size = dist.get_world_size(group=self.all2all_group) |
| for p in experts.parameters(): |
| p.expert = True |
| self.num_local_experts = len(self.experts) |
| self.args = args |
| self.in_generation = False |
| self.a2a_cuda_event_intervals = [] |
| self.a2a_cpu_time_ms = 0.0 |
|
|
| def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor: |
| assert len(input) == 1, "only single input Tensor supported" |
| input = input[0] |
| assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel" |
| if input_padding_mask is not None: |
| assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken" |
| assert input_padding_mask.shape[0] == input.shape[0] |
| assert input_padding_mask.shape[1] == input.shape[1] |
| |
|
|
| |
| d_model = input.shape[2] |
| |
| input_shape = list(input.shape) |
| expected_bsz = ( |
| getattr(self.args, "batch_size", 0) if self.training else getattr(self.args, "batch_size_valid", 0) |
| ) |
| |
| if expected_bsz is None: |
| expected_bsz = 0 |
| |
| |
| |
| if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz: |
| logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})") |
| assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}" |
| padded_input = torch.zeros( |
| (expected_bsz, input_shape[1], input_shape[2]), |
| dtype=input.dtype, |
| layout=input.layout, |
| device=input.device, |
| ) |
| padded_input[: input_shape[0], :, :] = input |
| input = padded_input |
|
|
| padded_input_padding_mask = torch.ones( |
| ( |
| expected_bsz, |
| input_shape[1], |
| ), |
| dtype=torch.bool, |
| device=input.device, |
| ) |
| if input_padding_mask is not None: |
| padded_input_padding_mask[: input_shape[0], :] = input_padding_mask |
| else: |
| padded_input_padding_mask[: input_shape[0], :] = False |
| input_padding_mask = padded_input_padding_mask |
|
|
| |
| reshaped_input = input.reshape(-1, d_model) |
| reshaped_input_shape = reshaped_input.shape |
| reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None |
|
|
| |
| |
| |
| if expected_bsz == 0: |
| expected_dim = reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device) |
| dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX) |
| expected_dim = int(expected_dim.item()) |
| padded_input = torch.zeros( |
| (expected_dim, reshaped_input_shape[1]), |
| dtype=input.dtype, |
| layout=input.layout, |
| device=input.device, |
| ) |
| padded_input[: reshaped_input_shape[0], :] = reshaped_input |
| reshaped_input = padded_input |
|
|
| padded_input_padding_mask = torch.ones((expected_dim,), dtype=torch.bool, device=padded_input.device) |
| if reshaped_input_padding_mask is not None: |
| padded_input_padding_mask[: reshaped_input_shape[0]] = reshaped_input_padding_mask |
| else: |
| padded_input_padding_mask[: reshaped_input_shape[0]] = False |
| reshaped_input_padding_mask = padded_input_padding_mask |
|
|
| if has_tutel: |
| l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate( |
| reshaped_input, reshaped_input_padding_mask |
| ) |
| S, M = reshaped_input.size(0), reshaped_input.size(1) |
|
|
| if not hasattr(self, "_tutel_dispatcher"): |
| self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype) |
| self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) |
| dispatched_input = self._tutel_dispatcher.encode(reshaped_input) |
| else: |
| l_aux, combine_weights, dispatch_mask, self.metadata = self.gate( |
| reshaped_input, reshaped_input_padding_mask |
| ) |
|
|
| dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) |
| E, C, S = dispatch_mask.size() |
| M = reshaped_input.size(1) |
| assert reshaped_input.size() == (S, M) |
| |
| dispatched_input = torch.mm(dispatch_mask.view(E * C, S), reshaped_input) |
|
|
| if self.all2all_size > 1: |
| dispatched_input = self.all_to_all_wrapper(dispatched_input) |
|
|
| |
| dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model) |
| chunks = dispatched_input.chunk(self.num_local_experts, dim=1) |
| expert_outputs = [] |
| for chunk, expert in zip(chunks, self.experts): |
| expert_outputs += [expert(chunk)] |
| expert_output = torch.cat(expert_outputs, dim=1) |
|
|
| if self.all2all_size > 1: |
| expert_output = self.all_to_all_wrapper(expert_output) |
|
|
| |
| expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model) |
|
|
| if has_tutel: |
| combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M)) |
| else: |
| |
| combined_output = combine_weights.view(S, E * C).mm(expert_output.view(E * C, M)) |
|
|
| |
| combined_output = combined_output[: reshaped_input_shape[0], :] |
| combined_output = combined_output.reshape(input.shape) |
| combined_output = combined_output[: input_shape[0], :, :] |
|
|
| self.record_all_to_all_stats() |
|
|
| return combined_output, l_aux |
|
|
| def prepare_for_inference_(self): |
| self.in_generation = True |
|
|
| def all_to_all_wrapper(self, input: Tensor): |
| dummy_a2a = getattr(self.args, "dummy_a2a", False) |
| if dummy_a2a: |
| input = input.contiguous() |
| output = input.detach().clone() |
| return input |
| |
| |
| cuda_start = torch.cuda.Event(enable_timing=True) |
| cuda_end = torch.cuda.Event(enable_timing=True) |
| cpu_start = time.time() * 1000 |
| cuda_start.record() |
| output = _AllToAll.apply(self.all2all_group, input) |
| cuda_end.record() |
| cpu_end = time.time() * 1000 |
| self.a2a_cpu_time_ms += cpu_end - cpu_start |
| self.a2a_cuda_event_intervals.append((cuda_start, cuda_end)) |
| return output |
|
|
| def record_all_to_all_stats(self): |
| |
| record_a2a_perf_stats = getattr(self.args, "record_a2a_perf_stats", False) |
| if record_a2a_perf_stats: |
| torch.cuda.synchronize() |
| self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms |
| a2a_cuda_time_ms = 0.0 |
| for ev_start, ev_end in self.a2a_cuda_event_intervals: |
| a2a_cuda_time_ms += ev_start.elapsed_time(ev_end) |
| self.metadata["all_to_all_cuda_time_ms"] = a2a_cuda_time_ms |
| |
| self.a2a_cpu_time_ms = 0.0 |
| self.a2a_cuda_event_intervals = [] |
|
|