robinfaro commited on
Commit
84ca0f0
·
verified ·
1 Parent(s): 0455e1f

Adding files from hf_modeling_btm_log_prob_mixing

Browse files
Files changed (1) hide show
  1. moe.py +134 -0
moe.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple MoE routing implementations that replace the MLP block in a standard transformer.
3
+ References:
4
+ 1) Mistral Source for Mixtral MoEs:
5
+ https://github.com/mistralai/mistral-src
6
+ 2) ST-MoE:
7
+ https://arxiv.org/abs/2202.08906
8
+ 3) Our notepad of MoE resources:
9
+ https://docs.google.com/document/d/1NuQ5jr7V-Jv1ui7p4KrxO_JTz-7bpYcYMmh49EeJ-QA/edit?usp=sharing
10
+ """
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import bisect
16
+ import math
17
+
18
+ class MoE(nn.Module):
19
+ """
20
+ Simplest MoE implementation with a linear router and softmax over experts.
21
+
22
+ Note that in this implementation, we simply loop over the experts and
23
+ aggregate the results. This is not the most efficient way to do it, but
24
+ it also avoids the large memory overhead _and_ has no token dropping
25
+ (because we do not need the capacity factor).
26
+ """
27
+
28
+ def __init__(self, config, mlp):
29
+ super().__init__()
30
+ assert config.moe_num_experts > 0
31
+ self.experts = nn.ModuleList(
32
+ [mlp(config=config) for _ in range(config.moe_num_experts)]
33
+ )
34
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False)
35
+ self.top_k = config.moe_num_experts_per_tok
36
+ self.softmax_order = config.moe_softmax_order
37
+
38
+ def forward(self, inputs: torch.Tensor):
39
+ # [batch_size * sequence_length, n_embd]
40
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
41
+ # [batch_size * sequence_length, num_experts]
42
+ router_logits = self.router(inputs_squashed)
43
+
44
+ # note that selected experts will be the same for all orders:
45
+ # softmax doesnt change top-k, but the weights are different
46
+ if self.softmax_order == "softmax_topk":
47
+ all_probs = F.softmax(router_logits, dim=1)
48
+ weights, selected_experts = torch.topk(all_probs, self.top_k)
49
+ elif self.softmax_order == "topk_softmax":
50
+ weights, selected_experts = torch.topk(router_logits, self.top_k)
51
+ weights = F.softmax(weights, dim=-1)
52
+ else:
53
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
54
+
55
+ results = torch.zeros_like(inputs_squashed)
56
+ # naive looping over experts
57
+ for i, expert in enumerate(self.experts):
58
+ batch_idx, nth_expert = torch.where(selected_experts == i)
59
+ expert_input = inputs_squashed[batch_idx] #+ self.attn(inputs_squashed[batch_idx])
60
+ output, _ = expert(expert_input)
61
+ results[batch_idx] += weights[batch_idx, nth_expert, None] * output.squeeze(0)
62
+
63
+ # return results and router logits (for aux loss calculation later)
64
+ return results.view_as(inputs), {
65
+ "router_logits": router_logits,
66
+ "selected_experts": selected_experts,
67
+ }
68
+
69
+
70
+ class DummyExpert(nn.Module):
71
+ def __init__(self, output_size: int):
72
+ super().__init__()
73
+ self._output_size = output_size
74
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
75
+ out = torch.zeros((self._output_size,), device=inputs.device)
76
+ return out, {}
77
+
78
+
79
+
80
+ class MaskedMoE(MoE):
81
+ def __init__(self, config, mlp):
82
+ super().__init__(config, mlp)
83
+ self._sequence_length = config.sequence_length
84
+ self.experts.append(DummyExpert(config.n_embd))
85
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts+1, bias=False)
86
+
87
+
88
+ def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
89
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
90
+ router_logits = self.router(inputs_squashed)
91
+ mask = torch.cat(
92
+ (mask, torch.ones((mask.shape[0], 1), device=mask.device)),
93
+ dim=1
94
+ )
95
+ mask = mask.repeat_interleave(self._sequence_length, dim=0)
96
+ router_logits = router_logits*mask
97
+
98
+ # note that selected experts will be the same for all orders:
99
+ # softmax doesnt change top-k, but the weights are different
100
+ if self.softmax_order == "softmax_topk":
101
+ all_probs = F.softmax(router_logits, dim=1)
102
+ weights, selected_experts = torch.topk(all_probs, self.top_k)
103
+ elif self.softmax_order == "topk_softmax":
104
+ weights, selected_experts = torch.topk(router_logits, self.top_k)
105
+ weights = F.softmax(weights, dim=-1)
106
+ else:
107
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
108
+
109
+ results = torch.zeros_like(inputs_squashed)
110
+ # naive looping over experts
111
+ for i, expert in enumerate(self.experts):
112
+ batch_idx, nth_expert = torch.where(selected_experts == i)
113
+ expert_input = inputs_squashed[batch_idx]
114
+ output, _ = expert(expert_input)
115
+ results[batch_idx] += weights[batch_idx, nth_expert, None] * output.squeeze(0)
116
+
117
+ # return results and router logits (for aux loss calculation later)
118
+ return results.view_as(inputs), {
119
+ "router_logits": router_logits,
120
+ "selected_experts": selected_experts,
121
+ }
122
+
123
+
124
+ class TimeDependantMoE(nn.Module):
125
+ def __init__(self, config, mlp):
126
+ super().__init__()
127
+ self._num_experts = config.moe_num_experts
128
+ self._mask_moe = MaskedMoE(config, mlp)
129
+
130
+ def forward(self, x, date):
131
+ mask_date = torch.zeros(x.shape[0], self._num_experts).to(x.device)
132
+ range_tensor = torch.arange(self._num_experts).unsqueeze(0).to(x.device)
133
+ mask_date = (range_tensor < date.unsqueeze(1)).float()
134
+ return self._mask_moe(x, mask_date)