szxllm commited on
Commit
958b4f3
·
verified ·
1 Parent(s): d16a3f0

Update moe.py

Browse files
Files changed (1) hide show
  1. moe.py +322 -459
moe.py CHANGED
@@ -1,460 +1,323 @@
1
- """
2
- 优化的混合专家系统 (Mixture of Experts)
3
- 基于Mixtral、Switch Transformer、GLaM的最佳实践
4
- """
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from typing import Tuple, Optional, List
9
- import math
10
-
11
- class Expert(nn.Module):
12
- """
13
- 单个专家网络
14
- 使用SwiGLU激活函数以获得更好的性能
15
- """
16
-
17
- def __init__(
18
- self,
19
- dim: int,
20
- hidden_dim: int,
21
- dropout: float = 0.0,
22
- bias: bool = False
23
- ):
24
- super().__init__()
25
- self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
26
- self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
27
- self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
28
- self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
29
-
30
- self._init_weights()
31
-
32
- def _init_weights(self):
33
- """改进的权重初始化"""
34
- for module in [self.w1, self.w2, self.w3]:
35
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
36
- if module.bias is not None:
37
- nn.init.zeros_(module.bias)
38
-
39
- def forward(self, x: torch.Tensor) -> torch.Tensor:
40
- """
41
- 前向传播
42
- SwiGLU: (Swish(W1·x) W3·x) W2
43
- """
44
- return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
45
-
46
- class TopKRouter(nn.Module):
47
- """
48
- Top-K路由器 - 改进版
49
- 改进点:
50
- 1. 专家容量管理
51
- 2. 负载均衡
52
- 3. 训练时的噪声注入
53
- 4. Z-loss防止logits爆炸
54
-
55
- 参考:
56
- - Switch Transformer
57
- - Mixtral 8x7B
58
- - ST-MoE
59
- """
60
-
61
- def __init__(
62
- self,
63
- dim: int,
64
- num_experts: int,
65
- top_k: int = 2,
66
- capacity_factor: float = 1.25,
67
- noise_std: float = 1.0,
68
- use_expert_capacity: bool = True,
69
- router_z_loss_coef: float = 0.001,
70
- router_aux_loss_coef: float = 0.01
71
- ):
72
- super().__init__()
73
- self.num_experts = num_experts
74
- self.top_k = top_k
75
- self.capacity_factor = capacity_factor
76
- self.noise_std = noise_std
77
- self.use_expert_capacity = use_expert_capacity
78
- self.router_z_loss_coef = router_z_loss_coef
79
- self.router_aux_loss_coef = router_aux_loss_coef
80
-
81
- self.gate = nn.Linear(dim, num_experts, bias=False)
82
-
83
- nn.init.normal_(self.gate.weight, mean=0.0, std=0.02)
84
-
85
- def _compute_routing_weights(
86
- self,
87
- logits: torch.Tensor,
88
- use_noise: bool = True
89
- ) -> Tuple[torch.Tensor, torch.Tensor]:
90
- """
91
- 计算路由权重
92
-
93
- Args:
94
- logits: 路由logits [batch*seq_len, num_experts]
95
- use_noise: 是否添加噪声
96
-
97
- Returns:
98
- top_k_gates: Top-K门控值 [batch*seq_len, top_k]
99
- top_k_indices: Top-K专家索引 [batch*seq_len, top_k]
100
- """
101
- if use_noise and self.training:
102
- noise = torch.randn_like(logits) * self.noise_std
103
- logits = logits + noise
104
-
105
- top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
106
-
107
- top_k_gates = F.softmax(top_k_logits, dim=-1)
108
-
109
- return top_k_gates, top_k_indices
110
-
111
- def _compute_auxiliary_loss(
112
- self,
113
- logits: torch.Tensor,
114
- top_k_indices: torch.Tensor
115
- ) -> Tuple[torch.Tensor, torch.Tensor]:
116
- """
117
- 计算辅助损失
118
-
119
- 包括:
120
- 1. 负载均衡损失(确保专家被均匀使用)
121
- 2. Z-loss(防止logits过大)
122
-
123
- Args:
124
- logits: 路由logits [batch*seq_len, num_experts]
125
- top_k_indices: 选中的专家索引 [batch*seq_len, top_k]
126
-
127
- Returns:
128
- load_balance_loss: 负载均衡损失
129
- z_loss: Z-loss
130
- """
131
- num_tokens = logits.shape[0]
132
-
133
- router_probs = F.softmax(logits, dim=-1)
134
-
135
- expert_probs = router_probs.mean(dim=0)
136
-
137
- expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
138
- expert_freq = expert_mask.sum(dim=[0, 1]) / (num_tokens * self.top_k)
139
-
140
- load_balance_loss = self.num_experts * torch.sum(expert_probs * expert_freq)
141
-
142
- z_loss = torch.mean(logits ** 2)
143
-
144
- return load_balance_loss, z_loss
145
-
146
- def forward(
147
- self,
148
- x: torch.Tensor
149
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
150
- """
151
- 前向传播
152
-
153
- Args:
154
- x: 输入 [batch*seq_len, dim]
155
-
156
- Returns:
157
- top_k_gates: 门控权重 [batch*seq_len, top_k]
158
- top_k_indices: 专家索引 [batch*seq_len, top_k]
159
- auxiliary_loss: 辅助损失(标量)
160
- """
161
- logits = self.gate(x)
162
-
163
- top_k_gates, top_k_indices = self._compute_routing_weights(
164
- logits, use_noise=self.training
165
- )
166
-
167
- if self.training:
168
- load_balance_loss, z_loss = self._compute_auxiliary_loss(logits, top_k_indices)
169
- auxiliary_loss = (
170
- self.router_aux_loss_coef * load_balance_loss +
171
- self.router_z_loss_coef * z_loss
172
- )
173
- else:
174
- auxiliary_loss = torch.tensor(0.0, device=x.device)
175
-
176
- return top_k_gates, top_k_indices, auxiliary_loss
177
-
178
- class MixtureOfExperts(nn.Module):
179
- """
180
- 混合专家层 - 优化版
181
- 改进点:
182
- 1. 高效的token分发和聚合
183
- 2. 专家容量管理
184
- 3. 改进的负载均衡
185
- 4. 支持专家并行
186
-
187
- 参考:
188
- - Mixtral 8x7B
189
- - Switch Transformer
190
- - GShard
191
- """
192
-
193
- def __init__(
194
- self,
195
- dim: int,
196
- num_experts: int = 8,
197
- expert_hidden_dim: Optional[int] = None,
198
- top_k: int = 2,
199
- dropout: float = 0.0,
200
- capacity_factor: float = 1.25,
201
- use_expert_capacity: bool = True,
202
- router_z_loss_coef: float = 0.001,
203
- router_aux_loss_coef: float = 0.01,
204
- noise_std: float = 1.0,
205
- ffn_dim_multiplier: Optional[float] = None
206
- ):
207
- super().__init__()
208
- self.num_experts = num_experts
209
- self.top_k = top_k
210
- self.capacity_factor = capacity_factor
211
- self.use_expert_capacity = use_expert_capacity
212
-
213
- if expert_hidden_dim is None:
214
- if ffn_dim_multiplier is not None:
215
- expert_hidden_dim = int(dim * ffn_dim_multiplier)
216
- else:
217
- expert_hidden_dim = int(2 * dim * 4 / 3)
218
- expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
219
-
220
- self.experts = nn.ModuleList([
221
- Expert(dim, expert_hidden_dim, dropout, bias=False)
222
- for _ in range(num_experts)
223
- ])
224
-
225
- self.router = TopKRouter(
226
- dim=dim,
227
- num_experts=num_experts,
228
- top_k=top_k,
229
- capacity_factor=capacity_factor,
230
- noise_std=noise_std,
231
- use_expert_capacity=use_expert_capacity,
232
- router_z_loss_coef=router_z_loss_coef,
233
- router_aux_loss_coef=router_aux_loss_coef
234
- )
235
-
236
- def _compute_expert_capacity(self, num_tokens: int) -> int:
237
- """计算每个专家的容量"""
238
- if not self.use_expert_capacity:
239
- return num_tokens
240
-
241
- capacity = int(
242
- (num_tokens / self.num_experts) * self.capacity_factor * self.top_k
243
- )
244
- return max(capacity, 1)
245
-
246
- def forward(
247
- self,
248
- x: torch.Tensor
249
- ) -> Tuple[torch.Tensor, torch.Tensor]:
250
- """
251
- 前向传播
252
-
253
- Args:
254
- x: 输入 [batch, seq_len, dim]
255
-
256
- Returns:
257
- output: 输出 [batch, seq_len, dim]
258
- auxiliary_loss: 辅助损失
259
- """
260
- B, T, D = x.shape
261
- num_tokens = B * T
262
-
263
- x_flat = x.view(-1, D)
264
-
265
- top_k_gates, top_k_indices, auxiliary_loss = self.router(x_flat)
266
-
267
- output = torch.zeros_like(x_flat)
268
-
269
- expert_capacity = self._compute_expert_capacity(num_tokens)
270
-
271
- for expert_idx, expert in enumerate(self.experts):
272
- expert_mask = (top_k_indices == expert_idx)
273
-
274
- token_indices, topk_positions = torch.where(expert_mask)
275
-
276
- if len(token_indices) == 0:
277
- continue
278
-
279
- if self.use_expert_capacity and len(token_indices) > expert_capacity:
280
- perm = torch.randperm(len(token_indices), device=x.device)[:expert_capacity]
281
- token_indices = token_indices[perm]
282
- topk_positions = topk_positions[perm]
283
-
284
- expert_input = x_flat[token_indices]
285
- expert_gates = top_k_gates[token_indices, topk_positions]
286
-
287
- expert_output = expert(expert_input)
288
-
289
- expert_output = expert_output * expert_gates.unsqueeze(-1)
290
-
291
- output.index_add_(0, token_indices, expert_output)
292
-
293
- output = output.view(B, T, D)
294
-
295
- return output, auxiliary_loss
296
-
297
- class SparseDispatcher:
298
- """
299
- 稀疏分发器 - 用于高效的MoE计算
300
- 管理tokens到专家的分配和聚合
301
- 这是一个可选的辅助类,用于更高效的实现
302
- """
303
-
304
- def __init__(
305
- self,
306
- num_experts: int,
307
- gates: torch.Tensor,
308
- expert_indices: torch.Tensor
309
- ):
310
- """
311
- Args:
312
- num_experts: 专家数量
313
- gates: 门控权重 [batch_size, num_experts]
314
- expert_indices: 专家索引 [batch_size]
315
- """
316
- self.num_experts = num_experts
317
- self._gates = gates
318
- self._expert_indices = expert_indices
319
-
320
- self._expert_masks = []
321
- for i in range(num_experts):
322
- self._expert_masks.append((expert_indices == i).nonzero(as_tuple=True)[0])
323
-
324
- def dispatch(self, inp: torch.Tensor) -> List[torch.Tensor]:
325
- """
326
- 将输入分发给各个专家
327
-
328
- Args:
329
- inp: 输入张量 [batch_size, dim]
330
-
331
- Returns:
332
- expert_inputs: 每个专家的输入列表
333
- """
334
- expert_inputs = []
335
- for mask in self._expert_masks:
336
- if len(mask) > 0:
337
- expert_inputs.append(inp[mask])
338
- else:
339
- expert_inputs.append(
340
- torch.empty(0, inp.size(-1), device=inp.device, dtype=inp.dtype)
341
- )
342
- return expert_inputs
343
-
344
- def combine(self, expert_outputs: List[torch.Tensor]) -> torch.Tensor:
345
- """
346
- 组合专家输出
347
-
348
- Args:
349
- expert_outputs: 每个专家的输出列表
350
-
351
- Returns:
352
- output: 组合后的输出 [batch_size, dim]
353
- """
354
- output_shape = (self._gates.size(0), expert_outputs[0].size(-1))
355
- output = torch.zeros(
356
- output_shape,
357
- device=self._gates.device,
358
- dtype=expert_outputs[0].dtype
359
- )
360
-
361
- for expert_idx, expert_out in enumerate(expert_outputs):
362
- mask = self._expert_masks[expert_idx]
363
- if len(mask) > 0:
364
- weighted_output = expert_out * self._gates[mask, expert_idx].unsqueeze(-1)
365
- output[mask] += weighted_output
366
-
367
- return output
368
-
369
- def expert_to_gates(self) -> List[torch.Tensor]:
370
- """
371
- 返回每个专家对应的门控权重
372
-
373
- Returns:
374
- gates_per_expert: 每个专家的门控权重列表
375
- """
376
- gates_per_expert = []
377
- for expert_idx in range(self.num_experts):
378
- mask = self._expert_masks[expert_idx]
379
- if len(mask) > 0:
380
- gates_per_expert.append(self._gates[mask, expert_idx])
381
- else:
382
- gates_per_expert.append(torch.empty(0, device=self._gates.device))
383
- return gates_per_expert
384
-
385
- class MoELayer(nn.Module):
386
- """
387
- MoE层的另一种实现方式
388
- 使用SparseDispatcher进行更高效的计算
389
- """
390
- def __init__(
391
- self,
392
- dim: int,
393
- num_experts: int = 8,
394
- expert_hidden_dim: Optional[int] = None,
395
- top_k: int = 2,
396
- dropout: float = 0.0,
397
- capacity_factor: float = 1.25
398
- ):
399
- super().__init__()
400
- self.num_experts = num_experts
401
- self.top_k = top_k
402
-
403
- if expert_hidden_dim is None:
404
- expert_hidden_dim = int(2 * dim * 4 / 3)
405
- expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
406
-
407
- self.experts = nn.ModuleList([
408
- Expert(dim, expert_hidden_dim, dropout)
409
- for _ in range(num_experts)
410
- ])
411
-
412
- self.gate = nn.Linear(dim, num_experts, bias=False)
413
- nn.init.normal_(self.gate.weight, std=0.02)
414
-
415
- self.capacity_factor = capacity_factor
416
-
417
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
418
- """
419
- 前向传播使用SparseDispatcher
420
-
421
- Args:
422
- x: 输入 [batch, seq_len, dim]
423
-
424
- Returns:
425
- output: 输出 [batch, seq_len, dim]
426
- aux_loss: 辅助损失
427
- """
428
- B, T, D = x.shape
429
- x_flat = x.view(-1, D)
430
-
431
- gates = F.softmax(self.gate(x_flat), dim=-1)
432
-
433
- top_k_gates, top_k_indices = torch.topk(gates, self.top_k, dim=-1)
434
- top_k_gates = F.softmax(top_k_gates, dim=-1)
435
-
436
- expert_probs = gates.mean(dim=0)
437
- expert_counts = F.one_hot(top_k_indices, self.num_experts).float().sum(dim=[0, 1])
438
- expert_counts = expert_counts / (B * T * self.top_k)
439
- aux_loss = self.num_experts * torch.sum(expert_probs * expert_counts)
440
-
441
- output = torch.zeros_like(x_flat)
442
-
443
- for expert_idx, expert in enumerate(self.experts):
444
- expert_mask = (top_k_indices == expert_idx)
445
- token_indices, topk_positions = torch.where(expert_mask)
446
-
447
- if len(token_indices) == 0:
448
- continue
449
-
450
- expert_input = x_flat[token_indices]
451
- expert_gates = top_k_gates[token_indices, topk_positions]
452
-
453
- expert_output = expert(expert_input)
454
- expert_output = expert_output * expert_gates.unsqueeze(-1)
455
-
456
- output.index_add_(0, token_indices, expert_output)
457
-
458
- output = output.view(B, T, D)
459
-
460
  return output, aux_loss
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Tuple, Optional, List
5
+ import math
6
+
7
+ class Expert(nn.Module):
8
+ def __init__(
9
+ self,
10
+ dim: int,
11
+ hidden_dim: int,
12
+ dropout: float = 0.0,
13
+ bias: bool = False
14
+ ):
15
+ super().__init__()
16
+ self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
17
+ self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
18
+ self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
19
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
20
+
21
+ self._init_weights()
22
+
23
+ def _init_weights(self):
24
+ """改进的权重初始化"""
25
+ for module in [self.w1, self.w2, self.w3]:
26
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
27
+ if module.bias is not None:
28
+ nn.init.zeros_(module.bias)
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
32
+
33
+ class TopKRouter(nn.Module):
34
+ def __init__(
35
+ self,
36
+ dim: int,
37
+ num_experts: int,
38
+ top_k: int = 2,
39
+ capacity_factor: float = 1.25,
40
+ noise_std: float = 1.0,
41
+ use_expert_capacity: bool = True,
42
+ router_z_loss_coef: float = 0.001,
43
+ router_aux_loss_coef: float = 0.01
44
+ ):
45
+ super().__init__()
46
+ self.num_experts = num_experts
47
+ self.top_k = top_k
48
+ self.capacity_factor = capacity_factor
49
+ self.noise_std = noise_std
50
+ self.use_expert_capacity = use_expert_capacity
51
+ self.router_z_loss_coef = router_z_loss_coef
52
+ self.router_aux_loss_coef = router_aux_loss_coef
53
+
54
+ self.gate = nn.Linear(dim, num_experts, bias=False)
55
+
56
+ nn.init.normal_(self.gate.weight, mean=0.0, std=0.02)
57
+
58
+ def _compute_routing_weights(
59
+ self,
60
+ logits: torch.Tensor,
61
+ use_noise: bool = True
62
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ if use_noise and self.training:
64
+ noise = torch.randn_like(logits) * self.noise_std
65
+ logits = logits + noise
66
+
67
+ top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
68
+
69
+ top_k_gates = F.softmax(top_k_logits, dim=-1)
70
+
71
+ return top_k_gates, top_k_indices
72
+
73
+ def _compute_auxiliary_loss(
74
+ self,
75
+ logits: torch.Tensor,
76
+ top_k_indices: torch.Tensor
77
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ num_tokens = logits.shape[0]
79
+
80
+ router_probs = F.softmax(logits, dim=-1)
81
+
82
+ expert_probs = router_probs.mean(dim=0)
83
+
84
+ expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
85
+ expert_freq = expert_mask.sum(dim=[0, 1]) / (num_tokens * self.top_k)
86
+
87
+ load_balance_loss = self.num_experts * torch.sum(expert_probs * expert_freq)
88
+
89
+ z_loss = torch.mean(logits ** 2)
90
+
91
+ return load_balance_loss, z_loss
92
+
93
+ def forward(
94
+ self,
95
+ x: torch.Tensor
96
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
97
+ logits = self.gate(x)
98
+
99
+ top_k_gates, top_k_indices = self._compute_routing_weights(
100
+ logits, use_noise=self.training
101
+ )
102
+
103
+ if self.training:
104
+ load_balance_loss, z_loss = self._compute_auxiliary_loss(logits, top_k_indices)
105
+ auxiliary_loss = (
106
+ self.router_aux_loss_coef * load_balance_loss +
107
+ self.router_z_loss_coef * z_loss
108
+ )
109
+ else:
110
+ auxiliary_loss = torch.tensor(0.0, device=x.device)
111
+
112
+ return top_k_gates, top_k_indices, auxiliary_loss
113
+
114
+ class MixtureOfExperts(nn.Module):
115
+ def __init__(
116
+ self,
117
+ dim: int,
118
+ num_experts: int = 8,
119
+ expert_hidden_dim: Optional[int] = None,
120
+ top_k: int = 2,
121
+ dropout: float = 0.0,
122
+ capacity_factor: float = 1.25,
123
+ use_expert_capacity: bool = True,
124
+ router_z_loss_coef: float = 0.001,
125
+ router_aux_loss_coef: float = 0.01,
126
+ noise_std: float = 1.0,
127
+ ffn_dim_multiplier: Optional[float] = None
128
+ ):
129
+ super().__init__()
130
+ self.num_experts = num_experts
131
+ self.top_k = top_k
132
+ self.capacity_factor = capacity_factor
133
+ self.use_expert_capacity = use_expert_capacity
134
+
135
+ if expert_hidden_dim is None:
136
+ if ffn_dim_multiplier is not None:
137
+ expert_hidden_dim = int(dim * ffn_dim_multiplier)
138
+ else:
139
+ expert_hidden_dim = int(2 * dim * 4 / 3)
140
+ expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
141
+
142
+ self.experts = nn.ModuleList([
143
+ Expert(dim, expert_hidden_dim, dropout, bias=False)
144
+ for _ in range(num_experts)
145
+ ])
146
+
147
+ self.router = TopKRouter(
148
+ dim=dim,
149
+ num_experts=num_experts,
150
+ top_k=top_k,
151
+ capacity_factor=capacity_factor,
152
+ noise_std=noise_std,
153
+ use_expert_capacity=use_expert_capacity,
154
+ router_z_loss_coef=router_z_loss_coef,
155
+ router_aux_loss_coef=router_aux_loss_coef
156
+ )
157
+
158
+ def _compute_expert_capacity(self, num_tokens: int) -> int:
159
+ """计算每个专家的容量"""
160
+ if not self.use_expert_capacity:
161
+ return num_tokens
162
+
163
+ capacity = int(
164
+ (num_tokens / self.num_experts) * self.capacity_factor * self.top_k
165
+ )
166
+ return max(capacity, 1)
167
+
168
+ def forward(
169
+ self,
170
+ x: torch.Tensor
171
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
172
+ B, T, D = x.shape
173
+ num_tokens = B * T
174
+
175
+ x_flat = x.view(-1, D)
176
+
177
+ top_k_gates, top_k_indices, auxiliary_loss = self.router(x_flat)
178
+
179
+ output = torch.zeros_like(x_flat)
180
+
181
+ expert_capacity = self._compute_expert_capacity(num_tokens)
182
+
183
+ for expert_idx, expert in enumerate(self.experts):
184
+ expert_mask = (top_k_indices == expert_idx)
185
+
186
+ token_indices, topk_positions = torch.where(expert_mask)
187
+
188
+ if len(token_indices) == 0:
189
+ continue
190
+
191
+ if self.use_expert_capacity and len(token_indices) > expert_capacity:
192
+ perm = torch.randperm(len(token_indices), device=x.device)[:expert_capacity]
193
+ token_indices = token_indices[perm]
194
+ topk_positions = topk_positions[perm]
195
+
196
+ expert_input = x_flat[token_indices]
197
+ expert_gates = top_k_gates[token_indices, topk_positions]
198
+
199
+ expert_output = expert(expert_input)
200
+
201
+ expert_output = expert_output * expert_gates.unsqueeze(-1)
202
+
203
+ output.index_add_(0, token_indices, expert_output)
204
+
205
+ output = output.view(B, T, D)
206
+
207
+ return output, auxiliary_loss
208
+
209
+ class SparseDispatcher:
210
+ def __init__(
211
+ self,
212
+ num_experts: int,
213
+ gates: torch.Tensor,
214
+ expert_indices: torch.Tensor
215
+ ):
216
+
217
+ self.num_experts = num_experts
218
+ self._gates = gates
219
+ self._expert_indices = expert_indices
220
+
221
+ self._expert_masks = []
222
+ for i in range(num_experts):
223
+ self._expert_masks.append((expert_indices == i).nonzero(as_tuple=True)[0])
224
+
225
+ def dispatch(self, inp: torch.Tensor) -> List[torch.Tensor]:
226
+ expert_inputs = []
227
+ for mask in self._expert_masks:
228
+ if len(mask) > 0:
229
+ expert_inputs.append(inp[mask])
230
+ else:
231
+ expert_inputs.append(
232
+ torch.empty(0, inp.size(-1), device=inp.device, dtype=inp.dtype)
233
+ )
234
+ return expert_inputs
235
+
236
+ def combine(self, expert_outputs: List[torch.Tensor]) -> torch.Tensor:
237
+ output_shape = (self._gates.size(0), expert_outputs[0].size(-1))
238
+ output = torch.zeros(
239
+ output_shape,
240
+ device=self._gates.device,
241
+ dtype=expert_outputs[0].dtype
242
+ )
243
+
244
+ for expert_idx, expert_out in enumerate(expert_outputs):
245
+ mask = self._expert_masks[expert_idx]
246
+ if len(mask) > 0:
247
+ weighted_output = expert_out * self._gates[mask, expert_idx].unsqueeze(-1)
248
+ output[mask] += weighted_output
249
+
250
+ return output
251
+
252
+ def expert_to_gates(self) -> List[torch.Tensor]:
253
+ gates_per_expert = []
254
+ for expert_idx in range(self.num_experts):
255
+ mask = self._expert_masks[expert_idx]
256
+ if len(mask) > 0:
257
+ gates_per_expert.append(self._gates[mask, expert_idx])
258
+ else:
259
+ gates_per_expert.append(torch.empty(0, device=self._gates.device))
260
+ return gates_per_expert
261
+
262
+ class MoELayer(nn.Module):
263
+ def __init__(
264
+ self,
265
+ dim: int,
266
+ num_experts: int = 8,
267
+ expert_hidden_dim: Optional[int] = None,
268
+ top_k: int = 2,
269
+ dropout: float = 0.0,
270
+ capacity_factor: float = 1.25
271
+ ):
272
+ super().__init__()
273
+ self.num_experts = num_experts
274
+ self.top_k = top_k
275
+
276
+ if expert_hidden_dim is None:
277
+ expert_hidden_dim = int(2 * dim * 4 / 3)
278
+ expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
279
+
280
+ self.experts = nn.ModuleList([
281
+ Expert(dim, expert_hidden_dim, dropout)
282
+ for _ in range(num_experts)
283
+ ])
284
+
285
+ self.gate = nn.Linear(dim, num_experts, bias=False)
286
+ nn.init.normal_(self.gate.weight, std=0.02)
287
+
288
+ self.capacity_factor = capacity_factor
289
+
290
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
291
+ B, T, D = x.shape
292
+ x_flat = x.view(-1, D)
293
+
294
+ gates = F.softmax(self.gate(x_flat), dim=-1)
295
+
296
+ top_k_gates, top_k_indices = torch.topk(gates, self.top_k, dim=-1)
297
+ top_k_gates = F.softmax(top_k_gates, dim=-1)
298
+
299
+ expert_probs = gates.mean(dim=0)
300
+ expert_counts = F.one_hot(top_k_indices, self.num_experts).float().sum(dim=[0, 1])
301
+ expert_counts = expert_counts / (B * T * self.top_k)
302
+ aux_loss = self.num_experts * torch.sum(expert_probs * expert_counts)
303
+
304
+ output = torch.zeros_like(x_flat)
305
+
306
+ for expert_idx, expert in enumerate(self.experts):
307
+ expert_mask = (top_k_indices == expert_idx)
308
+ token_indices, topk_positions = torch.where(expert_mask)
309
+
310
+ if len(token_indices) == 0:
311
+ continue
312
+
313
+ expert_input = x_flat[token_indices]
314
+ expert_gates = top_k_gates[token_indices, topk_positions]
315
+
316
+ expert_output = expert(expert_input)
317
+ expert_output = expert_output * expert_gates.unsqueeze(-1)
318
+
319
+ output.index_add_(0, token_indices, expert_output)
320
+
321
+ output = output.view(B, T, D)
322
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  return output, aux_loss