robinfaro commited on
Commit
88d6e74
·
verified ·
1 Parent(s): a750462

Upload custom config and model files

Browse files
Files changed (12) hide show
  1. README.md +10 -0
  2. __init__.py +2 -0
  3. aux_losses.py +88 -0
  4. config.json +89 -0
  5. configuration.py +51 -0
  6. merges.txt +0 -0
  7. modeling.py +514 -0
  8. moe.py +134 -0
  9. special_tokens_map.json +5 -0
  10. tokenizer.json +0 -0
  11. tokenizer_config.json +20 -0
  12. vocab.json +0 -0
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration import MoEGPTConfig
2
+ from .modeling import MoEGPTForCausalLM
aux_losses.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_mean(x, dim):
7
+ return torch.logsumexp(x, dim=dim) - torch.log(
8
+ torch.tensor(x.shape[dim], dtype=torch.float32)
9
+ )
10
+
11
+
12
+ def entropy_reg(logits: torch.Tensor, mean_over_batch: bool = True):
13
+ """Entropy regularization for the router."""
14
+
15
+ entropy_l = lambda l: -(l * l.exp()).sum(-1)
16
+ # softmax over experts
17
+ # logits: [batch_size * sequence_length, num_experts]
18
+ logprobs = F.log_softmax(logits, dim=-1)
19
+ if mean_over_batch:
20
+ # take mean probability over batch
21
+ logprobs = log_mean(logprobs, 0)
22
+
23
+ return -entropy_l(logprobs).mean()
24
+
25
+
26
+ # two losses below are adapted from
27
+ # https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/routing.py
28
+ def load_balancing_loss(logits: torch.Tensor, expert_indices: torch.Tensor) -> float:
29
+ """Computes auxiliary load balancing loss as in Switch Transformer.
30
+
31
+ See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
32
+ implements the loss function presented in equations (4) - (6). It aims to
33
+ penalize those cases where the routing between experts is unbalanced.
34
+
35
+ Args:
36
+ logits: logits assigned to each expert per token. Shape:
37
+ <float32>[batch_size * sequence_length, num_experts].
38
+ expert_indices: <int>[batch_size * sequence_length, num_selected_experts]
39
+ indices identifying the top num_selected_experts for a given token.
40
+
41
+ Returns:
42
+ The auxiliary loss.
43
+ """
44
+ # num_token = batch_size * sequence_length
45
+ num_token, num_experts = logits.shape
46
+
47
+ # Shape: [batch_size * sequence_length, num_selected_experts, num_experts].
48
+ expert_mask = F.one_hot(expert_indices, num_experts)
49
+ # For a given token, determine if it was routed to a given expert.
50
+ # Shape: [batch_size * sequence_length, num_experts]
51
+ expert_mask, _ = torch.max(expert_mask, dim=-2)
52
+
53
+ # shape [num_experts]
54
+ tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32)
55
+
56
+ # compute router probability per expert in log space for numerical stability
57
+ logprobs = F.log_softmax(logits, dim=-1)
58
+ # take mean probability over batch
59
+ # shape [num_experts]
60
+ logprobs = log_mean(logprobs, dim=0)
61
+ router_prob_per_expert = torch.exp(logprobs)
62
+ return (
63
+ torch.mean( # mean over experts
64
+ tokens_per_expert * router_prob_per_expert,
65
+ dtype=torch.float32,
66
+ )
67
+ * num_experts
68
+ )
69
+
70
+
71
+ def router_z_loss(router_logits: torch.Tensor) -> float:
72
+ """Compute router z-loss.
73
+
74
+ The router z-loss was introduced in Designing Effective Sparse Expert Models
75
+ (https://arxiv.org/abs/2202.08906). It encourages router logits to remain
76
+ small in an effort to improve stability.
77
+
78
+ Args:
79
+ router_logits: <float>[batch_size * sequence_length, num_experts]
80
+ router logits
81
+
82
+ Returns:
83
+ Scalar router z-loss.
84
+ """
85
+ num_tokens, _ = router_logits.shape
86
+ log_z = torch.logsumexp(router_logits, dim=-1)
87
+ z_loss = log_z**2
88
+ return torch.sum(z_loss, dtype=torch.float32) / (num_tokens)
config.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "return_dict": true,
3
+ "output_hidden_states": false,
4
+ "output_attentions": false,
5
+ "torchscript": false,
6
+ "torch_dtype": null,
7
+ "use_bfloat16": false,
8
+ "tf_legacy_loss": false,
9
+ "pruned_heads": {},
10
+ "tie_word_embeddings": true,
11
+ "chunk_size_feed_forward": 0,
12
+ "is_encoder_decoder": false,
13
+ "is_decoder": false,
14
+ "cross_attention_hidden_size": null,
15
+ "add_cross_attention": false,
16
+ "tie_encoder_decoder": false,
17
+ "max_length": 20,
18
+ "min_length": 0,
19
+ "do_sample": false,
20
+ "early_stopping": false,
21
+ "num_beams": 1,
22
+ "num_beam_groups": 1,
23
+ "diversity_penalty": 0.0,
24
+ "temperature": 1.0,
25
+ "top_k": 50,
26
+ "top_p": 1.0,
27
+ "typical_p": 1.0,
28
+ "repetition_penalty": 1.0,
29
+ "length_penalty": 1.0,
30
+ "no_repeat_ngram_size": 0,
31
+ "encoder_no_repeat_ngram_size": 0,
32
+ "bad_words_ids": null,
33
+ "num_return_sequences": 1,
34
+ "output_scores": false,
35
+ "return_dict_in_generate": false,
36
+ "forced_bos_token_id": null,
37
+ "forced_eos_token_id": null,
38
+ "remove_invalid_values": false,
39
+ "exponential_decay_length_penalty": null,
40
+ "suppress_tokens": null,
41
+ "begin_suppress_tokens": null,
42
+ "architectures": [
43
+ "MoEGPTForCausalLM"
44
+ ],
45
+ "finetuning_task": null,
46
+ "id2label": {
47
+ "0": "LABEL_0",
48
+ "1": "LABEL_1"
49
+ },
50
+ "label2id": {
51
+ "LABEL_0": 0,
52
+ "LABEL_1": 1
53
+ },
54
+ "tokenizer_class": null,
55
+ "prefix": null,
56
+ "bos_token_id": null,
57
+ "pad_token_id": null,
58
+ "eos_token_id": null,
59
+ "sep_token_id": null,
60
+ "decoder_start_token_id": null,
61
+ "task_specific_params": null,
62
+ "problem_type": null,
63
+ "_name_or_path": "",
64
+ "_attn_implementation_autoset": false,
65
+ "transformers_version": "4.51.3",
66
+ "shared_attention": true,
67
+ "vocab_size": 50304,
68
+ "n_embd": 1152,
69
+ "n_layer": 24,
70
+ "n_head": 16,
71
+ "sequence_length": 1024,
72
+ "moe": false,
73
+ "moe_routing": null,
74
+ "moe_num_experts": 1,
75
+ "moe_num_experts_per_tok": 2,
76
+ "moe_softmax_order": "softmax_topk",
77
+ "moe_router_loss": "load_balancing_z_loss",
78
+ "moe_aux_loss_factor": 0.01,
79
+ "moe_z_loss_factor": 1.0,
80
+ "mlp_dim_exp_factor": 1.0,
81
+ "dropout": 0.0,
82
+ "bias": false,
83
+ "auto_map": {
84
+ "AutoConfig": "configuration.MoEGPTConfig",
85
+ "AutoModelForCausalLM": "modeling.MoEGPTForCausalLM",
86
+ "AutoTokenizer": "GPT2TokenizerFast"
87
+ },
88
+ "model_type": "moegpt"
89
+ }
configuration.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MoEGPTConfig(PretrainedConfig):
4
+ model_type = "moegpt"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=50304,
9
+ n_embd=768,
10
+ n_layer=12,
11
+ n_head=12,
12
+ sequence_length=1024,
13
+ moe=False,
14
+ moe_routing="standard_gating",
15
+ moe_num_experts=4,
16
+ moe_num_experts_per_tok=2,
17
+ moe_softmax_order="softmax_topk",
18
+ moe_router_loss="load_balancing_z_loss",
19
+ moe_aux_loss_factor=0.01,
20
+ moe_z_loss_factor=1.0,
21
+ mlp_dim_exp_factor=1.0,
22
+ dropout=0.0,
23
+ bias=False,
24
+ architectures=["MoEGPTForCausalLM"],
25
+ auto_map={
26
+ "AutoConfig": "configuration.MoEGPTConfig",
27
+ "AutoModelForCausalLM": "modeling.MoEGPTForCausalLM",
28
+ "AutoTokenizer": "GPT2TokenizerFast"
29
+ },
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.vocab_size = vocab_size
34
+ self.n_embd = n_embd
35
+ self.n_layer = n_layer
36
+ self.n_head = n_head
37
+ self.sequence_length = sequence_length
38
+ self.moe = moe
39
+ self.moe_routing = moe_routing
40
+ self.moe_num_experts = moe_num_experts
41
+ self.moe_num_experts_per_tok = moe_num_experts_per_tok
42
+ self.moe_softmax_order = moe_softmax_order
43
+ self.moe_router_loss = moe_router_loss
44
+ self.moe_aux_loss_factor = moe_aux_loss_factor
45
+ self.moe_z_loss_factor = moe_z_loss_factor
46
+ self.mlp_dim_exp_factor = mlp_dim_exp_factor
47
+ self.dropout = dropout
48
+ self.bias = bias
49
+ self.architectures = architectures
50
+ self.auto_map = auto_map
51
+
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .configuration import MoEGPTConfig
3
+ # importa anche MoE, MaskedMoE, TimeDependantMoE ecc.
4
+ import math
5
+ import inspect
6
+ from typing import Optional, Dict, Any
7
+ from dataclasses import dataclass
8
+ import tiktoken
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ from transformers.utils import ModelOutput
14
+
15
+
16
+ from .moe import (
17
+ #ExpertChoiceMoE,
18
+ MaskedMoE,
19
+ TimeDependantMoE,
20
+ MoE,
21
+ )
22
+
23
+ from .aux_losses import (
24
+ entropy_reg,
25
+ load_balancing_loss,
26
+ router_z_loss,
27
+ )
28
+
29
+ @dataclass
30
+ class Output(ModelOutput):
31
+ logits: torch.FloatTensor = None
32
+ loss: Optional[torch.FloatTensor] = None
33
+ loss_to_log: Optional[float] = None
34
+ aux_losses_mlp: Optional[Dict[str, torch.FloatTensor]] = None
35
+ mlp_router_logits: Optional[torch.FloatTensor] = None
36
+ aux_losses_attn: Optional[Dict[str, torch.FloatTensor]] = None
37
+ attn_router_logits: Optional[torch.FloatTensor] = None
38
+
39
+ def __repr__(self):
40
+ return f"Output(logits={self.logits}, loss={self.loss}, aux_losses={self.aux_losses}, router_logits={self.router_logits})"
41
+
42
+ class LayerNorm(nn.Module):
43
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
44
+
45
+ def __init__(self, ndim, bias):
46
+ super().__init__()
47
+ self.weight = nn.Parameter(torch.ones(ndim))
48
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
49
+
50
+ def forward(self, input):
51
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
52
+
53
+ class CausalSelfAttention(nn.Module):
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ assert config.n_embd % config.n_head == 0
57
+ # key, query, value projections for all heads, but in a batch
58
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
59
+ # output projection
60
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
61
+ # regularization
62
+ self.attn_dropout = nn.Dropout(config.dropout)
63
+ self.resid_dropout = nn.Dropout(config.dropout)
64
+ self.n_head = config.n_head
65
+ self.n_embd = config.n_embd
66
+ self.dropout = config.dropout
67
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
68
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
69
+ if not self.flash:
70
+ print(
71
+ "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
72
+ )
73
+ # causal mask to ensure that attention is only applied to the left in the input sequence
74
+ self.register_buffer(
75
+ "bias",
76
+ torch.tril(
77
+ torch.ones(config.sequence_length, config.sequence_length)
78
+ ).view(1, 1, config.sequence_length, config.sequence_length),
79
+ )
80
+
81
+ def forward(self, x):
82
+ if x.ndim != 3:
83
+ x = x.unsqueeze(0) # handles the router input, since it previosly squashed the batch dim
84
+ # batch size, sequence length, embedding dimensionality (n_embd)
85
+ (
86
+ B,
87
+ T,
88
+ C,
89
+ ) = x.size()
90
+
91
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
92
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
93
+ # (B, T, nh, hs)
94
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
95
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
96
+
97
+ # (B, nh, T, hs)
98
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
99
+
100
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
101
+ if self.flash:
102
+ # efficient attention using Flash Attention CUDA kernels
103
+ y = torch.nn.functional.scaled_dot_product_attention(
104
+ q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
105
+ )
106
+ else:
107
+ # manual implementation of attention
108
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
109
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
110
+ att = F.softmax(att, dim=-1)
111
+ att = self.attn_dropout(att)
112
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
113
+ y = (
114
+ y.transpose(1, 2).contiguous().view(B, T, C)
115
+ ) # re-assemble all head outputs side by side
116
+
117
+ # output projection
118
+ y = self.resid_dropout(self.c_proj(y))
119
+ return y, {}
120
+
121
+
122
+ class MLP(nn.Module):
123
+ def __init__(self, config):
124
+ super().__init__()
125
+ self.dim_exp_factor = int(config.mlp_dim_exp_factor * 4)
126
+
127
+ self.c_fc = nn.Linear(
128
+ config.n_embd, self.dim_exp_factor * config.n_embd, bias=config.bias
129
+ )
130
+ self.c_proj = nn.Linear(
131
+ self.dim_exp_factor * config.n_embd, config.n_embd, bias=config.bias
132
+ )
133
+ self.dropout = nn.Dropout(config.dropout)
134
+ self.activation = nn.GELU()
135
+
136
+ def forward(self, x):
137
+ x = self.c_fc(x)
138
+ x = self.activation(x)
139
+ x = self.c_proj(x)
140
+ x = self.dropout(x)
141
+ # need to return same type as the MoE block, but in this case it's empty
142
+ return x, {}
143
+
144
+
145
+ class Block(nn.Module):
146
+ def __init__(self, config):
147
+ super().__init__()
148
+ self.moe_config = config.moe_routing
149
+ self.shared_attention = config.shared_attention
150
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
151
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
152
+ if not config.moe and not config.shared_attention:
153
+ raise ValueError(
154
+ "If not using MoE, shared attention must be set to True"
155
+ )
156
+
157
+ if self.shared_attention:
158
+ self.attn = CausalSelfAttention(config)
159
+
160
+ if config.moe:
161
+ if config.moe_routing == "standard_gating":
162
+ self.mlp = MoE(config, MLP)
163
+ if not self.shared_attention:
164
+ self.attn = MoE(config, CausalSelfAttention)
165
+ elif config.moe_routing == "masked":
166
+ self.mlp = TimeDependantMoE(config, MLP)
167
+ if not self.shared_attention:
168
+ self.attn = TimeDependantMoE(config, CausalSelfAttention)
169
+ else:
170
+ raise ValueError(f"Unknown routing: {config.routing}")
171
+ else:
172
+ self.mlp = MLP(config)
173
+
174
+ def forward(self, x, date, *args, **kwargs):
175
+ if self.moe_config == "masked":
176
+ if self.shared_attention:
177
+ attn_output, attn_logits_and_experts = self.attn(self.ln_1(x, *args, **kwargs))
178
+ else:
179
+ attn_output, attn_logits_and_experts = self.attn(self.ln_1(x, *args, **kwargs), date)
180
+ x = x + attn_output
181
+ x_, mlp_logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs), date)
182
+ else:
183
+ attn_output, attn_logits_and_experts = self.attn(self.ln_1(x, *args, **kwargs))
184
+ x = x + attn_output
185
+ x_, mlp_logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs))
186
+ x = x + x_
187
+ return x, mlp_logits_and_experts, attn_logits_and_experts
188
+
189
+
190
+ class MoEGPTForCausalLM(PreTrainedModel):
191
+ config_class = MoEGPTConfig
192
+ def __init__(self, config):
193
+ super().__init__(config)
194
+ assert config.vocab_size is not None
195
+ assert config.sequence_length is not None
196
+ self.config = config
197
+ self.tokenizer = tiktoken.get_encoding("gpt2")
198
+
199
+ self.transformer = nn.ModuleDict(
200
+ dict(
201
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
202
+ wpe=nn.Embedding(config.sequence_length, config.n_embd),
203
+ drop=nn.Dropout(config.dropout),
204
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
205
+ ln_f=LayerNorm(config.n_embd, bias=config.bias),
206
+ )
207
+ )
208
+
209
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
210
+ # with weight tying when using torch.compile() some warnings get generated:
211
+ # "UserWarning: functional_call was passed multiple values for tied weights.
212
+ # This behavior is deprecated and will be an error in future versions"
213
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
214
+ self.transformer.wte.weight = (
215
+ self.lm_head.weight
216
+ ) # https://paperswithcode.com/method/weight-tying
217
+
218
+ # init all weights
219
+ self.apply(self._init_weights)
220
+ # apply special scaled init to the residual projections, per GPT-2 paper
221
+ for pn, p in self.named_parameters():
222
+ if pn.endswith("c_proj.weight"):
223
+ torch.nn.init.normal_(
224
+ p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
225
+ )
226
+ if pn.endswith("router.weight"):
227
+ # special scaled init to moe router?
228
+ with torch.no_grad():
229
+ dim = 1 if config.moe_routing == "standard_gating" else 0
230
+ std = p.std()
231
+ p.div_(p.sum(dim=dim, keepdim=True))
232
+ p.mul_(std / p.std())
233
+
234
+ def get_router_losses(self, logits, selected_experts, eval=False):
235
+ # logits: (b * seq_len, n_experts)
236
+ # selected_experts: (b * seq_len, topk)
237
+ if eval: # eval mode, compute all losses
238
+ return {
239
+ "moe_entropy_loss": entropy_reg(logits),
240
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
241
+ "moe_z_loss": router_z_loss(logits),
242
+ }
243
+ if self.config.moe_router_loss == "entropy":
244
+ return {
245
+ "moe_entropy_loss": entropy_reg(logits),
246
+ }
247
+ elif self.config.moe_router_loss == "load_balancing_only":
248
+ return {
249
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
250
+ }
251
+ elif self.config.moe_router_loss == "load_balancing_z_loss":
252
+ return {
253
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
254
+ "moe_z_loss": router_z_loss(logits),
255
+ }
256
+ return {}
257
+
258
+ def get_num_params(self, non_embedding=True):
259
+ """
260
+ Return the number of parameters in the model.
261
+ For non-embedding count (default), the position embeddings get subtracted.
262
+ The token embeddings would too, except due to the parameter sharing these
263
+ params are actually used as weights in the final layer, so we include them.
264
+ """
265
+ n_params = sum(p.numel() for p in self.parameters())
266
+ if non_embedding:
267
+ n_params -= self.transformer.wpe.weight.numel()
268
+ return n_params
269
+
270
+ def _init_weights(self, module):
271
+ if isinstance(module, nn.Linear):
272
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
273
+ if module.bias is not None:
274
+ torch.nn.init.zeros_(module.bias)
275
+ elif isinstance(module, nn.Embedding):
276
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
277
+
278
+ def forward(self, input_ids, date=None, labels=None, attention_mask=None, get_logits=True, moe=False, inputs_embeds=None, output_hidden_states=False, output_attentions=False, return_dict=True):
279
+ device = input_ids.device
280
+ b, t = input_ids.size()
281
+ assert (
282
+ t <= self.config.sequence_length
283
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
284
+ # shape (1, t)
285
+ if date is None:
286
+ # set all the date to 6
287
+ date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0)
288
+ else:
289
+ date = (date - 2013) // 2 + 1
290
+ date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0)
291
+
292
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
293
+
294
+ # forward the GPT model itself
295
+ tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd)
296
+ pos_emb = self.transformer.wpe(
297
+ pos
298
+ ) # position embeddings of shape (1, t, n_embd)
299
+ x = self.transformer.drop(tok_emb + pos_emb)
300
+
301
+ # router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts)
302
+ mlp_router_logits = []
303
+ attn_router_logits = []
304
+ # experts is a list for each layer's selected experts, shape (b * seq_len, topk)
305
+ mlp_experts = []
306
+ attn_experts = []
307
+
308
+
309
+ # forward pass through all the transformer blocks
310
+ for block in self.transformer.h:
311
+ x, mlp_logits_and_experts, attn_logits_and_experts = block(x, date)
312
+ if len(mlp_logits_and_experts) > 0:
313
+ mlp_router_logits.append(mlp_logits_and_experts["router_logits"])
314
+ mlp_experts.append(mlp_logits_and_experts["selected_experts"])
315
+ if len(attn_logits_and_experts) > 0:
316
+ attn_router_logits.append(attn_logits_and_experts["router_logits"])
317
+ attn_experts.append(attn_logits_and_experts["selected_experts"])
318
+ x = self.transformer.ln_f(x)
319
+
320
+ # aux_losses is a dict with keys for different auxiliary losses
321
+ aux_losses_mlp = {}
322
+ aux_losses_attn = {}
323
+
324
+
325
+ if labels is not None:
326
+ # if we are given some desired targets also calculate the loss
327
+ logits = self.lm_head(x)
328
+ loss = F.cross_entropy(
329
+ logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1
330
+ )
331
+ loss_to_log = loss.item()
332
+ if moe and (self.config.moe_routing == "standard_gating" or self.config.moe_routing == "masked"):
333
+ # calculate the router losses per layer
334
+ for logit, expert_choice in zip(mlp_router_logits, mlp_experts):
335
+ router_losses = self.get_router_losses(
336
+ logit, expert_choice, eval=not self.training
337
+ )
338
+ for k, v in router_losses.items():
339
+ aux_losses_mlp[k] = aux_losses_mlp.get(k, 0.0) + v
340
+ if self.training:
341
+ loss += (
342
+ v
343
+ * getattr(self.config, k + "_factor")
344
+ / self.config.n_layer
345
+ )
346
+ for logit, expert_choice in zip(attn_router_logits, attn_experts):
347
+ router_losses = self.get_router_losses(
348
+ logit, expert_choice, eval=not self.training
349
+ )
350
+ for k, v in router_losses.items():
351
+ aux_losses_attn[k] = aux_losses_attn.get(k, 0.0) + v
352
+ if self.training:
353
+ loss += (
354
+ v
355
+ * getattr(self.config, k + "_factor")
356
+ / self.config.n_layer
357
+ )
358
+ else:
359
+ # inference-time mini-optimization: only forward the lm_head on the very last position
360
+ logits = self.lm_head(
361
+ #x[:, [-1], :]
362
+ x
363
+ ) # note: using list [-1] to preserve the time dim
364
+ loss = None
365
+ loss_to_log = None
366
+ logits = logits if get_logits else None
367
+ mlp_router_logits = (
368
+ torch.stack(mlp_router_logits, dim=0) if len(mlp_router_logits) > 0 else None
369
+ )
370
+ attn_router_logits = (
371
+ torch.stack(attn_router_logits, dim=0) if len(attn_router_logits) > 0 else None
372
+ )
373
+
374
+ return Output(logits = logits, loss = loss, loss_to_log= loss_to_log,
375
+ aux_losses_mlp=aux_losses_mlp, mlp_router_logits=mlp_router_logits,
376
+ aux_losses_attn=aux_losses_attn, attn_router_logits=attn_router_logits)
377
+
378
+ def crop_sequence_length(self, sequence_length):
379
+ # model surgery to decrease the block size if necessary
380
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
381
+ # but want to use a smaller block size for some smaller, simpler model
382
+ assert sequence_length <= self.config.sequence_length
383
+ self.config.sequence_length = sequence_length
384
+ self.transformer.wpe.weight = nn.Parameter(
385
+ self.transformer.wpe.weight[:sequence_length]
386
+ )
387
+ for block in self.transformer.h:
388
+ block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length]
389
+
390
+
391
+ def get_parameter_group_specs(self):
392
+ """
393
+ This long function is unfortunately doing something very simple and is being very defensive:
394
+ We are separating out all parameters of the model into two buckets: those that will experience
395
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
396
+ We are then returning the PyTorch optimizer object.
397
+ """
398
+
399
+ # separate out all parameters to those that will and won't experience regularizing weight decay
400
+ decay = set()
401
+ no_decay = set()
402
+ whitelist_weight_modules = (torch.nn.Linear,)
403
+
404
+ BLACKLIST_WEIGHT_MODULES = (
405
+ torch.nn.LayerNorm,
406
+ LayerNorm,
407
+ torch.nn.Embedding,
408
+ )
409
+
410
+ for mn, m in self.named_modules():
411
+ for pn, p in m.named_parameters():
412
+ fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
413
+ # random note: because named_modules and named_parameters are recursive
414
+ # we will see the same tensors p many many times. but doing it this way
415
+ # allows us to know which parent module any tensor p belongs to...
416
+ if pn.endswith("bias"):
417
+ # all biases will not be decayed
418
+ no_decay.add(fpn)
419
+ elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
420
+ # weights of whitelist modules will be weight decayed
421
+ decay.add(fpn)
422
+ elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES):
423
+ # weights of blacklist modules will NOT be weight decayed
424
+ no_decay.add(fpn)
425
+
426
+ # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
427
+ # will appear in the no_decay and decay sets respectively after the above.
428
+ # In addition, because named_parameters() doesn't return duplicates, it
429
+ # will only return the first occurence, key'd by 'transformer.wte.weight', below.
430
+ # so let's manually remove 'lm_head.weight' from decay set. This will include
431
+ # this tensor into optimization via transformer.wte.weight only, and not decayed.
432
+ decay.remove("lm_head.weight")
433
+
434
+ # validate that we considered every parameter
435
+ param_dict = {pn: p for pn, p in self.named_parameters()}
436
+ inter_params = decay & no_decay
437
+ union_params = decay | no_decay
438
+ assert (
439
+ len(inter_params) == 0
440
+ ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
441
+ assert (
442
+ len(param_dict.keys() - union_params) == 0
443
+ ), "parameters %s were not separated into either decay/no_decay set!" % (
444
+ str(param_dict.keys() - union_params),
445
+ )
446
+
447
+ # create the pytorch optimizer object
448
+ return [
449
+ {"params": sorted(list(decay))},
450
+ {"params": sorted(list(no_decay)), "weight_decay": 0.0},
451
+ ]
452
+
453
+ @torch.no_grad()
454
+ def generate(self, input_ids, max_new_tokens, date = None, temperature=1.0, top_k=None):
455
+ """
456
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
457
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
458
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
459
+ """
460
+ idx = input_ids
461
+ for _ in range(max_new_tokens):
462
+ # if the sequence context is growing too long we must crop it at sequence_length
463
+ idx_cond = (
464
+ idx
465
+ if idx.size(1) <= self.config.sequence_length
466
+ else idx[:, -self.config.sequence_length :]
467
+ )
468
+ # forward the model to get the logits for the index in the sequence
469
+ logits = self(idx_cond, date, get_logits=True).logits
470
+ # pluck the logits at the final step and scale by desired temperature
471
+ logits = logits[:, -1, :] / temperature
472
+ # optionally crop the logits to only the top k options
473
+ if top_k is not None:
474
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
475
+ logits[logits < v[:, [-1]]] = -float("Inf")
476
+ # apply softmax to convert logits to (normalized) probabilities
477
+ probs = F.softmax(logits, dim=-1)
478
+ # sample from the distribution
479
+ idx_next = torch.multinomial(probs, num_samples=1)
480
+ # append sampled index to the running sequence and continue
481
+ idx = torch.cat((idx, idx_next), dim=1)
482
+ # check if we hit the end of the sequence
483
+ if idx_next.item() == 50526:
484
+ break
485
+
486
+ return idx
487
+
488
+ @torch.no_grad()
489
+ def generate_from_string(self, in_str, max_new_tokens, date = None, temperature=1.0, top_k=None):
490
+ idx = (
491
+ torch.tensor(
492
+ self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
493
+ )
494
+ .view(1, -1)
495
+ .to(self.lm_head.weight.device)
496
+ )
497
+ out_idx = (
498
+ self.generate(idx, max_new_tokens, date, temperature, top_k)
499
+ .view(-1)
500
+ .to("cpu")
501
+ .numpy()
502
+ )
503
+ return self.tokenizer.decode(out_idx)
504
+
505
+ def get_input_embeddings(self):
506
+ return self.transformer.wte
507
+
508
+ def set_input_embeddings(self, new_embeddings):
509
+ self.transformer.wte = new_embeddings
510
+ # reset the lm_head to use the new embeddings
511
+ # this is necessary because the lm_head is tied to the input embeddings
512
+ self.lm_head = nn.Linear(
513
+ self.config.n_embd, new_embeddings.weight.shape[0] , bias=False
514
+ )
moe.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple MoE routing implementations that replace the MLP block in a standard transformer.
3
+ References:
4
+ 1) Mistral Source for Mixtral MoEs:
5
+ https://github.com/mistralai/mistral-src
6
+ 2) ST-MoE:
7
+ https://arxiv.org/abs/2202.08906
8
+ 3) Our notepad of MoE resources:
9
+ https://docs.google.com/document/d/1NuQ5jr7V-Jv1ui7p4KrxO_JTz-7bpYcYMmh49EeJ-QA/edit?usp=sharing
10
+ """
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import bisect
16
+
17
+
18
+
19
+ class MoE(nn.Module):
20
+ """
21
+ Simplest MoE implementation with a linear router and softmax over experts.
22
+
23
+ Note that in this implementation, we simply loop over the experts and
24
+ aggregate the results. This is not the most efficient way to do it, but
25
+ it also avoids the large memory overhead _and_ has no token dropping
26
+ (because we do not need the capacity factor).
27
+ """
28
+
29
+ def __init__(self, config, mlp):
30
+ super().__init__()
31
+ assert config.moe_num_experts > 0
32
+ self.experts = nn.ModuleList(
33
+ [mlp(config=config) for _ in range(config.moe_num_experts)]
34
+ )
35
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False)
36
+ self.top_k = config.moe_num_experts_per_tok
37
+ self.softmax_order = config.moe_softmax_order
38
+
39
+ def forward(self, inputs: torch.Tensor):
40
+ # [batch_size * sequence_length, n_embd]
41
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
42
+ # [batch_size * sequence_length, num_experts]
43
+ router_logits = self.router(inputs_squashed)
44
+
45
+ # note that selected experts will be the same for all orders:
46
+ # softmax doesnt change top-k, but the weights are different
47
+ if self.softmax_order == "softmax_topk":
48
+ all_probs = F.softmax(router_logits, dim=1)
49
+ weights, selected_experts = torch.topk(all_probs, self.top_k)
50
+ elif self.softmax_order == "topk_softmax":
51
+ weights, selected_experts = torch.topk(router_logits, self.top_k)
52
+ weights = F.softmax(weights, dim=-1)
53
+ else:
54
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
55
+
56
+ results = torch.zeros_like(inputs_squashed)
57
+ # naive looping over experts
58
+ for i, expert in enumerate(self.experts):
59
+ batch_idx, nth_expert = torch.where(selected_experts == i)
60
+ output, _ = expert(inputs_squashed[batch_idx])
61
+ results[batch_idx] += weights[batch_idx, nth_expert, None] * output
62
+
63
+ # return results and router logits (for aux loss calculation later)
64
+ return results.view_as(inputs), {
65
+ "router_logits": router_logits,
66
+ "selected_experts": selected_experts,
67
+ }
68
+
69
+
70
+ class DummyExpert(nn.Module):
71
+ def __init__(self, output_size: int):
72
+ super().__init__()
73
+ self._output_size = output_size
74
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
75
+ out = torch.zeros((self._output_size,), device=inputs.device)
76
+ return out, {}
77
+
78
+
79
+
80
+ class MaskedMoE(MoE):
81
+ def __init__(self, config, mlp):
82
+ super().__init__(config, mlp)
83
+ self._sequence_length = config.sequence_length
84
+ self.experts.append(DummyExpert(config.n_embd))
85
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts+1, bias=False)
86
+
87
+
88
+ def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
89
+ seq_len = inputs.shape[1]
90
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
91
+ router_logits = self.router(inputs_squashed)
92
+ mask = torch.cat(
93
+ (mask, torch.ones((mask.shape[0], 1), device=mask.device)),
94
+ dim=1
95
+ )
96
+ mask = mask.repeat_interleave(seq_len, dim=0)
97
+ router_logits = router_logits*mask
98
+
99
+ # note that selected experts will be the same for all orders:
100
+ # softmax doesnt change top-k, but the weights are different
101
+ if self.softmax_order == "softmax_topk":
102
+ all_probs = F.softmax(router_logits, dim=1)
103
+ weights, selected_experts = torch.topk(all_probs, self.top_k)
104
+ elif self.softmax_order == "topk_softmax":
105
+ weights, selected_experts = torch.topk(router_logits, self.top_k)
106
+ weights = F.softmax(weights, dim=-1)
107
+ else:
108
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
109
+
110
+ results = torch.zeros_like(inputs_squashed)
111
+ # naive looping over experts
112
+ for i, expert in enumerate(self.experts):
113
+ batch_idx, nth_expert = torch.where(selected_experts == i)
114
+ output, _ = expert(inputs_squashed[batch_idx])
115
+ results[batch_idx] += weights[batch_idx, nth_expert, None] * output.squeeze(0)
116
+
117
+ # return results and router logits (for aux loss calculation later)
118
+ return results.view_as(inputs), {
119
+ "router_logits": router_logits,
120
+ "selected_experts": selected_experts,
121
+ }
122
+
123
+
124
+ class TimeDependantMoE(nn.Module):
125
+ def __init__(self, config, mlp):
126
+ super().__init__()
127
+ self._num_experts = config.moe_num_experts
128
+ self._mask_moe = MaskedMoE(config, mlp)
129
+
130
+ def forward(self, x, date):
131
+ mask_date = torch.zeros(x.shape[0], self._num_experts).to(x.device)
132
+ range_tensor = torch.arange(self._num_experts).unsqueeze(0).to(x.device)
133
+ mask_date = (range_tensor < date.unsqueeze(1)).float()
134
+ return self._mask_moe(x, mask_date)
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff