caixiaoshun commited on
Commit
c8167b3
·
verified ·
1 Parent(s): 1897029

Create mini_moe.py

Browse files
Files changed (1) hide show
  1. mini_moe.py +374 -0
mini_moe.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from transformers import PretrainedConfig, PreTrainedModel
5
+
6
+
7
+ class MiniMoEConfig(PretrainedConfig):
8
+ model_type = "mini-moe"
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size=32000,
13
+ num_layers=12,
14
+ dim=1024,
15
+ rope_base=10000,
16
+ num_attention_q_heads=16,
17
+ num_attention_kv_heads=8,
18
+ num_expert=8,
19
+ top_k=4,
20
+ qkv_bias=False,
21
+ drop_rate=0.0,
22
+ use_aux_loss=True,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(**kwargs)
26
+ self.vocab_size = vocab_size
27
+ self.num_layers = num_layers
28
+ self.dim = dim
29
+ self.rope_base = rope_base
30
+ self.num_attention_q_heads = num_attention_q_heads
31
+ self.num_attention_kv_heads = num_attention_kv_heads
32
+ self.qkv_bias = qkv_bias
33
+ self.drop_rate = drop_rate
34
+ self.num_expert = num_expert
35
+ self.top_k = top_k
36
+ self.use_aux_loss = use_aux_loss
37
+ self.auto_map = {
38
+ "AutoConfig": "mini_moe.MiniMoEConfig",
39
+ "AutoModelForCausalLM": "mini_moe.MiniMoE",
40
+ }
41
+
42
+
43
+ class RMSNorm(nn.Module):
44
+ def __init__(self, dim):
45
+ super().__init__()
46
+ self.weight = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x: torch.Tensor):
49
+ norm_x = x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + 1e-8)
50
+ output = self.weight * norm_x
51
+ return output
52
+
53
+
54
+ class RopePositionEmbedding(nn.Module):
55
+ def __init__(self, dim: int, base=10000):
56
+ super().__init__()
57
+ inv_freq = 1 / base ** (torch.arange(0, dim, 2).float() / dim)
58
+ inv_freq = inv_freq.unsqueeze(0)
59
+ self.register_buffer("inv_freq", inv_freq)
60
+
61
+ def rotate_half(self, x: torch.Tensor):
62
+ odd = x[..., 1::2]
63
+ even = x[..., 0::2]
64
+ return torch.stack((-odd, even), dim=-1).flatten(-2)
65
+
66
+ def apply_rope(self, x: torch.Tensor):
67
+ x_len = x.shape[2]
68
+ t = torch.arange(0, x_len, device=x.device, dtype=torch.float32).unsqueeze(1)
69
+ freq = t * self.inv_freq
70
+ freq = torch.repeat_interleave(freq, repeats=2, dim=-1)[None, None, :, :]
71
+ xf = x.float()
72
+ y = xf * freq.cos() + self.rotate_half(xf) * freq.sin()
73
+ return y.to(x.dtype)
74
+
75
+ def forward(self, q: torch.Tensor, k: torch.Tensor):
76
+ return self.apply_rope(q), self.apply_rope(k)
77
+
78
+
79
+ class GroupQueryAttention(nn.Module):
80
+ def __init__(
81
+ self,
82
+ num_attention_q_heads,
83
+ num_attention_kv_heads,
84
+ dim,
85
+ qkv_bias,
86
+ drop_rate,
87
+ rope_base,
88
+ ):
89
+ super().__init__()
90
+
91
+ self.head_dim = dim // num_attention_q_heads
92
+
93
+ assert dim % num_attention_q_heads == 0, "dim 必须被 Q 头数整除"
94
+ assert (
95
+ num_attention_q_heads % num_attention_kv_heads == 0
96
+ ), "Q头数必须是KV头数的整数倍"
97
+ assert self.head_dim % 2 == 0, "head_dim 必须为偶数以应用 RoPE"
98
+
99
+ self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
100
+ self.k_proj = nn.Linear(
101
+ dim, self.head_dim * num_attention_kv_heads, bias=qkv_bias
102
+ )
103
+ self.v_proj = nn.Linear(
104
+ dim, self.head_dim * num_attention_kv_heads, bias=qkv_bias
105
+ )
106
+ self.out_proj = nn.Linear(dim, dim, bias=qkv_bias)
107
+
108
+ self.num_repeat_kv = num_attention_q_heads // num_attention_kv_heads
109
+ self.drop = nn.Dropout(drop_rate)
110
+
111
+ self.position_embedding = RopePositionEmbedding(self.head_dim, rope_base)
112
+
113
+ self.num_attention_q_heads = num_attention_q_heads
114
+ self.num_attention_kv_heads = num_attention_kv_heads
115
+ self.drop_rate = drop_rate
116
+
117
+ def repeat_kv(self, k: torch.Tensor, v: torch.Tensor):
118
+ k = k.repeat_interleave(self.num_repeat_kv, dim=1)
119
+ v = v.repeat_interleave(self.num_repeat_kv, dim=1)
120
+ return k, v
121
+
122
+ def forward(self, x: torch.Tensor):
123
+ batch_size, seq_len, dim = x.shape
124
+ Q = (
125
+ self.q_proj(x)
126
+ .reshape(batch_size, seq_len, self.num_attention_q_heads, self.head_dim)
127
+ .transpose(1, 2)
128
+ )
129
+ K = (
130
+ self.k_proj(x)
131
+ .reshape(batch_size, seq_len, self.num_attention_kv_heads, self.head_dim)
132
+ .transpose(1, 2)
133
+ )
134
+ V = (
135
+ self.v_proj(x)
136
+ .reshape(batch_size, seq_len, self.num_attention_kv_heads, self.head_dim)
137
+ .transpose(1, 2)
138
+ )
139
+
140
+ Q, K = self.position_embedding(Q, K)
141
+
142
+ K, V = self.repeat_kv(K, V)
143
+
144
+ out = F.scaled_dot_product_attention(
145
+ Q, K, V, dropout_p=self.drop_rate if self.training else 0.0, is_causal=True
146
+ )
147
+ out = out.transpose(1, 2).reshape(batch_size, seq_len, dim)
148
+ out = self.out_proj(out)
149
+ out = self.drop(out)
150
+ return out
151
+
152
+
153
+ class Expert(nn.Module):
154
+ def __init__(self, dim, drop_rate):
155
+ super().__init__()
156
+ self.ffn = nn.Sequential(
157
+ nn.Linear(dim, dim * 4),
158
+ nn.SiLU(),
159
+ nn.Linear(dim * 4, dim),
160
+ nn.Dropout(drop_rate),
161
+ )
162
+
163
+ def forward(self, x):
164
+ return self.ffn(x)
165
+
166
+
167
+ class NoiseRouter(nn.Module):
168
+ def __init__(self, num_expert, top_k, dim):
169
+ super().__init__()
170
+ self.gate = nn.Linear(dim, num_expert)
171
+ self.noise_gate = nn.Linear(dim, num_expert)
172
+ self.top_k = top_k
173
+
174
+ def forward(self, x):
175
+ gate = self.gate(x)
176
+ logits = gate + torch.randn_like(gate) + self.noise_gate(x)
177
+
178
+ top_k_val, top_k_ids = torch.topk(logits, k=self.top_k, dim=-1)
179
+ scores = torch.full_like(logits, -torch.inf)
180
+ scores.scatter_(dim=-1, index=top_k_ids, src=top_k_val)
181
+ scores = scores.softmax(dim=-1)
182
+ return scores, top_k_ids
183
+
184
+
185
+ class SparseMoe(nn.Module):
186
+ def __init__(self, num_expert, top_k, dim, drop_rate, use_aux_loss=True):
187
+ super().__init__()
188
+ self.route = NoiseRouter(num_expert=num_expert, top_k=top_k, dim=dim)
189
+ self.experts = nn.ModuleList(
190
+ [Expert(dim=dim, drop_rate=drop_rate) for _ in range(num_expert)]
191
+ )
192
+ self.use_aux_loss = use_aux_loss
193
+ self.num_expert = num_expert
194
+
195
+ def forward(self, x: torch.Tensor):
196
+ batch_size, seq_len, dim = x.shape
197
+
198
+ scores, indices = self.route(x)
199
+ flatten_x = x.reshape(-1, dim)
200
+ flatten_scores = scores.reshape(-1, scores.shape[-1])
201
+
202
+ final_out = torch.zeros_like(flatten_x)
203
+
204
+ for i, expert in enumerate(self.experts):
205
+ expert_mask = (indices == i).any(dim=-1)
206
+ expert_mask = expert_mask.reshape(-1)
207
+ if expert_mask.any():
208
+ expert_in = flatten_x[expert_mask]
209
+ expert_out = expert(expert_in)
210
+ expert_weight = flatten_scores[expert_mask, i].unsqueeze(1)
211
+ expert_out = expert_weight * expert_out
212
+
213
+ final_out[expert_mask] += expert_out
214
+
215
+ final_out = final_out.reshape(batch_size, seq_len, dim)
216
+
217
+ if self.use_aux_loss:
218
+ importance = flatten_scores.mean(dim=0).float()
219
+ uniform = torch.full_like(importance, fill_value=1.0 / self.num_expert).float()
220
+
221
+ importance_log = (importance + 1e-8).log()
222
+ uniform_log = uniform.log()
223
+
224
+ aux_loss = F.kl_div(
225
+ input=importance_log, target=uniform_log, log_target=True, reduction="sum"
226
+ )
227
+ return final_out, aux_loss
228
+ return final_out
229
+
230
+
231
+ class DecoderLayer(nn.Module):
232
+ def __init__(
233
+ self,
234
+ num_attention_q_heads,
235
+ num_attention_kv_heads,
236
+ dim,
237
+ qkv_bias,
238
+ drop_rate,
239
+ rope_base,
240
+ num_expert,
241
+ top_k,
242
+ use_aux_loss,
243
+ ):
244
+ super().__init__()
245
+ self.norm1 = RMSNorm(dim=dim)
246
+ self.attn = GroupQueryAttention(
247
+ num_attention_q_heads=num_attention_q_heads,
248
+ num_attention_kv_heads=num_attention_kv_heads,
249
+ dim=dim,
250
+ qkv_bias=qkv_bias,
251
+ drop_rate=drop_rate,
252
+ rope_base=rope_base,
253
+ )
254
+ self.norm2 = RMSNorm(dim=dim)
255
+ self.moe = SparseMoe(
256
+ num_expert=num_expert,
257
+ top_k=top_k,
258
+ dim=dim,
259
+ drop_rate=drop_rate,
260
+ use_aux_loss=use_aux_loss,
261
+ )
262
+ self.use_aux_loss = use_aux_loss
263
+
264
+ def forward(self, x):
265
+ x = x + self.attn(self.norm1(x))
266
+ hidden_state = self.moe(self.norm2(x))
267
+ if self.use_aux_loss:
268
+ x = x + hidden_state[0]
269
+ aux_loss = hidden_state[1]
270
+
271
+ return x, aux_loss
272
+ else:
273
+ x = x + hidden_state
274
+ return x
275
+
276
+
277
+ class MiniMoE(PreTrainedModel):
278
+ model_type = "mini-moe"
279
+ config_class = MiniMoEConfig
280
+
281
+ def __init__(self, config: MiniMoEConfig, pretrain_ckpt=None):
282
+ super().__init__(config)
283
+ self.embedding = nn.Embedding(config.vocab_size, config.dim)
284
+ self.layers = nn.ModuleList([])
285
+ for _ in range(config.num_layers):
286
+ self.layers.append(
287
+ DecoderLayer(
288
+ num_attention_q_heads=config.num_attention_q_heads,
289
+ num_attention_kv_heads=config.num_attention_kv_heads,
290
+ dim=config.dim,
291
+ qkv_bias=config.qkv_bias,
292
+ drop_rate=config.drop_rate,
293
+ rope_base=config.rope_base,
294
+ num_expert=config.num_expert,
295
+ top_k=config.top_k,
296
+ use_aux_loss=config.use_aux_loss,
297
+ )
298
+ )
299
+ self.norm = RMSNorm(dim=config.dim)
300
+ self.head = nn.Linear(config.dim, config.vocab_size, bias=False)
301
+ self.apply(self.init_weight)
302
+ self.head.weight = self.embedding.weight
303
+ self.use_aux_loss = config.use_aux_loss
304
+ if pretrain_ckpt is not None:
305
+ self.load_ckpt(pretrain_ckpt)
306
+
307
+ def load_ckpt(self, ckpt_path):
308
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
309
+ state_dict = ckpt["state_dict"]
310
+ new_state_dict = {}
311
+ for k, v in state_dict.items():
312
+ new_k = k[len("net._orig_mod.") :]
313
+ new_state_dict[new_k] = v
314
+ self.load_state_dict(new_state_dict, strict=True)
315
+ print(f"load state dict from {ckpt_path}")
316
+
317
+ def init_weight(self, m):
318
+ if isinstance(m, nn.Linear):
319
+ nn.init.normal_(m.weight, mean=0, std=0.02)
320
+ if m.bias is not None:
321
+ nn.init.constant_(m.bias, 0)
322
+ elif isinstance(m, RMSNorm):
323
+ nn.init.constant_(m.weight, 1)
324
+ elif isinstance(m, nn.Embedding):
325
+ nn.init.normal_(m.weight, mean=0, std=0.02)
326
+
327
+ def forward(self, input_ids: torch.Tensor):
328
+ hidden_state = self.embedding(input_ids)
329
+ aux_loss = None
330
+ for layer in self.layers:
331
+ hidden_state = layer(hidden_state)
332
+ if self.use_aux_loss:
333
+ if aux_loss is None:
334
+ aux_loss = hidden_state[1]
335
+ else:
336
+ aux_loss += hidden_state[1]
337
+ hidden_state = hidden_state[0]
338
+
339
+ hidden_state = self.norm(hidden_state)
340
+ logits = self.head(hidden_state)
341
+ if self.use_aux_loss:
342
+ return logits, aux_loss
343
+ return logits
344
+
345
+ def top_k_sample(self, logits, top_k=5):
346
+
347
+ weights, indices = torch.topk(logits, k=top_k, dim=-1)
348
+
349
+ probs = torch.softmax(weights, dim=-1)
350
+ chosssed_id = torch.multinomial(probs, num_samples=1)
351
+ new_token = torch.gather(indices, dim=-1, index=chosssed_id)
352
+ return new_token
353
+
354
+ @torch.no_grad()
355
+ def chat(self, conversations, tokenizer, max_new_token=256, top_k=5):
356
+ ids = tokenizer.apply_chat_template(
357
+ conversations, add_generation_prompt=True, tokenize=True
358
+ )
359
+ eos_ids = tokenizer.eos_token_id
360
+ input_ids = torch.tensor(ids, dtype=torch.long).unsqueeze(0)
361
+ for _ in range(max_new_token):
362
+
363
+ logits = self(input_ids) # batch, seq_len, dim
364
+ last_logits = logits[:, -1] # batch, dim
365
+ new_token = self.top_k_sample(last_logits, top_k=top_k)
366
+ input_ids = torch.cat((input_ids, new_token), dim=-1)
367
+
368
+ if new_token.detach()[0].cpu().item() == eos_ids:
369
+ break
370
+
371
+ output_id = input_ids.detach().cpu()[0].tolist()
372
+ output_id = output_id[len(ids) :]
373
+ answer = tokenizer.decode(output_id, skip_special_tokens=True)
374
+ return answer