File size: 5,031 Bytes
e576ca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union

import deepspeed
import torch
import torch.nn.functional as F
from deepspeed import comm as dist
from deepspeed.moe.sharded_moe import _capacity, _one_hot_to_float, einsum, gumbel_rsample
from torch import Tensor

try:
    # To enable Tutel MoE optimizations:
    #   python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
    from tutel import moe as tutel_moe

    TUTEL_INSTALLED = True
except:
    # Fail silently so we don't spam logs unnecessarily if user isn't using tutel
    TUTEL_INSTALLED = False
    pass


# =============================================================================
# DeepSpeed MoE Inference Utilities
# =============================================================================

def _AllToAll_forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor:  # type: ignore
    ctx.group = group
    input = input.contiguous()
    return input


def gate_forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
    d_model = input[0].shape[-1]
    reshaped_input = input[0].reshape(-1, d_model)

    if self.use_tutel:
        self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
        S, M = reshaped_input.size(0), reshaped_input.size(1)

        if not hasattr(self, "_tutel_dispatcher"):
            self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
        self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
        dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
    else:
        self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
        dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)

    dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
    expert_output = self.experts(dispatched_input)
    expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, dispatched_input.shape[2], -1)

    if self.use_tutel:
        combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
    else:
        combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)

    a = combined_output.reshape(input[0].size()[:-1] + (-1,))

    return a


def top2gating(
    logits: Tensor, capacity_factor: float, min_capacity: int, drop_tokens: bool = True, ep_group: Union[torch.distributed.ProcessGroup, None] = None, top2_2nd_expert_sampling: bool = True
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """Implements Top2Gating on logits."""
    gates = F.softmax(logits, dim=1)
    indices1_s = torch.argmax(gates, dim=1)
    num_experts = int(gates.shape[1])
    mask1 = F.one_hot(indices1_s, num_classes=num_experts)

    if top2_2nd_expert_sampling:
        logits += gumbel_rsample(logits.shape, device=logits.device)

    logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
    indices2_s = torch.argmax(logits_except1, dim=1)
    mask2 = F.one_hot(indices2_s, num_classes=num_experts)

    locations1 = torch.cumsum(mask1, dim=0) - 1
    locations2 = torch.cumsum(mask2, dim=0) - 1
    locations2 += torch.sum(mask1, dim=0, keepdim=True)

    me = torch.mean(gates, dim=0)
    ce = torch.mean(mask1.float(), dim=0)
    l_aux = torch.mean(me * ce) * num_experts * num_experts
    exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)

    if drop_tokens:
        capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
        mask1 *= torch.lt(locations1, capacity)
        mask2 *= torch.lt(locations2, capacity)
    else:
        new_capacity = torch.max(exp_counts)
        capacity = new_capacity

    locations1_s = torch.sum(locations1 * mask1, dim=1)
    locations2_s = torch.sum(locations2 * mask2, dim=1)
    mask1_float = mask1.float()
    mask2_float = mask2.float()

    gates1_s = einsum("se,se->s", gates, mask1_float)
    gates2_s = einsum("se,se->s", gates, mask2_float)
    denom_s = gates1_s + gates2_s

    denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
    gates1_s /= denom_s
    gates2_s /= denom_s

    gates1 = einsum("s,se->se", gates1_s, mask1_float)
    gates2 = einsum("s,se->se", gates2_s, mask2_float)
    locations1_sc = _one_hot_to_float(locations1_s, capacity)
    locations2_sc = _one_hot_to_float(locations2_s, capacity)
    combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
    combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
    combine_weights = combine1_sec + combine2_sec
    dispatch_mask = combine_weights.bool()

    return l_aux, combine_weights, dispatch_mask, exp_counts


# Apply the modifications to deepspeed
deepspeed.moe.sharded_moe.MOELayer.forward = gate_forward
deepspeed.moe.sharded_moe.top2gating = top2gating
deepspeed.moe.sharded_moe._AllToAll.forward = _AllToAll_forward