talphaidze commited on
Commit
1b8457b
·
verified ·
1 Parent(s): 898c171

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +396 -103
modeling.py CHANGED
@@ -1,142 +1,420 @@
1
  from transformers import PreTrainedModel
2
- from .configuration import MoLMConfig
 
 
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
 
6
  from transformers.utils import ModelOutput
7
- from .gpt import GPTBase
8
- from typing import Optional, List
9
- from dataclasses import dataclass
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  @dataclass
13
  class Output(ModelOutput):
14
  logits: torch.FloatTensor = None
15
  loss: Optional[torch.FloatTensor] = None
16
- expert_losses: Optional[List] = None
17
- loss_to_log: Optional[float] = None
18
 
 
 
19
 
20
- class MoLM(PreTrainedModel):
21
- config_class = MoLMConfig
22
 
23
- def __init__(self, config, expert_weights=None, dropout=0.1, use_router=False):
24
- """
25
- Constructor for the MoLM (Mixture of Language Models) class.
 
26
 
27
- :param config: The configuration of the model (should be a PretrainedConfig object)
28
- :param expert_weights: (Optional) A list of weights for each expert to load pre-trained weights (should match the number of experts)
29
- :param dropout: Dropout rate for the model
30
- :param use_router: Flag to indicate whether to use routing (currently not implemented)
31
- """
32
- super(MoLM, self).__init__(config)
33
-
34
- # Number of experts
35
- self.num_experts = config.num_experts
36
- print(f"Number of experts: {self.num_experts}")
37
- print(f"Expert configurations: {config.expert_configs}")
38
- assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config."
39
- self.expert_configs = config.expert_configs
40
-
41
- # Flag for routing (not implemented yet)
42
- self.use_router = use_router
43
-
44
- # Initialize experts using the provided configurations
45
- self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)])
46
-
47
- # Load pre-trained weights if provided
48
- if expert_weights is not None:
49
- for i, expert in enumerate(self.experts):
50
- expert.load_state_dict(expert_weights[i], strict=False)
51
- expert.transformer.wte.weight = torch.nn.Parameter(expert.transformer.wte.weight.clone())
52
- for param in expert.parameters():
53
- param.requires_grad = False
54
-
55
- def forward(self, input_ids, attention_mask=None, targets=None, date=None, masking_enabled=True, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  """
57
- Forward pass for the MoLM model, passing input through all experts and averaging their outputs.
58
-
59
- :param input_ids: Input token IDs (batch_size, seq_len)
60
- :param attention_mask: Attention mask (batch_size, seq_len)
61
- :param targets: Target labels for calculating loss (batch_size, seq_len)
62
- :param date: A tensor indicating which experts to use. Each sample in the batch can have a different date.
63
- :param masking_enabled: Whether or not to perform expert masking (True/False)
64
- :param kwargs: Additional arguments
65
- :return: The averaged output of all active experts up to the specified date for each sample in the batch
66
  """
67
- device = input_ids.device
68
- b, t = input_ids.size()
 
 
69
 
70
- # Ensure the sequence length doesn't exceed the configured block size
71
- assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
 
 
 
 
 
72
 
73
- # If date is None, set a default value (e.g., 6 for all samples)
 
 
 
 
 
 
74
  if date is None:
 
75
  date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0)
76
- elif isinstance(date, int):
77
- # If date is an integer, set it for all samples in the batch
78
  date = (date - 2013) // 2 + 1
79
  date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0)
80
- elif isinstance(date, torch.Tensor):
81
- # Ensure the tensor has the correct shape (batch_size,)
82
- assert date.size(0) == b, "The size of date tensor must match the batch size."
83
- date = date.to(device)
84
 
85
- # Get outputs from each expert
86
- expert_outputs = []
87
- expert_losses = []
 
 
 
88
 
89
- # Track the number of active experts for each sample in the batch
90
- active_experts_count = torch.zeros(b, dtype=torch.long, device=device)
 
 
91
 
