sohv commited on
Commit
f52cfc5
·
verified ·
1 Parent(s): 8ad8b73

Upload src/moe.py

Browse files
Files changed (1) hide show
  1. src/moe.py +163 -0
src/moe.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mixture of Experts (MoE) Implementation for nanoKimi
3
+
4
+ This module implements the MoE layer used in Kimi-K2, which allows
5
+ for efficient scaling by routing tokens to different expert networks.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class MoELayer(nn.Module):
14
+ """
15
+ Mixture of Experts Layer
16
+
17
+ Routes input tokens to different expert networks based on a learned gating function.
18
+ Only the top-k experts are activated for each token, making the computation sparse.
19
+
20
+ Args:
21
+ n_embd: embedding dimension
22
+ num_experts: number of expert networks
23
+ expert_capacity: capacity of each expert (max tokens per expert)
24
+ top_k: number of experts to route each token to
25
+ dropout: dropout probability
26
+ bias: whether to use bias in linear layers
27
+ """
28
+
29
+ def __init__(self, n_embd, num_experts=8, expert_capacity=32, top_k=2, dropout=0.0, bias=True):
30
+ super().__init__()
31
+
32
+ self.n_embd = n_embd
33
+ self.num_experts = num_experts
34
+ self.expert_capacity = expert_capacity
35
+ self.top_k = top_k
36
+
37
+ # Gating network - decides which experts to use
38
+ self.gate = nn.Linear(n_embd, num_experts, bias=bias)
39
+
40
+ # Expert networks - simple FFN for each expert
41
+ self.experts = nn.ModuleList([
42
+ ExpertFFN(n_embd, dropout=dropout, bias=bias)
43
+ for _ in range(num_experts)
44
+ ])
45
+
46
+ # Load balancing loss coefficient
47
+ self.load_balance_loss_coef = 0.01
48
+
49
+ def forward(self, x):
50
+ B, T, C = x.shape
51
+
52
+ # Flatten to (B*T, C) for easier processing
53
+ x_flat = x.view(-1, C)
54
+
55
+ # Compute gating scores
56
+ gate_logits = self.gate(x_flat) # (B*T, num_experts)
57
+ gate_scores = F.softmax(gate_logits, dim=-1)
58
+
59
+ # Select top-k experts for each token
60
+ top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1)
61
+
62
+ # Normalize top-k scores
63
+ top_k_scores = top_k_scores / top_k_scores.sum(dim=-1, keepdim=True)
64
+
65
+ # Initialize output
66
+ output = torch.zeros_like(x_flat)
67
+
68
+ # Process each expert
69
+ for expert_idx in range(self.num_experts):
70
+ # Find tokens assigned to this expert
71
+ expert_mask = (top_k_indices == expert_idx).any(dim=-1)
72
+
73
+ if expert_mask.sum() == 0:
74
+ continue
75
+
76
+ # Get tokens for this expert
77
+ expert_tokens = x_flat[expert_mask]
78
+
79
+ # Apply capacity constraint
80
+ if expert_tokens.size(0) > self.expert_capacity:
81
+ # Random sampling if too many tokens
82
+ perm = torch.randperm(expert_tokens.size(0))[:self.expert_capacity]
83
+ expert_tokens = expert_tokens[perm]
84
+ expert_mask_indices = torch.where(expert_mask)[0][perm]
85
+ else:
86
+ expert_mask_indices = torch.where(expert_mask)[0]
87
+
88
+ # Process through expert
89
+ expert_output = self.experts[expert_idx](expert_tokens)
90
+
91
+ # Weight by gating scores and add to output
92
+ for i, token_idx in enumerate(expert_mask_indices):
93
+ # Find which position in top_k this expert is for this token
94
+ expert_positions = (top_k_indices[token_idx] == expert_idx).nonzero(as_tuple=True)[0]
95
+ if len(expert_positions) > 0:
96
+ weight = top_k_scores[token_idx, expert_positions[0]]
97
+ output[token_idx] += weight * expert_output[i]
98
+
99
+ # Reshape back to original shape
100
+ output = output.view(B, T, C)
101
+
102
+ # Compute load balancing loss
103
+ load_balance_loss = self._compute_load_balance_loss(gate_scores)
104
+
105
+ return output, load_balance_loss
106
+
107
+ def _compute_load_balance_loss(self, gate_scores):
108
+ """
109
+ Compute load balancing loss to encourage equal usage of experts
110
+ """
111
+ # Compute the fraction of tokens routed to each expert
112
+ expert_usage = gate_scores.mean(dim=0) # (num_experts,)
113
+
114
+ # Target is uniform distribution
115
+ target_usage = torch.ones_like(expert_usage) / self.num_experts
116
+
117
+ # L2 loss between actual and target usage
118
+ load_balance_loss = F.mse_loss(expert_usage, target_usage)
119
+
120
+ return self.load_balance_loss_coef * load_balance_loss
121
+
122
+
123
+ class ExpertFFN(nn.Module):
124
+ """
125
+ Expert Feed-Forward Network
126
+
127
+ A simple two-layer MLP that serves as an expert in the MoE layer.
128
+ """
129
+
130
+ def __init__(self, n_embd, dropout=0.0, bias=True):
131
+ super().__init__()
132
+
133
+ # Typical GPT-style FFN with 4x expansion
134
+ self.fc1 = nn.Linear(n_embd, 4 * n_embd, bias=bias)
135
+ self.fc2 = nn.Linear(4 * n_embd, n_embd, bias=bias)
136
+ self.dropout = nn.Dropout(dropout)
137
+
138
+ def forward(self, x):
139
+ x = self.fc1(x)
140
+ x = F.gelu(x)
141
+ x = self.fc2(x)
142
+ x = self.dropout(x)
143
+ return x
144
+
145
+
146
+ class StandardFFN(nn.Module):
147
+ """
148
+ Standard Feed-Forward Network for comparison with MoE
149
+ """
150
+
151
+ def __init__(self, n_embd, dropout=0.0, bias=True):
152
+ super().__init__()
153
+
154
+ self.fc1 = nn.Linear(n_embd, 4 * n_embd, bias=bias)
155
+ self.fc2 = nn.Linear(4 * n_embd, n_embd, bias=bias)
156
+ self.dropout = nn.Dropout(dropout)
157
+
158
+ def forward(self, x):
159
+ x = self.fc1(x)
160
+ x = F.gelu(x)
161
+ x = self.fc2(x)
162
+ x = self.dropout(x)
163
+ return x, 0.0 # Return 0 load balance loss for consistency