talphaidze commited on
Commit
e10528d
·
verified ·
1 Parent(s): 35e90b2

Upload 4 files

Browse files
Files changed (4) hide show
  1. aux_losses.py +88 -0
  2. configuration.py +51 -0
  3. modeling.py +481 -0
  4. moe.py +145 -0
aux_losses.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_mean(x, dim):
7
+ return torch.logsumexp(x, dim=dim) - torch.log(
8
+ torch.tensor(x.shape[dim], dtype=torch.float32)
9
+ )
10
+
11
+
12
+ def entropy_reg(logits: torch.Tensor, mean_over_batch: bool = True):
13
+ """Entropy regularization for the router."""
14
+
15
+ entropy_l = lambda l: -(l * l.exp()).sum(-1)
16
+ # softmax over experts
17
+ # logits: [batch_size * sequence_length, num_experts]
18
+ logprobs = F.log_softmax(logits, dim=-1)
19
+ if mean_over_batch:
20
+ # take mean probability over batch
21
+ logprobs = log_mean(logprobs, 0)
22
+
23
+ return -entropy_l(logprobs).mean()
24
+
25
+
26
+ # two losses below are adapted from
27
+ # https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/routing.py
28
+ def load_balancing_loss(logits: torch.Tensor, expert_indices: torch.Tensor) -> float:
29
+ """Computes auxiliary load balancing loss as in Switch Transformer.
30
+
31
+ See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
32
+ implements the loss function presented in equations (4) - (6). It aims to
33
+ penalize those cases where the routing between experts is unbalanced.
34
+
35
+ Args:
36
+ logits: logits assigned to each expert per token. Shape:
37
+ <float32>[batch_size * sequence_length, num_experts].
38
+ expert_indices: <int>[batch_size * sequence_length, num_selected_experts]
39
+ indices identifying the top num_selected_experts for a given token.
40
+
41
+ Returns:
42
+ The auxiliary loss.
43
+ """
44
+ # num_token = batch_size * sequence_length
45
+ num_token, num_experts = logits.shape
46
+
47
+ # Shape: [batch_size * sequence_length, num_selected_experts, num_experts].
48
+ expert_mask = F.one_hot(expert_indices, num_experts)
49
+ # For a given token, determine if it was routed to a given expert.
50
+ # Shape: [batch_size * sequence_length, num_experts]
51
+ expert_mask, _ = torch.max(expert_mask, dim=-2)
52
+
53
+ # shape [num_experts]
54
+ tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32)
55
+
56
+ # compute router probability per expert in log space for numerical stability
57
+ logprobs = F.log_softmax(logits, dim=-1)
58
+ # take mean probability over batch
59
+ # shape [num_experts]
60
+ logprobs = log_mean(logprobs, dim=0)
61
+ router_prob_per_expert = torch.exp(logprobs)
62
+ return (
63
+ torch.mean( # mean over experts
64
+ tokens_per_expert * router_prob_per_expert,
65
+ dtype=torch.float32,
66
+ )
67
+ * num_experts
68
+ )
69
+
70
+
71
+ def router_z_loss(router_logits: torch.Tensor) -> float:
72
+ """Compute router z-loss.
73
+
74
+ The router z-loss was introduced in Designing Effective Sparse Expert Models
75
+ (https://arxiv.org/abs/2202.08906). It encourages router logits to remain
76
+ small in an effort to improve stability.
77
+
78
+ Args:
79
+ router_logits: <float>[batch_size * sequence_length, num_experts]
80
+ router logits
81
+
82
+ Returns:
83
+ Scalar router z-loss.
84
+ """
85
+ num_tokens, _ = router_logits.shape
86
+ log_z = torch.logsumexp(router_logits, dim=-1)
87
+ z_loss = log_z**2
88
+ return torch.sum(z_loss, dtype=torch.float32) / (num_tokens)
configuration.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MoEGPTConfig(PretrainedConfig):
4
+ model_type = "moegpt"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=50304,
9
+ n_embd=768,
10
+ n_layer=12,
11
+ n_head=12,
12
+ sequence_length=1024,
13
+ moe=False,
14
+ moe_routing="standard_gating",
15
+ moe_num_experts=4,
16
+ moe_num_experts_per_tok=2,
17
+ moe_softmax_order="softmax_topk",
18
+ moe_router_loss="load_balancing_z_loss",
19
+ moe_aux_loss_factor=0.01,
20
+ moe_z_loss_factor=1.0,
21
+ mlp_dim_exp_factor=1.0,
22
+ dropout=0.0,
23
+ bias=False,
24
+ architectures=["MoEGPTForCausalLM"],
25
+ auto_map={
26
+ "AutoConfig": "configuration.MoEGPTConfig",
27
+ "AutoModelForCausalLM": "modeling.MoEGPTForCausalLM",
28
+ "AutoTokenizer": "GPT2TokenizerFast"
29
+ },
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.vocab_size = vocab_size
34
+ self.n_embd = n_embd
35
+ self.n_layer = n_layer
36
+ self.n_head = n_head
37
+ self.sequence_length = sequence_length
38
+ self.moe = moe
39
+ self.moe_routing = moe_routing
40
+ self.moe_num_experts = moe_num_experts
41
+ self.moe_num_experts_per_tok = moe_num_experts_per_tok
42
+ self.moe_softmax_order = moe_softmax_order
43
+ self.moe_router_loss = moe_router_loss
44
+ self.moe_aux_loss_factor = moe_aux_loss_factor
45
+ self.moe_z_loss_factor = moe_z_loss_factor
46
+ self.mlp_dim_exp_factor = mlp_dim_exp_factor
47
+ self.dropout = dropout
48
+ self.bias = bias
49
+ self.architectures = architectures
50
+ self.auto_map = auto_map
51
+
modeling.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from configuration import MoEGPTConfig
3
+ # importa anche MoE, MaskedMoE, TimeDependantMoE ecc.
4
+ import math
5
+ import inspect
6
+ from typing import Optional, Dict, Any
7
+ from dataclasses import dataclass
8
+ import tiktoken
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ from transformers.utils import ModelOutput
14
+
15
+
16
+ from .moe import (
17
+ #ExpertChoiceMoE,
18
+ MaskedMoE,
19
+ TimeDependantMoE,
20
+ MoE,
21
+ )
22
+
23
+ from .aux_losses import (
24
+ entropy_reg,
25
+ load_balancing_loss,
26
+ router_z_loss,
27
+ )
28
+
29
+ # class Output(ModelOutput):
30
+ # def __init__(self, logits, loss=None, aux_losses=None, router_logits=None):
31
+ # self.logits = logits
32
+ # self.loss = loss
33
+ # self.aux_losses = aux_losses
34
+ # self.router_logits = router_logits
35
+ @dataclass
36
+ class Output(ModelOutput):
37
+ logits: torch.FloatTensor = None
38
+ loss: Optional[torch.FloatTensor] = None
39
+ aux_losses: Optional[Dict[str, torch.FloatTensor]] = None
40
+ router_logits: Optional[torch.FloatTensor] = None
41
+
42
+ def __repr__(self):
43
+ return f"Output(logits={self.logits}, loss={self.loss}, aux_losses={self.aux_losses}, router_logits={self.router_logits})"
44
+
45
+ class LayerNorm(nn.Module):
46
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
47
+
48
+ def __init__(self, ndim, bias):
49
+ super().__init__()
50
+ self.weight = nn.Parameter(torch.ones(ndim))
51
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
52
+
53
+ def forward(self, input):
54
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
55
+
56
+ class CausalSelfAttention(nn.Module):
57
+ def __init__(self, config):
58
+ super().__init__()
59
+ assert config.n_embd % config.n_head == 0
60
+ # key, query, value projections for all heads, but in a batch
61
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
62
+ # output projection
63
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
64
+ # regularization
65
+ self.attn_dropout = nn.Dropout(config.dropout)
66
+ self.resid_dropout = nn.Dropout(config.dropout)
67
+ self.n_head = config.n_head
68
+ self.n_embd = config.n_embd
69
+ self.dropout = config.dropout
70
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
71
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
72
+ if not self.flash:
73
+ print(
74
+ "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
75
+ )
76
+ # causal mask to ensure that attention is only applied to the left in the input sequence
77
+ self.register_buffer(
78
+ "bias",
79
+ torch.tril(
80
+ torch.ones(config.sequence_length, config.sequence_length)
81
+ ).view(1, 1, config.sequence_length, config.sequence_length),
82
+ )
83
+
84
+ def forward(self, x):
85
+ # batch size, sequence length, embedding dimensionality (n_embd)
86
+ (
87
+ B,
88
+ T,
89
+ C,
90
+ ) = x.size()
91
+
92
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
93
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
94
+ # (B, T, nh, hs)
95
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
96
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
97
+
98
+ # (B, nh, T, hs)
99
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
100
+
101
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
102
+ if self.flash:
103
+ # efficient attention using Flash Attention CUDA kernels
104
+ y = torch.nn.functional.scaled_dot_product_attention(
105
+ q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
106
+ )
107
+ else:
108
+ # manual implementation of attention
109
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
110
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
111
+ att = F.softmax(att, dim=-1)
112
+ att = self.attn_dropout(att)
113
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
114
+ y = (
115
+ y.transpose(1, 2).contiguous().view(B, T, C)
116
+ ) # re-assemble all head outputs side by side
117
+
118
+ # output projection
119
+ y = self.resid_dropout(self.c_proj(y))
120
+ return y
121
+
122
+
123
+ class MLP(nn.Module):
124
+ def __init__(self, config):
125
+ super().__init__()
126
+ self.dim_exp_factor = int(config.mlp_dim_exp_factor * 4)
127
+
128
+ self.c_fc = nn.Linear(
129
+ config.n_embd, self.dim_exp_factor * config.n_embd, bias=config.bias
130
+ )
131
+ self.c_proj = nn.Linear(
132
+ self.dim_exp_factor * config.n_embd, config.n_embd, bias=config.bias
133
+ )
134
+ self.dropout = nn.Dropout(config.dropout)
135
+ self.activation = nn.GELU()
136
+
137
+ def forward(self, x):
138
+ x = self.c_fc(x)
139
+ x = self.activation(x)
140
+ x = self.c_proj(x)
141
+ x = self.dropout(x)
142
+ # need to return same type as the MoE block, but in this case it's empty
143
+ return x, {}
144
+
145
+
146
+ class Block(nn.Module):
147
+ def __init__(self, config):
148
+ super().__init__()
149
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
150
+ self.attn = CausalSelfAttention(config)
151
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
152
+ self.moe_config = config.moe_routing
153
+ if config.moe:
154
+ if config.moe_routing == "standard_gating":
155
+ self.mlp = MoE(config, MLP)
156
+ elif config.moe_routing == "masked":
157
+ self.mlp = TimeDependantMoE(config, MLP)
158
+ #elif config.moe_routing == "expert_choice":
159
+ # self.mlp = ExpertChoiceMoE(config, MLP)
160
+ else:
161
+ raise ValueError(f"Unknown routing: {config.routing}")
162
+ else:
163
+ self.mlp = MLP(config)
164
+
165
+ def forward(self, x, date, *args, **kwargs):
166
+ x = x + self.attn(self.ln_1(x, *args, **kwargs))
167
+ if self.moe_config == "masked":
168
+ x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs), date)
169
+ else:
170
+ x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs))
171
+ x = x + x_
172
+ return x, logits_and_experts
173
+
174
+
175
+ class MoEGPTForCausalLM(PreTrainedModel):
176
+ config_class = MoEGPTConfig
177
+ def __init__(self, config):
178
+ super().__init__(config)
179
+ assert config.vocab_size is not None
180
+ assert config.sequence_length is not None
181
+ self.config = config
182
+ self.tokenizer = tiktoken.get_encoding("gpt2")
183
+ self.base_model_prefix = "timoe"
184
+
185
+ self.transformer = nn.ModuleDict(
186
+ dict(
187
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
188
+ wpe=nn.Embedding(config.sequence_length, config.n_embd),
189
+ drop=nn.Dropout(config.dropout),
190
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
191
+ ln_f=LayerNorm(config.n_embd, bias=config.bias),
192
+ )
193
+ )
194
+
195
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
196
+ # with weight tying when using torch.compile() some warnings get generated:
197
+ # "UserWarning: functional_call was passed multiple values for tied weights.
198
+ # This behavior is deprecated and will be an error in future versions"
199
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
200
+ self.transformer.wte.weight = (
201
+ self.lm_head.weight
202
+ ) # https://paperswithcode.com/method/weight-tying
203
+
204
+ # init all weights
205
+ self.apply(self._init_weights)
206
+ # apply special scaled init to the residual projections, per GPT-2 paper
207
+ for pn, p in self.named_parameters():
208
+ if pn.endswith("c_proj.weight"):
209
+ torch.nn.init.normal_(
210
+ p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
211
+ )
212
+ if pn.endswith("router.weight"):
213
+ # special scaled init to moe router?
214
+ with torch.no_grad():
215
+ dim = 1 if config.moe_routing == "standard_gating" else 0
216
+ std = p.std()
217
+ p.div_(p.sum(dim=dim, keepdim=True))
218
+ p.mul_(std / p.std())
219
+
220
+ def get_router_losses(self, logits, selected_experts, eval=False):
221
+ # logits: (b * seq_len, n_experts)
222
+ # selected_experts: (b * seq_len, topk)
223
+ if eval: # eval mode, compute all losses
224
+ return {
225
+ "moe_entropy_loss": entropy_reg(logits),
226
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
227
+ "moe_z_loss": router_z_loss(logits),
228
+ }
229
+ if self.config.moe_router_loss == "entropy":
230
+ return {
231
+ "moe_entropy_loss": entropy_reg(logits),
232
+ }
233
+ elif self.config.moe_router_loss == "load_balancing_only":
234
+ return {
235
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
236
+ }
237
+ elif self.config.moe_router_loss == "load_balancing_z_loss":
238
+ return {
239
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
240
+ "moe_z_loss": router_z_loss(logits),
241
+ }
242
+ return {}
243
+
244
+ def get_num_params(self, non_embedding=True):
245
+ """
246
+ Return the number of parameters in the model.
247
+ For non-embedding count (default), the position embeddings get subtracted.
248
+ The token embeddings would too, except due to the parameter sharing these
249
+ params are actually used as weights in the final layer, so we include them.
250
+ """
251
+ n_params = sum(p.numel() for p in self.parameters())
252
+ if non_embedding:
253
+ n_params -= self.transformer.wpe.weight.numel()
254
+ return n_params
255
+
256
+ def _init_weights(self, module):
257
+ if isinstance(module, nn.Linear):
258
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
259
+ if module.bias is not None:
260
+ torch.nn.init.zeros_(module.bias)
261
+ elif isinstance(module, nn.Embedding):
262
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
263
+
264
+ def forward(self, idx, date=None, targets=None, attention_mask=None, get_logits=True, moe=False):
265
+ device = idx.device
266
+ b, t = idx.size()
267
+ assert (
268
+ t <= self.config.sequence_length
269
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
270
+ # shape (1, t)
271
+ if date is None:
272
+ # set all the date to 6
273
+ date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0)
274
+ else:
275
+ date = (date - 2013) // 2 + 1
276
+ date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0)
277
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
278
+
279
+ # forward the GPT model itself
280
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
281
+ pos_emb = self.transformer.wpe(
282
+ pos
283
+ ) # position embeddings of shape (1, t, n_embd)
284
+ x = self.transformer.drop(tok_emb + pos_emb)
285
+
286
+ # router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts)
287
+ router_logits = []
288
+ # experts is a list for each layer's selected experts, shape (b * seq_len, topk)
289
+ experts = []
290
+
291
+ # forward pass through all the transformer blocks
292
+ for block in self.transformer.h:
293
+ x, logits_and_experts = block(x, date)
294
+ if len(logits_and_experts) > 0:
295
+ router_logits.append(logits_and_experts["router_logits"])
296
+ experts.append(logits_and_experts["selected_experts"])
297
+ x = self.transformer.ln_f(x)
298
+
299
+ # aux_losses is a dict with keys for different auxiliary losses
300
+ aux_losses = {}
301
+
302
+ if targets is not None:
303
+ # if we are given some desired targets also calculate the loss
304
+ logits = self.lm_head(x)
305
+ loss = F.cross_entropy(
306
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
307
+ )
308
+ if moe and (self.config.moe_routing == "standard_gating" or self.config.moe_routing == "masked"):
309
+ # calculate the router losses per layer
310
+ for logit, expert_choice in zip(router_logits, experts):
311
+ router_losses = self.get_router_losses(
312
+ logit, expert_choice, eval=not self.training
313
+ )
314
+ for k, v in router_losses.items():
315
+ aux_losses[k] = aux_losses.get(k, 0.0) + v
316
+ if self.training:
317
+ loss += (
318
+ v
319
+ * getattr(self.config, k + "_factor")
320
+ / self.config.n_layer
321
+ )
322
+ else:
323
+ # inference-time mini-optimization: only forward the lm_head on the very last position
324
+ logits = self.lm_head(
325
+ #x[:, [-1], :]
326
+ x
327
+ ) # note: using list [-1] to preserve the time dim
328
+ loss = None
329
+ logits = logits if get_logits else None
330
+ router_logits = (
331
+ torch.stack(router_logits, dim=0) if len(router_logits) > 0 else None
332
+ )
333
+ # return {
334
+ # "logits": logits,
335
+ # "loss": loss,
336
+ # "aux_losses": aux_losses,
337
+ # "router_logits": router_logits,
338
+ # }
339
+ return Output(logits = logits, loss = loss, aux_losses = aux_losses, router_logits = router_logits)
340
+
341
+ def crop_sequence_length(self, sequence_length):
342
+ # model surgery to decrease the block size if necessary
343
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
344
+ # but want to use a smaller block size for some smaller, simpler model
345
+ assert sequence_length <= self.config.sequence_length
346
+ self.config.sequence_length = sequence_length
347
+ self.transformer.wpe.weight = nn.Parameter(
348
+ self.transformer.wpe.weight[:sequence_length]
349
+ )
350
+ for block in self.transformer.h:
351
+ block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length]
352
+
353
+
354
+ def get_parameter_group_specs(self):
355
+ """
356
+ This long function is unfortunately doing something very simple and is being very defensive:
357
+ We are separating out all parameters of the model into two buckets: those that will experience
358
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
359
+ We are then returning the PyTorch optimizer object.
360
+ """
361
+
362
+ # separate out all parameters to those that will and won't experience regularizing weight decay
363
+ decay = set()
364
+ no_decay = set()
365
+ whitelist_weight_modules = (torch.nn.Linear,)
366
+
367
+ BLACKLIST_WEIGHT_MODULES = (
368
+ torch.nn.LayerNorm,
369
+ LayerNorm,
370
+ torch.nn.Embedding,
371
+ )
372
+
373
+ for mn, m in self.named_modules():
374
+ for pn, p in m.named_parameters():
375
+ fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
376
+ # random note: because named_modules and named_parameters are recursive
377
+ # we will see the same tensors p many many times. but doing it this way
378
+ # allows us to know which parent module any tensor p belongs to...
379
+ if pn.endswith("bias"):
380
+ # all biases will not be decayed
381
+ no_decay.add(fpn)
382
+ elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
383
+ # weights of whitelist modules will be weight decayed
384
+ decay.add(fpn)
385
+ elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES):
386
+ # weights of blacklist modules will NOT be weight decayed
387
+ no_decay.add(fpn)
388
+
389
+ # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
390
+ # will appear in the no_decay and decay sets respectively after the above.
391
+ # In addition, because named_parameters() doesn't return duplicates, it
392
+ # will only return the first occurence, key'd by 'transformer.wte.weight', below.
393
+ # so let's manually remove 'lm_head.weight' from decay set. This will include
394
+ # this tensor into optimization via transformer.wte.weight only, and not decayed.
395
+ decay.remove("lm_head.weight")
396
+
397
+ # validate that we considered every parameter
398
+ param_dict = {pn: p for pn, p in self.named_parameters()}
399
+ inter_params = decay & no_decay
400
+ union_params = decay | no_decay
401
+ assert (
402
+ len(inter_params) == 0
403
+ ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
404
+ assert (
405
+ len(param_dict.keys() - union_params) == 0
406
+ ), "parameters %s were not separated into either decay/no_decay set!" % (
407
+ str(param_dict.keys() - union_params),
408
+ )
409
+
410
+ # create the pytorch optimizer object
411
+ return [
412
+ {"params": sorted(list(decay))},
413
+ {"params": sorted(list(no_decay)), "weight_decay": 0.0},
414
+ ]
415
+
416
+ @torch.no_grad()
417
+ def generate(self, input_ids, max_new_tokens, date = None, temperature=1.0, top_k=None):
418
+ """
419
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
420
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
421
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
422
+ """
423
+ idx = input_ids
424
+ for _ in range(max_new_tokens):
425
+ # if the sequence context is growing too long we must crop it at sequence_length
426
+ idx_cond = (
427
+ idx
428
+ if idx.size(1) <= self.config.sequence_length
429
+ else idx[:, -self.config.sequence_length :]
430
+ )
431
+ # forward the model to get the logits for the index in the sequence
432
+ logits = self(idx_cond, date, get_logits=True).logits
433
+ # pluck the logits at the final step and scale by desired temperature
434
+ logits = logits[:, -1, :] / temperature
435
+ # optionally crop the logits to only the top k options
436
+ if top_k is not None:
437
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
438
+ logits[logits < v[:, [-1]]] = -float("Inf")
439
+ # apply softmax to convert logits to (normalized) probabilities
440
+ probs = F.softmax(logits, dim=-1)
441
+ # sample from the distribution
442
+ idx_next = torch.multinomial(probs, num_samples=1)
443
+ # append sampled index to the running sequence and continue
444
+ idx = torch.cat((idx, idx_next), dim=1)
445
+ # check if we hit the end of the sequence
446
+ if idx_next.item() == self.tokenizer.eot_token:
447
+ break
448
+
449
+ return idx
450
+
451
+ @torch.no_grad()
452
+ def generate_from_string(self, in_str, max_new_tokens, date = None, temperature=1.0, top_k=None):
453
+ idx = (
454
+ torch.tensor(
455
+ self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
456
+ )
457
+ .view(1, -1)
458
+ .to(self.lm_head.weight.device)
459
+ )
460
+ out_idx = (
461
+ self.generate(idx, max_new_tokens, date, temperature, top_k)
462
+ .view(-1)
463
+ .to("cpu")
464
+ .numpy()
465
+ )
466
+ return self.tokenizer.decode(out_idx).split(in_str)[-1]
467
+
468
+
469
+ def get_input_embeddings(self):
470
+ return self.transformer.wte
471
+
472
+ def set_input_embeddings(self, new_embeddings):
473
+ self.transformer.wte = new_embeddings
474
+ # reset the lm_head to use the new embeddings
475
+ # this is necessary because the lm_head is tied to the input embeddings
476
+ self.lm_head = nn.Linear(
477
+ self.config.n_embd, new_embeddings.weight.shape[0] , bias=False
478
+ )
479
+ #self.transformer.wte.weight = (
480
+ # self.lm_head.weight
481
+ #)
moe.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple MoE routing implementations that replace the MLP block in a standard transformer.
3
+ References:
4
+ 1) Mistral Source for Mixtral MoEs:
5
+ https://github.com/mistralai/mistral-src
6
+ 2) ST-MoE:
7
+ https://arxiv.org/abs/2202.08906
8
+ 3) Our notepad of MoE resources:
9
+ https://docs.google.com/document/d/1NuQ5jr7V-Jv1ui7p4KrxO_JTz-7bpYcYMmh49EeJ-QA/edit?usp=sharing
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class MoE(nn.Module):
18
+ """
19
+ Simplest MoE implementation with a linear router and softmax over experts.
20
+
21
+ Note that in this implementation, we simply loop over the experts and
22
+ aggregate the results. This is not the most efficient way to do it, but
23
+ it also avoids the large memory overhead _and_ has no token dropping
24
+ (because we do not need the capacity factor).
25
+ """
26
+
27
+ def __init__(self, config, mlp):
28
+ super().__init__()
29
+ assert config.moe_num_experts > 0
30
+ self.experts = nn.ModuleList(
31
+ [mlp(config=config) for _ in range(config.moe_num_experts)]
32
+ )
33
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False)
34
+ self.top_k = config.moe_num_experts_per_tok
35
+ self.softmax_order = config.moe_softmax_order
36
+
37
+ def forward(self, inputs: torch.Tensor):
38
+ # [batch_size * sequence_length, n_embd]
39
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
40
+ # [batch_size * sequence_length, num_experts]
41
+ router_logits = self.router(inputs_squashed)
42
+
43
+ # note that selected experts will be the same for all orders:
44
+ # softmax doesnt change top-k, but the weights are different
45
+ if self.softmax_order == "softmax_topk":
46
+ all_probs = F.softmax(router_logits, dim=1, dtype=torch.float32)
47
+ weights, selected_experts = torch.topk(all_probs, self.top_k)
48
+ elif self.softmax_order == "topk_softmax":
49
+ weights, selected_experts = torch.topk(router_logits, self.top_k)
50
+ weights = F.softmax(weights, dim=-1, dtype=torch.float32)
51
+ else:
52
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
53
+
54
+ results = torch.zeros_like(inputs_squashed)
55
+ # naive looping over experts
56
+ for i, expert in enumerate(self.experts):
57
+ batch_idx, nth_expert = torch.where(selected_experts == i)
58
+ output, _ = expert(inputs_squashed[batch_idx])
59
+ results[batch_idx] += weights[batch_idx, nth_expert, None] * output
60
+
61
+ # return results and router logits (for aux loss calculation later)
62
+ return results.view_as(inputs), {
63
+ "router_logits": router_logits,
64
+ "selected_experts": selected_experts,
65
+ }
66
+
67
+
68
+ class ExpertChoiceMoE(nn.Module):
69
+ """
70
+ This is the MoE implementation that uses the expert choice method from
71
+ https://arxiv.org/pdf/2202.09368v2.pdf.
72
+
73
+ The main difference is that the router takes the softmax over the tokens, not the experts
74
+ (i.e. each expert chooses its top-k tokens, not the other way around).
75
+ For the same capacity factor, in theory, the same compute will be used as in standard top-k routing.
76
+ AFAICT, there is no way around the capacity factor (whereas the code above does not need it).
77
+ """
78
+
79
+ def __init__(self, config, mlp):
80
+ super().__init__()
81
+ assert config.moe_num_experts > 0
82
+ self.n_experts = config.moe_num_experts
83
+ self.experts = nn.ModuleList(
84
+ [mlp(config=config) for _ in range(config.moe_num_experts)]
85
+ )
86
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False)
87
+ self.capacity_factor = config.capacity_factor
88
+ self.softmax_order = config.moe_softmax_order
89
+ self.top_k = int(
90
+ self.capacity_factor
91
+ * config.batch_size
92
+ * config.sequence_length
93
+ / config.moe_num_experts
94
+ )
95
+
96
+ def forward(self, inputs: torch.Tensor):
97
+ # [batch_size * sequence_length, n_embd]
98
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
99
+ num_tokens = inputs_squashed.shape[0]
100
+ top_k = min(self.top_k, int(self.capacity_factor * num_tokens / self.n_experts))
101
+ # [batch_size * sequence_length, num_experts]
102
+ router_logits = self.router(inputs_squashed)
103
+
104
+ # note that selected experts will be the same for all orders:
105
+ # softmax doesnt change top-k, but the weights are different
106
+ if self.softmax_order == "softmax_topk":
107
+ all_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
108
+ # weights and selected tokens: [num_experts, top_k]
109
+ # topk over tokens!
110
+ weights, selected_tokens = torch.topk(all_probs.T, top_k)
111
+ elif self.softmax_order == "topk_softmax":
112
+ # weights and selected tokens: [num_experts, top_k]
113
+ weights, selected_tokens = torch.topk(router_logits.T, top_k)
114
+ weights = F.softmax(weights, dim=-1, dtype=torch.float32)
115
+ else:
116
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
117
+
118
+ """ this is the full parallel version with einsum -- this can OOM quickly """
119
+ # [num_experts, top_k, num_tokens]
120
+ # P = F.one_hot(selected_tokens, num_tokens).type_as(inputs_squashed)
121
+ # # [num_experts, top_k, n_embd]
122
+ # x_in = torch.matmul(P, inputs_squashed)
123
+ # # [num_experts, num_tokens, n_embd]
124
+ # experts_out = torch.stack(
125
+ # [expert(x)[0] for expert, x in zip(self.experts, x_in)], dim=0
126
+ # )
127
+ # results = torch.einsum("ijl,ij,ijd->ld", P, weights, experts_out)
128
+
129
+ """ this is the naive loop version """
130
+ # loop through experts because of memory growing too large
131
+ # when doing everything in parallel.
132
+ # also, more hackable :)
133
+ results = torch.zeros_like(inputs_squashed)
134
+ for i, expert in enumerate(self.experts):
135
+ # [top_k]
136
+ batch_idx = selected_tokens[i]
137
+ # [top_k, n_embd]
138
+ output, _ = expert(inputs_squashed[batch_idx])
139
+ results[batch_idx] += weights[i, :, None] * output
140
+
141
+ # return results and router logits (for aux loss calculation later)
142
+ return results.view_as(inputs), {
143
+ "router_logits": router_logits,
144
+ "selected_experts": selected_tokens,
145
+ }