92
- # Pass input through each expert
93
- for i, expert in enumerate(self.experts):
94
- # Masking logic based on date (for each sample in the batch)
95
- expert_mask = date >= i # Mask experts where date < i (i.e., deactivate them)
 
 
 
96
 
97
- # Expand the expert_mask to match the logits shape (batch_size, 1, 1)
98
- expert_mask_expanded = expert_mask.unsqueeze(-1).unsqueeze(-1).float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- expert_output = expert(input_ids, targets=targets, date=date, get_logits=True, **kwargs)
 
 
 
 
 
 
 
 
 
 
101
 
102
- logits = expert_output["logits"]
103
- loss_to_log = expert_output["loss_to_log"]
104
 
105
- # Mask out the outputs for deactivated experts
106
- logits = logits * expert_mask_expanded # Apply the mask (zero out logits for inactive experts)
 
 
 
 
 
107
 
108
- # Only append logits from active experts
109
- expert_outputs.append(logits)
110
- expert_losses.append(loss_to_log)
 
111
 
112
- # Update active expert count for each sample
113
- active_experts_count += expert_mask.long() # Ensure type consistency by converting `expert_mask` to Long
 
 
 
114
 
115
- # Stack the logits and calculate the mean for each sample across the active experts
116
- expert_outputs = torch.stack(expert_outputs, dim=0) # Shape: (num_experts, batch_size, seq_len, vocab_size)
117
-
118
- # Calculate the sum across the active experts for each sample and then average
119
- summed_logits = torch.sum(expert_outputs, dim=0) # Sum across active experts
120
- combined_logits = summed_logits / active_experts_count.unsqueeze(-1).unsqueeze(-1) # Divide by the number of active experts
 
 
 
 
 
 
 
 
 
121
 
122
- # Calculate the loss if targets are provided
123
- if targets is not None:
124
- loss = F.cross_entropy(combined_logits.view(-1, combined_logits.size(-1)), targets.view(-1), ignore_index=-1)
125
- loss_to_log = loss.item()
126
- else:
127
- loss = None
128
- loss_to_log = None
129
 
130
- return Output(
131
- logits=combined_logits,
132
- loss=loss,
133
- loss_to_log=loss_to_log,
134
- expert_losses=expert_losses
 
 
 
 
 
 
135
  )
136
 
 
 
 
 
 
137
 
138
  @torch.no_grad()
139
- def generate(self, input_ids, max_new_tokens, date=None, temperature=1.0, top_k=None):
140
  """
141
  Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
142
  the sequence max_new_tokens times, feeding the predictions back into the model each time.
@@ -165,13 +443,13 @@ class MoLM(PreTrainedModel):
165
  # append sampled index to the running sequence and continue
166
  idx = torch.cat((idx, idx_next), dim=1)
167
  # check if we hit the end of the sequence
168
- if idx_next.item() == 50526:
169
  break
170
 
171
  return idx
172
 
173
  @torch.no_grad()
174
- def generate_from_string(self, in_str, max_new_tokens, date=None, temperature=1.0, top_k=None):
175
  idx = (
176
  torch.tensor(
177
  self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
@@ -185,4 +463,19 @@ class MoLM(PreTrainedModel):
185
  .to("cpu")
186
  .numpy()
187
  )
188
- return self.tokenizer.decode(out_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import PreTrainedModel
2
+ from configuration import MoEGPTConfig
3
+ # importa anche MoE, MaskedMoE, TimeDependantMoE ecc.
4
+ import math
5
+ import inspect
6
+ from typing import Optional, Dict, Any
7
+ from dataclasses import dataclass
8
+ import tiktoken
9
  import torch
10
  import torch.nn as nn
11
  from torch.nn import functional as F
12
+ from huggingface_hub import PyTorchModelHubMixin
13
  from transformers.utils import ModelOutput
 
 
 
14
 
15
 
16
+ from moe import (
17
+ #ExpertChoiceMoE,
18
+ MaskedMoE,
19
+ TimeDependantMoE,
20
+ MoE,
21
+ )
22
+
23
+ from aux_losses import (
24
+ entropy_reg,
25
+ load_balancing_loss,
26
+ router_z_loss,
27
+ )
28
+
29
+ # class Output(ModelOutput):
30
+ # def __init__(self, logits, loss=None, aux_losses=None, router_logits=None):
31
+ # self.logits = logits
32
+ # self.loss = loss
33
+ # self.aux_losses = aux_losses
34
+ # self.router_logits = router_logits
35
  @dataclass
36
  class Output(ModelOutput):
37
  logits: torch.FloatTensor = None
38
  loss: Optional[torch.FloatTensor] = None
39
+ aux_losses: Optional[Dict[str, torch.FloatTensor]] = None
40
+ router_logits: Optional[torch.FloatTensor] = None
41
 
42
+ def __repr__(self):
43
+ return f"Output(logits={self.logits}, loss={self.loss}, aux_losses={self.aux_losses}, router_logits={self.router_logits})"
44
 
45
+ class LayerNorm(nn.Module):
46
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
47
 
48
+ def __init__(self, ndim, bias):
49
+ super().__init__()
50
+ self.weight = nn.Parameter(torch.ones(ndim))
51
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
52
 
53
+ def forward(self, input):
54
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
55
+
56
+ class CausalSelfAttention(nn.Module):
57
+ def __init__(self, config):
58
+ super().__init__()
59
+ assert config.n_embd % config.n_head == 0
60
+ # key, query, value projections for all heads, but in a batch
61
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
62
+ # output projection
63
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
64
+ # regularization
65
+ self.attn_dropout = nn.Dropout(config.dropout)
66
+ self.resid_dropout = nn.Dropout(config.dropout)
67
+ self.n_head = config.n_head
68
+ self.n_embd = config.n_embd
69
+ self.dropout = config.dropout
70
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
71
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
72
+ if not self.flash:
73
+ print(
74
+ "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
75
+ )
76
+ # causal mask to ensure that attention is only applied to the left in the input sequence
77
+ self.register_buffer(
78
+ "bias",
79
+ torch.tril(
80
+ torch.ones(config.sequence_length, config.sequence_length)
81
+ ).view(1, 1, config.sequence_length, config.sequence_length),
82
+ )
83
+
84
+ def forward(self, x):
85
+ # batch size, sequence length, embedding dimensionality (n_embd)
86
+ (
87
+ B,
88
+ T,
89
+ C,
90
+ ) = x.size()
91
+
92
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
93
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
94
+ # (B, T, nh, hs)
95
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
96
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
97
+
98
+ # (B, nh, T, hs)
99
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
100
+
101
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
102
+ if self.flash:
103
+ # efficient attention using Flash Attention CUDA kernels
104
+ y = torch.nn.functional.scaled_dot_product_attention(
105
+ q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
106
+ )
107
+ else:
108
+ # manual implementation of attention
109
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
110
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
111
+ att = F.softmax(att, dim=-1)
112
+ att = self.attn_dropout(att)
113
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
114
+ y = (
115
+ y.transpose(1, 2).contiguous().view(B, T, C)
116
+ ) # re-assemble all head outputs side by side
117
+
118
+ # output projection
119
+ y = self.resid_dropout(self.c_proj(y))
120
+ return y
121
+
122
+
123
+ class MLP(nn.Module):
124
+ def __init__(self, config):
125
+ super().__init__()
126
+ self.dim_exp_factor = int(config.mlp_dim_exp_factor * 4)
127
+
128
+ self.c_fc = nn.Linear(
129
+ config.n_embd, self.dim_exp_factor * config.n_embd, bias=config.bias
130
+ )
131
+ self.c_proj = nn.Linear(
132
+ self.dim_exp_factor * config.n_embd, config.n_embd, bias=config.bias
133
+ )
134
+ self.dropout = nn.Dropout(config.dropout)
135
+ self.activation = nn.GELU()
136
+
137
+ def forward(self, x):
138
+ x = self.c_fc(x)
139
+ x = self.activation(x)
140
+ x = self.c_proj(x)
141
+ x = self.dropout(x)
142
+ # need to return same type as the MoE block, but in this case it's empty
143
+ return x, {}
144
+
145
+
146
+ class Block(nn.Module):
147
+ def __init__(self, config):
148
+ super().__init__()
149
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
150
+ self.attn = CausalSelfAttention(config)
151
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
152
+ self.moe_config = config.moe_routing
153
+ if config.moe:
154
+ if config.moe_routing == "standard_gating":
155
+ self.mlp = MoE(config, MLP)
156
+ elif config.moe_routing == "masked":
157
+ self.mlp = TimeDependantMoE(config, MLP)
158
+ #elif config.moe_routing == "expert_choice":
159
+ # self.mlp = ExpertChoiceMoE(config, MLP)
160
+ else:
161
+ raise ValueError(f"Unknown routing: {config.routing}")
162
+ else:
163
+ self.mlp = MLP(config)
164
+
165
+ def forward(self, x, date, *args, **kwargs):
166
+ x = x + self.attn(self.ln_1(x, *args, **kwargs))
167
+ if self.moe_config == "masked":
168
+ x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs), date)
169
+ else:
170
+ x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs))
171
+ x = x + x_
172
+ return x, logits_and_experts
173
+
174
+
175
+ class MoEGPTForCausalLM(PreTrainedModel):
176
+ config_class = MoEGPTConfig
177
+ def __init__(self, config):
178
+ super().__init__(config)
179
+ assert config.vocab_size is not None
180
+ assert config.sequence_length is not None
181
+ self.config = config
182
+ self.tokenizer = tiktoken.get_encoding("gpt2")
183
+ self.base_model_prefix = "timoe"
184
+
185
+ self.transformer = nn.ModuleDict(
186
+ dict(
187
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
188
+ wpe=nn.Embedding(config.sequence_length, config.n_embd),
189
+ drop=nn.Dropout(config.dropout),
190
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
191
+ ln_f=LayerNorm(config.n_embd, bias=config.bias),
192
+ )
193
+ )
194
+
195
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
196
+ # with weight tying when using torch.compile() some warnings get generated:
197
+ # "UserWarning: functional_call was passed multiple values for tied weights.
198
+ # This behavior is deprecated and will be an error in future versions"
199
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
200
+ self.transformer.wte.weight = (
201
+ self.lm_head.weight
202
+ ) # https://paperswithcode.com/method/weight-tying
203
+
204
+ # init all weights
205
+ self.apply(self._init_weights)
206
+ # apply special scaled init to the residual projections, per GPT-2 paper
207
+ for pn, p in self.named_parameters():
208
+ if pn.endswith("c_proj.weight"):
209
+ torch.nn.init.normal_(
210
+ p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
211
+ )
212
+ if pn.endswith("router.weight"):
213
+ # special scaled init to moe router?
214
+ with torch.no_grad():
215
+ dim = 1 if config.moe_routing == "standard_gating" else 0
216
+ std = p.std()
217
+ p.div_(p.sum(dim=dim, keepdim=True))
218
+ p.mul_(std / p.std())
219
+
220
+ def get_router_losses(self, logits, selected_experts, eval=False):
221
+ # logits: (b * seq_len, n_experts)
222
+ # selected_experts: (b * seq_len, topk)
223
+ if eval: # eval mode, compute all losses
224
+ return {
225
+ "moe_entropy_loss": entropy_reg(logits),
226
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
227
+ "moe_z_loss": router_z_loss(logits),
228
+ }
229
+ if self.config.moe_router_loss == "entropy":
230
+ return {
231
+ "moe_entropy_loss": entropy_reg(logits),
232
+ }
233
+ elif self.config.moe_router_loss == "load_balancing_only":
234
+ return {
235
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
236
+ }
237
+ elif self.config.moe_router_loss == "load_balancing_z_loss":
238
+ return {
239
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
240
+ "moe_z_loss": router_z_loss(logits),
241
+ }
242
+ return {}
243
+
244
+ def get_num_params(self, non_embedding=True):
245
  """
