robinfaro commited on
Commit
9cba846
·
verified ·
1 Parent(s): e2a666f

Upload custom config and model files

Browse files
Files changed (13) hide show
  1. README.md +10 -0
  2. __init__.py +2 -0
  3. aux_losses.py +88 -0
  4. config.json +88 -0
  5. configuration.py +50 -0
  6. merges.txt +0 -0
  7. model.safetensors +3 -0
  8. modeling.py +453 -0
  9. moe.py +134 -0
  10. special_tokens_map.json +5 -0
  11. tokenizer.json +0 -0
  12. tokenizer_config.json +20 -0
  13. vocab.json +0 -0
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration import MoEGPTConfig
2
+ from .modeling import MoEGPTForCausalLM
aux_losses.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_mean(x, dim):
7
+ return torch.logsumexp(x, dim=dim) - torch.log(
8
+ torch.tensor(x.shape[dim], dtype=torch.float32)
9
+ )
10
+
11
+
12
+ def entropy_reg(logits: torch.Tensor, mean_over_batch: bool = True):
13
+ """Entropy regularization for the router."""
14
+
15
+ entropy_l = lambda l: -(l * l.exp()).sum(-1)
16
+ # softmax over experts
17
+ # logits: [batch_size * sequence_length, num_experts]
18
+ logprobs = F.log_softmax(logits, dim=-1)
19
+ if mean_over_batch:
20
+ # take mean probability over batch
21
+ logprobs = log_mean(logprobs, 0)
22
+
23
+ return -entropy_l(logprobs).mean()
24
+
25
+
26
+ # two losses below are adapted from
27
+ # https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/routing.py
28
+ def load_balancing_loss(logits: torch.Tensor, expert_indices: torch.Tensor) -> float:
29
+ """Computes auxiliary load balancing loss as in Switch Transformer.
30
+
31
+ See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
32
+ implements the loss function presented in equations (4) - (6). It aims to
33
+ penalize those cases where the routing between experts is unbalanced.
34
+
35
+ Args:
36
+ logits: logits assigned to each expert per token. Shape:
37
+ <float32>[batch_size * sequence_length, num_experts].
38
+ expert_indices: <int>[batch_size * sequence_length, num_selected_experts]
39
+ indices identifying the top num_selected_experts for a given token.
40
+
41
+ Returns:
42
+ The auxiliary loss.
43
+ """
44
+ # num_token = batch_size * sequence_length
45
+ num_token, num_experts = logits.shape
46
+
47
+ # Shape: [batch_size * sequence_length, num_selected_experts, num_experts].
48
+ expert_mask = F.one_hot(expert_indices, num_experts)
49
+ # For a given token, determine if it was routed to a given expert.
50
+ # Shape: [batch_size * sequence_length, num_experts]
51
+ expert_mask, _ = torch.max(expert_mask, dim=-2)
52
+
53
+ # shape [num_experts]
54
+ tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32)
55
+
56
+ # compute router probability per expert in log space for numerical stability
57
+ logprobs = F.log_softmax(logits, dim=-1)
58
+ # take mean probability over batch
59
+ # shape [num_experts]
60
+ logprobs = log_mean(logprobs, dim=0)
61
+ router_prob_per_expert = torch.exp(logprobs)
62
+ return (
63
+ torch.mean( # mean over experts
64
+ tokens_per_expert * router_prob_per_expert,
65
+ dtype=torch.float32,
66
+ )
67
+ * num_experts
68
+ )
69
+
70
+
71
+ def router_z_loss(router_logits: torch.Tensor) -> float:
72
+ """Compute router z-loss.
73
+
74
+ The router z-loss was introduced in Designing Effective Sparse Expert Models
75
+ (https://arxiv.org/abs/2202.08906). It encourages router logits to remain
76
+ small in an effort to improve stability.
77
+
78
+ Args:
79
+ router_logits: <float>[batch_size * sequence_length, num_experts]
80
+ router logits
81
+
82
+ Returns:
83
+ Scalar router z-loss.
84
+ """
85
+ num_tokens, _ = router_logits.shape
86
+ log_z = torch.logsumexp(router_logits, dim=-1)
87
+ z_loss = log_z**2
88
+ return torch.sum(z_loss, dtype=torch.float32) / (num_tokens)
config.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "return_dict": true,
3
+ "output_hidden_states": false,
4
+ "output_attentions": false,
5
+ "torchscript": false,
6
+ "torch_dtype": null,
7
+ "use_bfloat16": false,
8
+ "tf_legacy_loss": false,
9
+ "pruned_heads": {},
10
+ "tie_word_embeddings": true,
11
+ "chunk_size_feed_forward": 0,
12
+ "is_encoder_decoder": false,
13
+ "is_decoder": false,
14
+ "cross_attention_hidden_size": null,
15
+ "add_cross_attention": false,
16
+ "tie_encoder_decoder": false,
17
+ "max_length": 20,
18
+ "min_length": 0,
19
+ "do_sample": false,
20
+ "early_stopping": false,
21
+ "num_beams": 1,
22
+ "num_beam_groups": 1,
23
+ "diversity_penalty": 0.0,
24
+ "temperature": 1.0,
25
+ "top_k": 50,
26
+ "top_p": 1.0,
27
+ "typical_p": 1.0,
28
+ "repetition_penalty": 1.0,
29
+ "length_penalty": 1.0,
30
+ "no_repeat_ngram_size": 0,
31
+ "encoder_no_repeat_ngram_size": 0,
32
+ "bad_words_ids": null,
33
+ "num_return_sequences": 1,
34
+ "output_scores": false,
35
+ "return_dict_in_generate": false,
36
+ "forced_bos_token_id": null,
37
+ "forced_eos_token_id": null,
38
+ "remove_invalid_values": false,
39
+ "exponential_decay_length_penalty": null,
40
+ "suppress_tokens": null,
41
+ "begin_suppress_tokens": null,
42
+ "architectures": [
43
+ "MoEGPTForCausalLM"
44
+ ],
45
+ "finetuning_task": null,
46
+ "id2label": {
47
+ "0": "LABEL_0",
48
+ "1": "LABEL_1"
49
+ },
50
+ "label2id": {
51
+ "LABEL_0": 0,
52
+ "LABEL_1": 1
53
+ },
54
+ "tokenizer_class": null,
55
+ "prefix": null,
56
+ "bos_token_id": null,
57
+ "pad_token_id": null,
58
+ "eos_token_id": null,
59
+ "sep_token_id": null,
60
+ "decoder_start_token_id": null,
61
+ "task_specific_params": null,
62
+ "problem_type": null,
63
+ "_name_or_path": "",
64
+ "_attn_implementation_autoset": false,
65
+ "transformers_version": "4.51.0",
66
+ "batch_size": 16,
67
+ "vocab_size": 50304,
68
+ "n_embd": 1152,
69
+ "n_layer": 24,
70
+ "n_head": 16,
71
+ "sequence_length": 1024,
72
+ "moe": true,
73
+ "moe_routing": "masked",
74
+ "moe_num_experts": 6,
75
+ "moe_num_experts_per_tok": 2,
76
+ "moe_softmax_order": "softmax_topk",
77
+ "moe_router_loss": "load_balancing_z_loss",
78
+ "moe_aux_loss_factor": 0.01,
79
+ "moe_z_loss_factor": 1.0,
80
+ "mlp_dim_exp_factor": 1.0,
81
+ "dropout": 0.0,
82
+ "bias": false,
83
+ "auto_map": {
84
+ "AutoConfig": "configuration.MoEGPTConfig",
85
+ "AutoModelForCausalLM": "modeling.MoEGPTForCausalLM"
86
+ },
87
+ "model_type": "moegpt"
88
+ }
configuration.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MoEGPTConfig(PretrainedConfig):
4
+ model_type = "moegpt"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=50304,
9
+ n_embd=768,
10
+ n_layer=12,
11
+ n_head=12,
12
+ sequence_length=1024,
13
+ moe=False,
14
+ moe_routing="standard_gating",
15
+ moe_num_experts=4,
16
+ moe_num_experts_per_tok=2,
17
+ moe_softmax_order="softmax_topk",
18
+ moe_router_loss="load_balancing_z_loss",
19
+ moe_aux_loss_factor=0.01,
20
+ moe_z_loss_factor=1.0,
21
+ mlp_dim_exp_factor=1.0,
22
+ dropout=0.0,
23
+ bias=False,
24
+ architectures=["MoEGPTForCausalLM"],
25
+ auto_map={
26
+ "AutoConfig": "configuration.MoEGPTConfig",
27
+ "AutoModelForCausalLM": "modeling.MoEGPTForCausalLM"
28
+ },
29
+ **kwargs,
30
+ ):
31
+ super().__init__(**kwargs)
32
+ self.vocab_size = vocab_size
33
+ self.n_embd = n_embd
34
+ self.n_layer = n_layer
35
+ self.n_head = n_head
36
+ self.sequence_length = sequence_length
37
+ self.moe = moe
38
+ self.moe_routing = moe_routing
39
+ self.moe_num_experts = moe_num_experts
40
+ self.moe_num_experts_per_tok = moe_num_experts_per_tok
41
+ self.moe_softmax_order = moe_softmax_order
42
+ self.moe_router_loss = moe_router_loss
43
+ self.moe_aux_loss_factor = moe_aux_loss_factor
44
+ self.moe_z_loss_factor = moe_z_loss_factor
45
+ self.mlp_dim_exp_factor = mlp_dim_exp_factor
46
+ self.dropout = dropout
47
+ self.bias = bias
48
+ self.architectures = architectures
49
+ self.auto_map = auto_map
50
+
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:660e72ffc148cf135d809e75ac0b918f019d2923a0331753cebd76dd490bc5be
3
+ size 6862473440
modeling.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .configuration import MoEGPTConfig
3
+ # importa anche MoE, MaskedMoE, TimeDependantMoE ecc.
4
+ import math
5
+ import inspect
6
+
7
+ import tiktoken
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+
14
+
15
+ from .moe import (
16
+ #ExpertChoiceMoE,
17
+ MaskedMoE,
18
+ TimeDependantMoE,
19
+ MoE,
20
+ )
21
+
22
+ from .aux_losses import (
23
+ entropy_reg,
24
+ load_balancing_loss,
25
+ router_z_loss,
26
+ )
27
+
28
+ class Output:
29
+ def __init__(self, logits, loss=None, aux_losses=None, router_logits=None):
30
+ self.logits = logits
31
+ self.loss = loss
32
+ self.aux_losses = aux_losses
33
+ self.router_logits = router_logits
34
+
35
+ def __repr__(self):
36
+ return f"Output(logits={self.logits}, loss={self.loss}, aux_losses={self.aux_losses}, router_logits={self.router_logits})"
37
+
38
+ class LayerNorm(nn.Module):
39
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
40
+
41
+ def __init__(self, ndim, bias):
42
+ super().__init__()
43
+ self.weight = nn.Parameter(torch.ones(ndim))
44
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
45
+
46
+ def forward(self, input):
47
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
48
+
49
+ class CausalSelfAttention(nn.Module):
50
+ def __init__(self, config):
51
+ super().__init__()
52
+ assert config.n_embd % config.n_head == 0
53
+ # key, query, value projections for all heads, but in a batch
54
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
55
+ # output projection
56
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
57
+ # regularization
58
+ self.attn_dropout = nn.Dropout(config.dropout)
59
+ self.resid_dropout = nn.Dropout(config.dropout)
60
+ self.n_head = config.n_head
61
+ self.n_embd = config.n_embd
62
+ self.dropout = config.dropout
63
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
64
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
65
+ if not self.flash:
66
+ print(
67
+ "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
68
+ )
69
+ # causal mask to ensure that attention is only applied to the left in the input sequence
70
+ self.register_buffer(
71
+ "bias",
72
+ torch.tril(
73
+ torch.ones(config.sequence_length, config.sequence_length)
74
+ ).view(1, 1, config.sequence_length, config.sequence_length),
75
+ )
76
+
77
+ def forward(self, x):
78
+ # batch size, sequence length, embedding dimensionality (n_embd)
79
+ (
80
+ B,
81
+ T,
82
+ C,
83
+ ) = x.size()
84
+
85
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
86
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
87
+ # (B, T, nh, hs)
88
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
89
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
90
+
91
+ # (B, nh, T, hs)
92
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
93
+
94
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
95
+ if self.flash:
96
+ # efficient attention using Flash Attention CUDA kernels
97
+ y = torch.nn.functional.scaled_dot_product_attention(
98
+ q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
99
+ )
100
+ else:
101
+ # manual implementation of attention
102
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
103
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
104
+ att = F.softmax(att, dim=-1)
105
+ att = self.attn_dropout(att)
106
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
107
+ y = (
108
+ y.transpose(1, 2).contiguous().view(B, T, C)
109
+ ) # re-assemble all head outputs side by side
110
+
111
+ # output projection
112
+ y = self.resid_dropout(self.c_proj(y))
113
+ return y
114
+
115
+
116
+ class MLP(nn.Module):
117
+ def __init__(self, config):
118
+ super().__init__()
119
+ self.dim_exp_factor = int(config.mlp_dim_exp_factor * 4)
120
+
121
+ self.c_fc = nn.Linear(
122
+ config.n_embd, self.dim_exp_factor * config.n_embd, bias=config.bias
123
+ )
124
+ self.c_proj = nn.Linear(
125
+ self.dim_exp_factor * config.n_embd, config.n_embd, bias=config.bias
126
+ )
127
+ self.dropout = nn.Dropout(config.dropout)
128
+ self.activation = nn.GELU()
129
+
130
+ def forward(self, x):
131
+ x = self.c_fc(x)
132
+ x = self.activation(x)
133
+ x = self.c_proj(x)
134
+ x = self.dropout(x)
135
+ # need to return same type as the MoE block, but in this case it's empty
136
+ return x, {}
137
+
138
+
139
+ class Block(nn.Module):
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
143
+ self.attn = CausalSelfAttention(config)
144
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
145
+ self.moe_config = config.moe_routing
146
+ if config.moe:
147
+ if config.moe_routing == "standard_gating":
148
+ self.mlp = MoE(config, MLP)
149
+ elif config.moe_routing == "masked":
150
+ self.mlp = TimeDependantMoE(config, MLP)
151
+ #elif config.moe_routing == "expert_choice":
152
+ # self.mlp = ExpertChoiceMoE(config, MLP)
153
+ else:
154
+ raise ValueError(f"Unknown routing: {config.routing}")
155
+ else:
156
+ self.mlp = MLP(config)
157
+
158
+ def forward(self, x, date, *args, **kwargs):
159
+ x = x + self.attn(self.ln_1(x, *args, **kwargs))
160
+ if self.moe_config == "masked":
161
+ x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs), date)
162
+ else:
163
+ x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs))
164
+ x = x + x_
165
+ return x, logits_and_experts
166
+
167
+
168
+ class MoEGPTForCausalLM(PreTrainedModel):
169
+ config_class = MoEGPTConfig
170
+ def __init__(self, config):
171
+ super().__init__(config)
172
+ assert config.vocab_size is not None
173
+ assert config.sequence_length is not None
174
+ self.config = config
175
+ self.tokenizer = tiktoken.get_encoding("gpt2")
176
+
177
+ self.transformer = nn.ModuleDict(
178
+ dict(
179
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
180
+ wpe=nn.Embedding(config.sequence_length, config.n_embd),
181
+ drop=nn.Dropout(config.dropout),
182
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
183
+ ln_f=LayerNorm(config.n_embd, bias=config.bias),
184
+ )
185
+ )
186
+
187
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
188
+ # with weight tying when using torch.compile() some warnings get generated:
189
+ # "UserWarning: functional_call was passed multiple values for tied weights.
190
+ # This behavior is deprecated and will be an error in future versions"
191
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
192
+ self.transformer.wte.weight = (
193
+ self.lm_head.weight
194
+ ) # https://paperswithcode.com/method/weight-tying
195
+
196
+ # init all weights
197
+ self.apply(self._init_weights)
198
+ # apply special scaled init to the residual projections, per GPT-2 paper
199
+ for pn, p in self.named_parameters():
200
+ if pn.endswith("c_proj.weight"):
201
+ torch.nn.init.normal_(
202
+ p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
203
+ )
204
+ if pn.endswith("router.weight"):
205
+ # special scaled init to moe router?
206
+ with torch.no_grad():
207
+ dim = 1 if config.moe_routing == "standard_gating" else 0
208
+ std = p.std()
209
+ p.div_(p.sum(dim=dim, keepdim=True))
210
+ p.mul_(std / p.std())
211
+
212
+ def get_router_losses(self, logits, selected_experts, eval=False):
213
+ # logits: (b * seq_len, n_experts)
214
+ # selected_experts: (b * seq_len, topk)
215
+ if eval: # eval mode, compute all losses
216
+ return {
217
+ "moe_entropy_loss": entropy_reg(logits),
218
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
219
+ "moe_z_loss": router_z_loss(logits),
220
+ }
221
+ if self.config.moe_router_loss == "entropy":
222
+ return {
223
+ "moe_entropy_loss": entropy_reg(logits),
224
+ }
225
+ elif self.config.moe_router_loss == "load_balancing_only":
226
+ return {
227
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
228
+ }
229
+ elif self.config.moe_router_loss == "load_balancing_z_loss":
230
+ return {
231
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
232
+ "moe_z_loss": router_z_loss(logits),
233
+ }
234
+ return {}
235
+
236
+ def get_num_params(self, non_embedding=True):
237
+ """
238
+ Return the number of parameters in the model.
239
+ For non-embedding count (default), the position embeddings get subtracted.
240
+ The token embeddings would too, except due to the parameter sharing these
241
+ params are actually used as weights in the final layer, so we include them.
242
+ """
243
+ n_params = sum(p.numel() for p in self.parameters())
244
+ if non_embedding:
245
+ n_params -= self.transformer.wpe.weight.numel()
246
+ return n_params
247
+
248
+ def _init_weights(self, module):
249
+ if isinstance(module, nn.Linear):
250
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
251
+ if module.bias is not None:
252
+ torch.nn.init.zeros_(module.bias)
253
+ elif isinstance(module, nn.Embedding):
254
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
255
+
256
+ def forward(self, idx, date=None, targets=None, get_logits=True, moe=False):
257
+ device = idx.device
258
+ b, t = idx.size()
259
+ assert (
260
+ t <= self.config.sequence_length
261
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
262
+ # shape (1, t)
263
+ if date is None:
264
+ # set all the date to 6
265
+ date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0)
266
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
267
+
268
+ # forward the GPT model itself
269
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
270
+ pos_emb = self.transformer.wpe(
271
+ pos
272
+ ) # position embeddings of shape (1, t, n_embd)
273
+ x = self.transformer.drop(tok_emb + pos_emb)
274
+
275
+ # router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts)
276
+ router_logits = []
277
+ # experts is a list for each layer's selected experts, shape (b * seq_len, topk)
278
+ experts = []
279
+
280
+ # forward pass through all the transformer blocks
281
+ for block in self.transformer.h:
282
+ x, logits_and_experts = block(x, date)
283
+ if len(logits_and_experts) > 0:
284
+ router_logits.append(logits_and_experts["router_logits"])
285
+ experts.append(logits_and_experts["selected_experts"])
286
+ x = self.transformer.ln_f(x)
287
+
288
+ # aux_losses is a dict with keys for different auxiliary losses
289
+ aux_losses = {}
290
+
291
+ if targets is not None:
292
+ # if we are given some desired targets also calculate the loss
293
+ logits = self.lm_head(x)
294
+ loss = F.cross_entropy(
295
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
296
+ )
297
+ if moe and (self.config.moe_routing == "standard_gating" or self.config.moe_routing == "masked"):
298
+ # calculate the router losses per layer
299
+ for logit, expert_choice in zip(router_logits, experts):
300
+ router_losses = self.get_router_losses(
301
+ logit, expert_choice, eval=not self.training
302
+ )
303
+ for k, v in router_losses.items():
304
+ aux_losses[k] = aux_losses.get(k, 0.0) + v
305
+ if self.training:
306
+ loss += (
307
+ v
308
+ * getattr(self.config, k + "_factor")
309
+ / self.config.n_layer
310
+ )
311
+ else:
312
+ # inference-time mini-optimization: only forward the lm_head on the very last position
313
+ logits = self.lm_head(
314
+ #x[:, [-1], :]
315
+ x
316
+ ) # note: using list [-1] to preserve the time dim
317
+ loss = None
318
+ logits = logits if get_logits else None
319
+ router_logits = (
320
+ torch.stack(router_logits, dim=0) if len(router_logits) > 0 else None
321
+ )
322
+ # return {
323
+ # "logits": logits,
324
+ # "loss": loss,
325
+ # "aux_losses": aux_losses,
326
+ # "router_logits": router_logits,
327
+ # }
328
+ return Output(logits = logits, loss = loss, aux_losses = aux_losses, router_logits = router_logits)
329
+
330
+ def crop_sequence_length(self, sequence_length):
331
+ # model surgery to decrease the block size if necessary
332
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
333
+ # but want to use a smaller block size for some smaller, simpler model
334
+ assert sequence_length <= self.config.sequence_length
335
+ self.config.sequence_length = sequence_length
336
+ self.transformer.wpe.weight = nn.Parameter(
337
+ self.transformer.wpe.weight[:sequence_length]
338
+ )
339
+ for block in self.transformer.h:
340
+ block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length]
341
+
342
+
343
+ def get_parameter_group_specs(self):
344
+ """
345
+ This long function is unfortunately doing something very simple and is being very defensive:
346
+ We are separating out all parameters of the model into two buckets: those that will experience
347
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
348
+ We are then returning the PyTorch optimizer object.
349
+ """
350
+
351
+ # separate out all parameters to those that will and won't experience regularizing weight decay
352
+ decay = set()
353
+ no_decay = set()
354
+ whitelist_weight_modules = (torch.nn.Linear,)
355
+
356
+ BLACKLIST_WEIGHT_MODULES = (
357
+ torch.nn.LayerNorm,
358
+ LayerNorm,
359
+ torch.nn.Embedding,
360
+ )
361
+
362
+ for mn, m in self.named_modules():
363
+ for pn, p in m.named_parameters():
364
+ fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
365
+ # random note: because named_modules and named_parameters are recursive
366
+ # we will see the same tensors p many many times. but doing it this way
367
+ # allows us to know which parent module any tensor p belongs to...
368
+ if pn.endswith("bias"):
369
+ # all biases will not be decayed
370
+ no_decay.add(fpn)
371
+ elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
372
+ # weights of whitelist modules will be weight decayed
373
+ decay.add(fpn)
374
+ elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES):
375
+ # weights of blacklist modules will NOT be weight decayed
376
+ no_decay.add(fpn)
377
+
378
+ # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
379
+ # will appear in the no_decay and decay sets respectively after the above.
380
+ # In addition, because named_parameters() doesn't return duplicates, it
381
+ # will only return the first occurence, key'd by 'transformer.wte.weight', below.
382
+ # so let's manually remove 'lm_head.weight' from decay set. This will include
383
+ # this tensor into optimization via transformer.wte.weight only, and not decayed.
384
+ decay.remove("lm_head.weight")
385
+
386
+ # validate that we considered every parameter
387
+ param_dict = {pn: p for pn, p in self.named_parameters()}
388
+ inter_params = decay & no_decay
389
+ union_params = decay | no_decay
390
+ assert (
391
+ len(inter_params) == 0
392
+ ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
393
+ assert (
394
+ len(param_dict.keys() - union_params) == 0
395
+ ), "parameters %s were not separated into either decay/no_decay set!" % (
396
+ str(param_dict.keys() - union_params),
397
+ )
398
+
399
+ # create the pytorch optimizer object
400
+ return [
401
+ {"params": sorted(list(decay))},
402
+ {"params": sorted(list(no_decay)), "weight_decay": 0.0},
403
+ ]
404
+
405
+ @torch.no_grad()
406
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
407
+ """
408
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
409
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
410
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
411
+ """
412
+ for _ in range(max_new_tokens):
413
+ # if the sequence context is growing too long we must crop it at sequence_length
414
+ idx_cond = (
415
+ idx
416
+ if idx.size(1) <= self.config.sequence_length
417
+ else idx[:, -self.config.sequence_length :]
418
+ )
419
+ # forward the model to get the logits for the index in the sequence
420
+ logits = self(idx_cond, get_logits=True)["logits"]
421
+ # pluck the logits at the final step and scale by desired temperature
422
+ logits = logits[:, -1, :] / temperature
423
+ # optionally crop the logits to only the top k options
424
+ if top_k is not None:
425
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
426
+ logits[logits < v[:, [-1]]] = -float("Inf")
427
+ # apply softmax to convert logits to (normalized) probabilities
428
+ probs = F.softmax(logits, dim=-1)
429
+ # sample from the distribution
430
+ idx_next = torch.multinomial(probs, num_samples=1)
431
+ # append sampled index to the running sequence and continue
432
+ idx = torch.cat((idx, idx_next), dim=1)
433
+
434
+ return idx
435
+
436
+ @torch.no_grad()
437
+ def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None):
438
+ idx = (
439
+ torch.tensor(
440
+ self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
441
+ )
442
+ .view(1, -1)
443
+ .to(self.lm_head.weight.device)
444
+ )
445
+ out_idx = (
446
+ self.generate(idx, max_new_tokens, temperature, top_k)
447
+ .view(-1)
448
+ .to("cpu")
449
+ .numpy()
450
+ )
451
+ return self.tokenizer.decode(out_idx)
452
+
453
+ # copia la tua GPTBase qui dentro adattando tutto a `self.config`
moe.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple MoE routing implementations that replace the MLP block in a standard transformer.
3
+ References:
4
+ 1) Mistral Source for Mixtral MoEs:
5
+ https://github.com/mistralai/mistral-src
6
+ 2) ST-MoE:
7
+ https://arxiv.org/abs/2202.08906
8
+ 3) Our notepad of MoE resources:
9
+ https://docs.google.com/document/d/1NuQ5jr7V-Jv1ui7p4KrxO_JTz-7bpYcYMmh49EeJ-QA/edit?usp=sharing
10
+ """
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import bisect
16
+
17
+
18
+
19
+ class MoE(nn.Module):
20
+ """
21
+ Simplest MoE implementation with a linear router and softmax over experts.
22
+
23
+ Note that in this implementation, we simply loop over the experts and
24
+ aggregate the results. This is not the most efficient way to do it, but
25
+ it also avoids the large memory overhead _and_ has no token dropping
26
+ (because we do not need the capacity factor).
27
+ """
28
+
29
+ def __init__(self, config, mlp):
30
+ super().__init__()
31
+ assert config.moe_num_experts > 0
32
+ self.experts = nn.ModuleList(
33
+ [mlp(config=config) for _ in range(config.moe_num_experts)]
34
+ )
35
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False)
36
+ self.top_k = config.moe_num_experts_per_tok
37
+ self.softmax_order = config.moe_softmax_order
38
+
39
+ def forward(self, inputs: torch.Tensor):
40
+ # [batch_size * sequence_length, n_embd]
41
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
42
+ # [batch_size * sequence_length, num_experts]
43
+ router_logits = self.router(inputs_squashed)
44
+
45
+ # note that selected experts will be the same for all orders:
46
+ # softmax doesnt change top-k, but the weights are different
47
+ if self.softmax_order == "softmax_topk":
48
+ all_probs = F.softmax(router_logits, dim=1)
49
+ weights, selected_experts = torch.topk(all_probs, self.top_k)
50
+ elif self.softmax_order == "topk_softmax":
51
+ weights, selected_experts = torch.topk(router_logits, self.top_k)
52
+ weights = F.softmax(weights, dim=-1)
53
+ else:
54
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
55
+
56
+ results = torch.zeros_like(inputs_squashed)
57
+ # naive looping over experts
58
+ for i, expert in enumerate(self.experts):
59
+ batch_idx, nth_expert = torch.where(selected_experts == i)
60
+ output, _ = expert(inputs_squashed[batch_idx])
61
+ results[batch_idx] += weights[batch_idx, nth_expert, None] * output
62
+
63
+ # return results and router logits (for aux loss calculation later)
64
+ return results.view_as(inputs), {
65
+ "router_logits": router_logits,
66
+ "selected_experts": selected_experts,
67
+ }
68
+
69
+
70
+ class DummyExpert(nn.Module):
71
+ def __init__(self, output_size: int):
72
+ super().__init__()
73
+ self._output_size = output_size
74
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
75
+ out = torch.zeros((self._output_size,), device=inputs.device)
76
+ return out, {}
77
+
78
+
79
+
80
+ class MaskedMoE(MoE):
81
+ def __init__(self, config, mlp):
82
+ super().__init__(config, mlp)
83
+ self._sequence_length = config.sequence_length
84
+ self.experts.append(DummyExpert(config.n_embd))
85
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts+1, bias=False)
86
+
87
+
88
+ def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
89
+ seq_len = inputs.shape[1]
90
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
91
+ router_logits = self.router(inputs_squashed)
92
+ mask = torch.cat(
93
+ (mask, torch.ones((mask.shape[0], 1), device=mask.device)),
94
+ dim=1
95
+ )
96
+ mask = mask.repeat_interleave(seq_len, dim=0)
97
+ router_logits = router_logits*mask
98
+
99
+ # note that selected experts will be the same for all orders:
100
+ # softmax doesnt change top-k, but the weights are different
101
+ if self.softmax_order == "softmax_topk":
102
+ all_probs = F.softmax(router_logits, dim=1)
103
+ weights, selected_experts = torch.topk(all_probs, self.top_k)
104
+ elif self.softmax_order == "topk_softmax":
105
+ weights, selected_experts = torch.topk(router_logits, self.top_k)
106
+ weights = F.softmax(weights, dim=-1)
107
+ else:
108
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
109
+
110
+ results = torch.zeros_like(inputs_squashed)
111
+ # naive looping over experts
112
+ for i, expert in enumerate(self.experts):
113
+ batch_idx, nth_expert = torch.where(selected_experts == i)
114
+ output, _ = expert(inputs_squashed[batch_idx])
115
+ results[batch_idx] += weights[batch_idx, nth_expert, None] * output
116
+
117
+ # return results and router logits (for aux loss calculation later)
118
+ return results.view_as(inputs), {
119
+ "router_logits": router_logits,
120
+ "selected_experts": selected_experts,
121
+ }
122
+
123
+
124
+ class TimeDependantMoE(nn.Module):
125
+ def __init__(self, config, mlp):
126
+ super().__init__()
127
+ self._num_experts = config.moe_num_experts
128
+ self._mask_moe = MaskedMoE(config, mlp)
129
+
130
+ def forward(self, x, date):
131
+ mask_date = torch.zeros(x.shape[0], self._num_experts).to(x.device)
132
+ range_tensor = torch.arange(self._num_experts).unsqueeze(0).to(x.device)
133
+ mask_date = (range_tensor < date.unsqueeze(1)).float()
134
+ return self._mask_moe(x, mask_date)
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "extra_special_tokens": {},
17
+ "model_max_length": 1024,
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff