studyinglover commited on
Commit
34eb6cc
·
verified ·
1 Parent(s): 10549c7

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +483 -0
model.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import struct
3
+ import inspect
4
+ import time
5
+
6
+ from .config import ModelConfig
7
+ from typing import Any, Optional, Tuple
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from transformers import PreTrainedModel
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+
15
+
16
+ class RMSNorm(torch.nn.Module):
17
+ def __init__(self, dim: int, eps: float):
18
+ super().__init__()
19
+ self.eps = eps
20
+ self.weight = nn.Parameter(torch.ones(dim))
21
+
22
+ def _norm(self, x):
23
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
24
+
25
+ def forward(self, x):
26
+ output = self._norm(x.float()).type_as(x)
27
+ return output * self.weight
28
+
29
+
30
+ def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0):
31
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
32
+ t = torch.arange(end, device=freqs.device) # type: ignore
33
+ freqs = torch.outer(t, freqs).float() # type: ignore
34
+ pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
35
+ return pos_cis
36
+
37
+
38
+ def apply_rotary_emb(xq, xk, pos_cis):
39
+ def unite_shape(pos_cis, x):
40
+ ndim = x.ndim
41
+ assert 0 <= 1 < ndim
42
+ assert pos_cis.shape == (x.shape[1], x.shape[-1])
43
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
44
+ return pos_cis.view(*shape)
45
+
46
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
47
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
48
+ pos_cis = unite_shape(pos_cis, xq_)
49
+ xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
50
+ xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
51
+ return xq_out.type_as(xq), xk_out.type_as(xk)
52
+
53
+
54
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
55
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
56
+ bs, slen, n_kv_heads, head_dim = x.shape
57
+ if n_rep == 1:
58
+ return x
59
+ return (
60
+ x[:, :, :, None, :]
61
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
62
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
63
+ )
64
+
65
+
66
+ class Attention(nn.Module):
67
+ def __init__(self, args: ModelConfig):
68
+ super().__init__()
69
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
70
+ assert args.n_heads % self.n_kv_heads == 0
71
+ self.n_local_heads = args.n_heads
72
+ self.n_local_kv_heads: int = self.n_kv_heads
73
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
74
+ self.head_dim = args.dim // args.n_heads
75
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
76
+ self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
77
+ self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
78
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
79
+ self.k_cache, self.v_cache = None, None
80
+ self.attn_dropout = nn.Dropout(args.dropout)
81
+ self.resid_dropout = nn.Dropout(args.dropout)
82
+ self.dropout = args.dropout
83
+ self.flash = (
84
+ hasattr(torch.nn.functional, "scaled_dot_product_attention")
85
+ and args.flash_attn
86
+ )
87
+
88
+ # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
89
+ mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
90
+ mask = torch.triu(mask, diagonal=1)
91
+ self.register_buffer("mask", mask, persistent=False)
92
+
93
+ def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
94
+ bsz, seqlen, _ = x.shape
95
+
96
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
97
+
98
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
99
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
100
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
101
+
102
+ xq, xk = apply_rotary_emb(xq, xk, pos_cis)
103
+
104
+ # 更高效的kv_cache实现
105
+ if kv_cache and self.eval():
106
+ if seqlen == 1 and all(
107
+ cache is not None for cache in (self.k_cache, self.v_cache)
108
+ ):
109
+ xk = torch.cat((self.k_cache, xk), dim=1)
110
+ xv = torch.cat((self.v_cache, xv), dim=1)
111
+ self.k_cache, self.v_cache = xk, xv
112
+
113
+ xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
114
+ xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
115
+
116
+ xq = xq.transpose(1, 2)
117
+ xk = xk.transpose(1, 2)
118
+ xv = xv.transpose(1, 2)
119
+
120
+ if self.flash and seqlen != 1:
121
+ output = torch.nn.functional.scaled_dot_product_attention(
122
+ xq,
123
+ xk,
124
+ xv,
125
+ attn_mask=None,
126
+ dropout_p=self.dropout if self.training else 0.0,
127
+ is_causal=True,
128
+ )
129
+ else:
130
+ scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
131
+ scores = (
132
+ scores + self.mask[:, :, :seqlen, :seqlen]
133
+ ) # (bs, n_local_heads, seqlen, cache_len + seqlen)
134
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
135
+ scores = self.attn_dropout(scores)
136
+ output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
137
+
138
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
139
+
140
+ output = self.wo(output)
141
+ output = self.resid_dropout(output)
142
+ return output
143
+
144
+
145
+ class FeedForward(nn.Module):
146
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
147
+ super().__init__()
148
+ if hidden_dim is None:
149
+ hidden_dim = 4 * dim
150
+ hidden_dim = int(2 * hidden_dim / 3)
151
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
152
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
153
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
154
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
155
+ self.dropout = nn.Dropout(dropout)
156
+
157
+ def forward(self, x):
158
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
159
+
160
+
161
+ class MoEGate(nn.Module):
162
+ def __init__(self, config: ModelConfig):
163
+ super().__init__()
164
+ self.config = config
165
+ self.top_k = config.num_experts_per_tok
166
+ self.n_routed_experts = config.n_routed_experts
167
+
168
+ self.scoring_func = config.scoring_func
169
+ self.alpha = config.aux_loss_alpha
170
+ self.seq_aux = config.seq_aux
171
+
172
+ self.norm_topk_prob = config.norm_topk_prob
173
+ self.gating_dim = config.dim
174
+ self.weight = nn.Parameter(
175
+ torch.empty((self.n_routed_experts, self.gating_dim))
176
+ )
177
+ self.reset_parameters()
178
+
179
+ def reset_parameters(self) -> None:
180
+ import torch.nn.init as init
181
+
182
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
183
+
184
+ def forward(self, hidden_states):
185
+ bsz, seq_len, h = hidden_states.shape
186
+
187
+ hidden_states = hidden_states.view(-1, h)
188
+ logits = F.linear(hidden_states, self.weight, None)
189
+ if self.scoring_func == "softmax":
190
+ scores = logits.softmax(dim=-1)
191
+ else:
192
+ raise NotImplementedError(
193
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
194
+ )
195
+
196
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
197
+
198
+ if self.top_k > 1 and self.norm_topk_prob:
199
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
200
+ topk_weight = topk_weight / denominator
201
+
202
+ if self.training and self.alpha > 0.0:
203
+ scores_for_aux = scores
204
+ aux_topk = self.top_k
205
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
206
+ if self.seq_aux:
207
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
208
+ ce = torch.zeros(
209
+ bsz, self.n_routed_experts, device=hidden_states.device
210
+ )
211
+ ce.scatter_add_(
212
+ 1,
213
+ topk_idx_for_aux_loss,
214
+ torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
215
+ ).div_(seq_len * aux_topk / self.n_routed_experts)
216
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
217
+ dim=1
218
+ ).mean() * self.alpha
219
+ else:
220
+ mask_ce = F.one_hot(
221
+ topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
222
+ )
223
+ ce = mask_ce.float().mean(0)
224
+ Pi = scores_for_aux.mean(0)
225
+ fi = ce * self.n_routed_experts
226
+ aux_loss = (Pi * fi).sum() * self.alpha
227
+ else:
228
+ aux_loss = None
229
+ return topk_idx, topk_weight, aux_loss
230
+
231
+
232
+ class MOEFeedForward(nn.Module):
233
+ def __init__(self, config: ModelConfig):
234
+ super().__init__()
235
+ self.config = config
236
+ self.experts = nn.ModuleList(
237
+ [
238
+ FeedForward(
239
+ dim=config.dim,
240
+ hidden_dim=config.hidden_dim,
241
+ multiple_of=config.multiple_of,
242
+ dropout=config.dropout,
243
+ )
244
+ for _ in range(config.n_routed_experts)
245
+ ]
246
+ )
247
+
248
+ self.gate = MoEGate(config)
249
+ if config.n_shared_experts is not None:
250
+ self.shared_experts = FeedForward(
251
+ dim=config.dim,
252
+ hidden_dim=config.hidden_dim,
253
+ multiple_of=config.multiple_of,
254
+ dropout=config.dropout,
255
+ )
256
+
257
+ def forward(self, x):
258
+ identity = x
259
+ orig_shape = x.shape
260
+ bsz, seq_len, _ = x.shape
261
+
262
+ # 使用门控机制选择专家
263
+ topk_idx, topk_weight, aux_loss = self.gate(x)
264
+
265
+ x = x.view(-1, x.shape[-1])
266
+ flat_topk_idx = topk_idx.view(-1)
267
+
268
+ if self.training:
269
+ # 训练模式下,重复输入数据
270
+ x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
271
+ y = torch.empty_like(x, dtype=torch.float16)
272
+ for i, expert in enumerate(self.experts):
273
+ y[flat_topk_idx == i] = expert(x[flat_topk_idx == i])
274
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
275
+ y = y.view(*orig_shape)
276
+ else:
277
+ # 推理模式下,只选择最优专家
278
+ y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(
279
+ *orig_shape
280
+ )
281
+
282
+ if self.config.n_shared_experts is not None:
283
+ y = y + self.shared_experts(identity)
284
+
285
+ return y
286
+
287
+ @torch.no_grad()
288
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
289
+ expert_cache = torch.zeros_like(x)
290
+ idxs = flat_expert_indices.argsort()
291
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
292
+ token_idxs = idxs // self.config.num_experts_per_tok
293
+ # 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52]
294
+ # 当token_idxs=[3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...]
295
+ # 意味着当token_idxs[:6] -> [3, 7, 19, 21, 24, 25, 4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理......
296
+ for i, end_idx in enumerate(tokens_per_expert):
297
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
298
+ if start_idx == end_idx:
299
+ continue
300
+ expert = self.experts[i]
301
+ exp_token_idx = token_idxs[start_idx:end_idx]
302
+ expert_tokens = x[exp_token_idx]
303
+ expert_out = expert(expert_tokens)
304
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
305
+ # 使用 scatter_add_ 进行 sum 操作
306
+ expert_cache.scatter_add_(
307
+ 0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out
308
+ )
309
+
310
+ return expert_cache
311
+
312
+
313
+ class TransformerBlock(nn.Module):
314
+ def __init__(self, layer_id: int, args: ModelConfig):
315
+ super().__init__()
316
+ self.n_heads = args.n_heads
317
+ self.dim = args.dim
318
+ self.head_dim = args.dim // args.n_heads
319
+ self.attention = Attention(args)
320
+
321
+ self.layer_id = layer_id
322
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
323
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
324
+
325
+ self.feed_forward = FeedForward(
326
+ dim=args.dim,
327
+ hidden_dim=args.hidden_dim,
328
+ multiple_of=args.multiple_of,
329
+ dropout=args.dropout,
330
+ )
331
+
332
+ def forward(self, x, pos_cis, kv_cache=False):
333
+ h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
334
+ out = h + self.feed_forward(self.ffn_norm(h))
335
+ return out
336
+
337
+
338
+ class Transformer(PreTrainedModel):
339
+ config_class = ModelConfig
340
+ last_loss: Optional[torch.Tensor]
341
+
342
+ def __init__(self, params: ModelConfig = None):
343
+ super().__init__(params)
344
+ if not params:
345
+ params = ModelConfig()
346
+ self.params = params
347
+ self.vocab_size = params.vocab_size
348
+ self.n_layers = params.n_layers
349
+
350
+ self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
351
+ self.dropout = nn.Dropout(params.dropout)
352
+ self.layers = torch.nn.ModuleList()
353
+ for layer_id in range(self.n_layers):
354
+ self.layers.append(TransformerBlock(layer_id, params))
355
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
356
+ self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
357
+ self.tok_embeddings.weight = self.output.weight
358
+ pos_cis = precompute_pos_cis(
359
+ self.params.dim // self.params.n_heads, self.params.max_seq_len
360
+ )
361
+ self.register_buffer("pos_cis", pos_cis, persistent=False)
362
+
363
+ self.apply(self._init_weights)
364
+
365
+ for pn, p in self.named_parameters():
366
+ if pn.endswith("w3.weight") or pn.endswith("wo.weight"):
367
+ torch.nn.init.normal_(
368
+ p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers)
369
+ )
370
+
371
+ self.last_loss = None
372
+ self.OUT = CausalLMOutputWithPast()
373
+ self._no_split_modules = [name for name, _ in self.named_modules()]
374
+
375
+ def _init_weights(self, module):
376
+ if isinstance(module, nn.Linear):
377
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
378
+ if module.bias is not None:
379
+ torch.nn.init.zeros_(module.bias)
380
+ elif isinstance(module, nn.Embedding):
381
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
382
+
383
+ def forward(
384
+ self,
385
+ tokens: Optional[torch.Tensor] = None,
386
+ targets: Optional[torch.Tensor] = None,
387
+ kv_cache=False,
388
+ **keyargs,
389
+ ):
390
+ current_idx = 0
391
+ if "input_ids" in keyargs:
392
+ tokens = keyargs["input_ids"]
393
+ if "attention_mask" in keyargs:
394
+ targets = keyargs["attention_mask"]
395
+ if "current_idx" in keyargs:
396
+ current_idx = int(keyargs["current_idx"])
397
+
398
+ _bsz, seqlen = tokens.shape
399
+ h = self.tok_embeddings(tokens)
400
+ h = self.dropout(h)
401
+ pos_cis = self.pos_cis[current_idx : current_idx + seqlen]
402
+ for idx, layer in enumerate(self.layers):
403
+ h = layer(h, pos_cis, kv_cache)
404
+
405
+ h = self.norm(h)
406
+
407
+ if targets is not None:
408
+ logits = self.output(h)
409
+ self.last_loss = F.cross_entropy(
410
+ logits.view(-1, logits.size(-1)),
411
+ targets.view(-1),
412
+ ignore_index=0,
413
+ reduction="none",
414
+ )
415
+ else:
416
+ logits = self.output(h[:, [-1], :])
417
+ self.last_loss = None
418
+
419
+ self.OUT.__setitem__("logits", logits)
420
+ self.OUT.__setitem__("last_loss", self.last_loss)
421
+ return self.OUT
422
+
423
+ @torch.inference_mode()
424
+ def generate(
425
+ self,
426
+ idx,
427
+ eos,
428
+ max_new_tokens,
429
+ temperature=0.7,
430
+ top_k=8,
431
+ stream=True,
432
+ rp=1.0,
433
+ kv_cache=True,
434
+ ):
435
+ # rp: repetition_penalty
436
+ index = idx.shape[1]
437
+ init_inference = True
438
+ while idx.shape[1] < max_new_tokens - 1:
439
+ if init_inference or not kv_cache:
440
+ inference_res, init_inference = self(idx, kv_cache=kv_cache), False
441
+ else:
442
+ inference_res = self(
443
+ idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1
444
+ )
445
+
446
+ logits = inference_res.logits
447
+ logits = logits[:, -1, :]
448
+
449
+ for token in set(idx.tolist()[0]):
450
+ logits[:, token] /= rp
451
+
452
+ if temperature == 0.0:
453
+ _, idx_next = torch.topk(logits, k=1, dim=-1)
454
+ else:
455
+ logits = logits / temperature
456
+ if top_k is not None:
457
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
458
+ logits[logits < v[:, [-1]]] = -float("Inf")
459
+
460
+ probs = F.softmax(logits, dim=-1)
461
+ idx_next = torch.multinomial(probs, num_samples=1, generator=None)
462
+
463
+ if idx_next == eos:
464
+ break
465
+
466
+ idx = torch.cat((idx, idx_next), dim=1)
467
+ if stream:
468
+ yield idx[:, index:]
469
+
470
+ if not stream:
471
+ yield idx[:, index:]
472
+
473
+ @torch.inference_mode()
474
+ def eval_answer(self, idx):
475
+ idx_cond = (
476
+ idx
477
+ if idx.size(1) <= self.params.max_seq_len
478
+ else idx[:, -self.params.max_seq_len :]
479
+ )
480
+ inference_res = self(idx_cond)
481
+ logits = inference_res.logits
482
+ logits = logits[:, -1, :]
483
+ return logits