246
+ Return the number of parameters in the model.
247
+ For non-embedding count (default), the position embeddings get subtracted.
248
+ The token embeddings would too, except due to the parameter sharing these
249
+ params are actually used as weights in the final layer, so we include them.
 
 
 
 
 
250
  """
251
+ n_params = sum(p.numel() for p in self.parameters())
252
+ if non_embedding:
253
+ n_params -= self.transformer.wpe.weight.numel()
254
+ return n_params
255
 
256
+ def _init_weights(self, module):
257
+ if isinstance(module, nn.Linear):
258
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
259
+ if module.bias is not None:
260
+ torch.nn.init.zeros_(module.bias)
261
+ elif isinstance(module, nn.Embedding):
262
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
263
 
264
+ def forward(self, idx, date=None, targets=None, attention_mask=None, get_logits=True, moe=False):
265
+ device = idx.device
266
+ b, t = idx.size()
267
+ assert (
268
+ t <= self.config.sequence_length
269
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
270
+ # shape (1, t)
271
  if date is None:
272
+ # set all the date to 6
273
  date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0)
274
+ else:
 
275
  date = (date - 2013) // 2 + 1
276
  date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0)
277
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
 
 
 
278
 
279
+ # forward the GPT model itself
280
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
281
+ pos_emb = self.transformer.wpe(
282
+ pos
283
+ ) # position embeddings of shape (1, t, n_embd)
284
+ x = self.transformer.drop(tok_emb + pos_emb)
285
 
286
+ # router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts)
287
+ router_logits = []
288
+ # experts is a list for each layer's selected experts, shape (b * seq_len, topk)
289
+ experts = []
290
 
291
+ # forward pass through all the transformer blocks
292
+ for block in self.transformer.h:
293
+ x, logits_and_experts = block(x, date)
294
+ if len(logits_and_experts) > 0:
295
+ router_logits.append(logits_and_experts["router_logits"])
296
+ experts.append(logits_and_experts["selected_experts"])
297
+ x = self.transformer.ln_f(x)
298
 
299
+ # aux_losses is a dict with keys for different auxiliary losses
300
+ aux_losses = {}
301
+
302
+ if targets is not None:
303
+ # if we are given some desired targets also calculate the loss
304
+ logits = self.lm_head(x)
305
+ loss = F.cross_entropy(
306
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
307
+ )
308
+ if moe and (self.config.moe_routing == "standard_gating" or self.config.moe_routing == "masked"):
309
+ # calculate the router losses per layer
310
+ for logit, expert_choice in zip(router_logits, experts):
311
+ router_losses = self.get_router_losses(
312
+ logit, expert_choice, eval=not self.training
313
+ )
314
+ for k, v in router_losses.items():
315
+ aux_losses[k] = aux_losses.get(k, 0.0) + v
316
+ if self.training:
317
+ loss += (
318
+ v
319
+ * getattr(self.config, k + "_factor")
320
+ / self.config.n_layer
321
+ )
322
+ else:
323
+ # inference-time mini-optimization: only forward the lm_head on the very last position
324
+ logits = self.lm_head(
325
+ #x[:, [-1], :]
326
+ x
327
+ ) # note: using list [-1] to preserve the time dim
328
+ loss = None
329
+ logits = logits if get_logits else None
330
+ router_logits = (
331
+ torch.stack(router_logits, dim=0) if len(router_logits) > 0 else None
332
+ )
333
+ # return {
334
+ # "logits": logits,
335
+ # "loss": loss,
336
+ # "aux_losses": aux_losses,
337
+ # "router_logits": router_logits,
338
+ # }
339
+ return Output(logits = logits, loss = loss, aux_losses = aux_losses, router_logits = router_logits)
340
 
341
+ def crop_sequence_length(self, sequence_length):
342
+ # model surgery to decrease the block size if necessary
343
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
344
+ # but want to use a smaller block size for some smaller, simpler model
345
+ assert sequence_length <= self.config.sequence_length
346
+ self.config.sequence_length = sequence_length
347
+ self.transformer.wpe.weight = nn.Parameter(
348
+ self.transformer.wpe.weight[:sequence_length]
349
+ )
350
+ for block in self.transformer.h:
351
+ block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length]
352
 
 
 
353
 
354
+ def get_parameter_group_specs(self):
355
+ """
356
+ This long function is unfortunately doing something very simple and is being very defensive:
357
+ We are separating out all parameters of the model into two buckets: those that will experience
358
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
359
+ We are then returning the PyTorch optimizer object.
360
+ """
361
 
362
+ # separate out all parameters to those that will and won't experience regularizing weight decay
363
+ decay = set()
364
+ no_decay = set()
365
+ whitelist_weight_modules = (torch.nn.Linear,)
366
 
367
+ BLACKLIST_WEIGHT_MODULES = (
368
+ torch.nn.LayerNorm,
369
+ LayerNorm,
370
+ torch.nn.Embedding,
371
+ )
372
 
373
+ for mn, m in self.named_modules():
374
+ for pn, p in m.named_parameters():
375
+ fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
376
+ # random note: because named_modules and named_parameters are recursive
377
+ # we will see the same tensors p many many times. but doing it this way
378
+ # allows us to know which parent module any tensor p belongs to...
379
+ if pn.endswith("bias"):
380
+ # all biases will not be decayed
381
+ no_decay.add(fpn)
382
+ elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
383
+ # weights of whitelist modules will be weight decayed
384
+ decay.add(fpn)
385
+ elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES):
386
+ # weights of blacklist modules will NOT be weight decayed
387
+ no_decay.add(fpn)
388
 
389
+ # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
390
+ # will appear in the no_decay and decay sets respectively after the above.
391
+ # In addition, because named_parameters() doesn't return duplicates, it
392
+ # will only return the first occurence, key'd by 'transformer.wte.weight', below.
393
+ # so let's manually remove 'lm_head.weight' from decay set. This will include
394
+ # this tensor into optimization via transformer.wte.weight only, and not decayed.
395
+ decay.remove("lm_head.weight")
396
 
397
+ # validate that we considered every parameter
398
+ param_dict = {pn: p for pn, p in self.named_parameters()}
399
+ inter_params = decay & no_decay
400
+ union_params = decay | no_decay
401
+ assert (
402
+ len(inter_params) == 0
403
+ ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
404
+ assert (
405
+ len(param_dict.keys() - union_params) == 0
406
+ ), "parameters %s were not separated into either decay/no_decay set!" % (
407
+ str(param_dict.keys() - union_params),
408
  )
409
 
410
+ # create the pytorch optimizer object
411
+ return [
412
+ {"params": sorted(list(decay))},
413
+ {"params": sorted(list(no_decay)), "weight_decay": 0.0},
414
+ ]
415
 
416
  @torch.no_grad()
417
+ def generate(self, input_ids, max_new_tokens, date = None, temperature=1.0, top_k=None):
418
  """
419
  Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
420
  the sequence max_new_tokens times, feeding the predictions back into the model each time.
 
443
  # append sampled index to the running sequence and continue
444
  idx = torch.cat((idx, idx_next), dim=1)
445
  # check if we hit the end of the sequence
446
+ if idx_next.item() == self.tokenizer.eot_token:
447
  break
448
 
449
  return idx
450
 
451
  @torch.no_grad()
452
+ def generate_from_string(self, in_str, max_new_tokens, date = None, temperature=1.0, top_k=None):
453
  idx = (
454
  torch.tensor(
455
  self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
 
463
  .to("cpu")
464
  .numpy()
465
  )
466
+ return self.tokenizer.decode(out_idx).split(in_str)[-1]
467
+
468
+
469
+ def get_input_embeddings(self):
470
+ return self.transformer.wte
471
+
472
+ def set_input_embeddings(self, new_embeddings):
473
+ self.transformer.wte = new_embeddings
474
+ # reset the lm_head to use the new embeddings
475
+ # this is necessary because the lm_head is tied to the input embeddings
476
+ self.lm_head = nn.Linear(
477
+ self.config.n_embd, new_embeddings.weight.shape[0] , bias=False
478
+ )
479
+ #self.transformer.wte.weight = (
480
+ # self.lm_head.weight
481
+ #)