robinfaro commited on
Commit
17b5121
·
verified ·
1 Parent(s): 7642de2

Adding files from hf_modeling_btm_log_prob_mixing

Browse files
Files changed (1) hide show
  1. gpt.py +495 -0
gpt.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) nanoGPT by Karpathy:
5
+ https://github.com/karpathy/nanoGPT/tree/eba36e84649f3c6d840a93092cb779a260544d08
6
+ 2) the official GPT-2 TensorFlow implementation released by OpenAI:
7
+ https://github.com/openai/gpt-2/blob/master/src/model.py
8
+ 3) huggingface/transformers PyTorch implementation:
9
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
10
+ """
11
+
12
+ import math
13
+ import inspect
14
+
15
+ import tiktoken
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.nn import functional as F
19
+ from huggingface_hub import PyTorchModelHubMixin
20
+ from types import SimpleNamespace
21
+
22
+ from .moe import (
23
+ #ExpertChoiceMoE,
24
+ MaskedMoE,
25
+ TimeDependantMoE,
26
+ MoE,
27
+ )
28
+
29
+ from .aux_losses import (
30
+ entropy_reg,
31
+ load_balancing_loss,
32
+ router_z_loss,
33
+ )
34
+
35
+ class CausalSelfAttention(nn.Module):
36
+ def __init__(self, config):
37
+ super().__init__()
38
+ assert config.n_embd % config.n_head == 0
39
+ # key, query, value projections for all heads, but in a batch
40
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
41
+ # output projection
42
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
43
+ # regularization
44
+ self.attn_dropout = nn.Dropout(config.dropout)
45
+ self.resid_dropout = nn.Dropout(config.dropout)
46
+ self.n_head = config.n_head
47
+ self.n_embd = config.n_embd
48
+ self.dropout = config.dropout
49
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
50
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
51
+ if not self.flash:
52
+ print(
53
+ "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
54
+ )
55
+ # causal mask to ensure that attention is only applied to the left in the input sequence
56
+ self.register_buffer(
57
+ "bias",
58
+ torch.tril(
59
+ torch.ones(config.sequence_length, config.sequence_length)
60
+ ).view(1, 1, config.sequence_length, config.sequence_length),
61
+ )
62
+
63
+ def forward(self, x):
64
+ if x.ndim != 3:
65
+ x = x.unsqueeze(0) # handles the router input, since it previosly squashed the batch dim
66
+ # batch size, sequence length, embedding dimensionality (n_embd)
67
+ (
68
+ B,
69
+ T,
70
+ C,
71
+ ) = x.size()
72
+
73
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
74
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
75
+ # (B, T, nh, hs)
76
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
77
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
78
+
79
+ # (B, nh, T, hs)
80
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
81
+
82
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
83
+ if self.flash:
84
+ # efficient attention using Flash Attention CUDA kernels
85
+ y = torch.nn.functional.scaled_dot_product_attention(
86
+ q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
87
+ )
88
+ else:
89
+ # manual implementation of attention
90
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
91
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
92
+ att = F.softmax(att, dim=-1)
93
+ att = self.attn_dropout(att)
94
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
95
+ y = (
96
+ y.transpose(1, 2).contiguous().view(B, T, C)
97
+ ) # re-assemble all head outputs side by side
98
+
99
+ # output projection
100
+ y = self.resid_dropout(self.c_proj(y))
101
+ return y, {}
102
+
103
+
104
+ class LayerNorm(nn.Module):
105
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
106
+
107
+ def __init__(self, ndim, bias):
108
+ super().__init__()
109
+ self.weight = nn.Parameter(torch.ones(ndim))
110
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
111
+
112
+ def forward(self, input):
113
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
114
+
115
+
116
+ class MLP(nn.Module):
117
+ def __init__(self, config):
118
+ super().__init__()
119
+ self.dim_exp_factor = int(config.mlp_dim_exp_factor * 4)
120
+
121
+ self.c_fc = nn.Linear(
122
+ config.n_embd, self.dim_exp_factor * config.n_embd, bias=config.bias
123
+ )
124
+ self.c_proj = nn.Linear(
125
+ self.dim_exp_factor * config.n_embd, config.n_embd, bias=config.bias
126
+ )
127
+ self.dropout = nn.Dropout(config.dropout)
128
+ self.activation = nn.GELU()
129
+
130
+ def forward(self, x):
131
+ x = self.c_fc(x)
132
+ x = self.activation(x)
133
+ x = self.c_proj(x)
134
+ x = self.dropout(x)
135
+ # need to return same type as the MoE block, but in this case it's empty
136
+ return x, {}
137
+
138
+ class Block(nn.Module):
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ self.moe_config = config.moe_routing
142
+ self.shared_attention = config.shared_attention
143
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
144
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
145
+ if not config.moe and not config.shared_attention:
146
+ raise ValueError(
147
+ "If not using MoE, shared attention must be set to True"
148
+ )
149
+
150
+ if self.shared_attention:
151
+ self.attn = CausalSelfAttention(config)
152
+
153
+ if config.moe:
154
+ if config.moe_routing == "standard_gating":
155
+ self.mlp = MoE(config, MLP)
156
+ if not self.shared_attention:
157
+ self.attn = MoE(config, CausalSelfAttention)
158
+ elif config.moe_routing == "masked":
159
+ self.mlp = TimeDependantMoE(config, MLP)
160
+ if not self.shared_attention:
161
+ self.attn = TimeDependantMoE(config, CausalSelfAttention)
162
+ else:
163
+ raise ValueError(f"Unknown routing: {config.routing}")
164
+ else:
165
+ self.mlp = MLP(config)
166
+
167
+ def forward(self, x, date, *args, **kwargs):
168
+ if self.moe_config == "masked":
169
+ if self.shared_attention:
170
+ attn_output, attn_logits_and_experts = self.attn(self.ln_1(x, *args, **kwargs))
171
+ else:
172
+ attn_output, attn_logits_and_experts = self.attn(self.ln_1(x, *args, **kwargs), date)
173
+ x = x + attn_output
174
+ x_, mlp_logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs), date)
175
+ else:
176
+ attn_output, attn_logits_and_experts = self.attn(self.ln_1(x, *args, **kwargs))
177
+ x = x + attn_output
178
+ x_, mlp_logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs))
179
+ x = x + x_
180
+ return x, mlp_logits_and_experts, attn_logits_and_experts
181
+
182
+ class GPTBase(nn.Module, PyTorchModelHubMixin):
183
+ def __init__(self, config):
184
+ super().__init__()
185
+ if isinstance(config, dict):
186
+ # if we are given a dict, convert it to a SimpleNamespace
187
+ config = SimpleNamespace(**config)
188
+
189
+ assert config.vocab_size is not None
190
+ assert config.sequence_length is not None
191
+ self.config = config
192
+ self.tokenizer = tiktoken.get_encoding("gpt2")
193
+
194
+ self.transformer = nn.ModuleDict(
195
+ dict(
196
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
197
+ wpe=nn.Embedding(config.sequence_length, config.n_embd),
198
+ drop=nn.Dropout(config.dropout),
199
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
200
+ ln_f=LayerNorm(config.n_embd, bias=config.bias),
201
+ )
202
+ )
203
+
204
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
205
+ # with weight tying when using torch.compile() some warnings get generated:
206
+ # "UserWarning: functional_call was passed multiple values for tied weights.
207
+ # This behavior is deprecated and will be an error in future versions"
208
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
209
+ self.transformer.wte.weight = (
210
+ self.lm_head.weight
211
+ ) # https://paperswithcode.com/method/weight-tying
212
+
213
+ # init all weights
214
+ self.apply(self._init_weights)
215
+ # apply special scaled init to the residual projections, per GPT-2 paper
216
+ for pn, p in self.named_parameters():
217
+ if pn.endswith("c_proj.weight"):
218
+ torch.nn.init.normal_(
219
+ p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
220
+ )
221
+ if pn.endswith("router.weight"):
222
+ # special scaled init to moe router?
223
+ with torch.no_grad():
224
+ dim = 1 if config.moe_routing == "standard_gating" else 0
225
+ std = p.std()
226
+ p.div_(p.sum(dim=dim, keepdim=True))
227
+ p.mul_(std / p.std())
228
+
229
+ def get_router_losses(self, logits, selected_experts, eval=False):
230
+ # logits: (b * seq_len, n_experts)
231
+ # selected_experts: (b * seq_len, topk)
232
+ if eval: # eval mode, compute all losses
233
+ return {
234
+ "moe_entropy_loss": entropy_reg(logits),
235
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
236
+ "moe_z_loss": router_z_loss(logits),
237
+ }
238
+ if self.config.moe_router_loss == "entropy":
239
+ return {
240
+ "moe_entropy_loss": entropy_reg(logits),
241
+ }
242
+ elif self.config.moe_router_loss == "load_balancing_only":
243
+ return {
244
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
245
+ }
246
+ elif self.config.moe_router_loss == "load_balancing_z_loss":
247
+ return {
248
+ "moe_aux_loss": load_balancing_loss(logits, selected_experts),
249
+ "moe_z_loss": router_z_loss(logits),
250
+ }
251
+ return {}
252
+
253
+ def get_num_params(self, non_embedding=True):
254
+ """
255
+ Return the number of parameters in the model.
256
+ For non-embedding count (default), the position embeddings get subtracted.
257
+ The token embeddings would too, except due to the parameter sharing these
258
+ params are actually used as weights in the final layer, so we include them.
259
+ """
260
+ n_params = sum(p.numel() for p in self.parameters())
261
+ if non_embedding:
262
+ n_params -= self.transformer.wpe.weight.numel()
263
+ return n_params
264
+
265
+ def _init_weights(self, module):
266
+ if isinstance(module, nn.Linear):
267
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
268
+ if module.bias is not None:
269
+ torch.nn.init.zeros_(module.bias)
270
+ elif isinstance(module, nn.Embedding):
271
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
272
+
273
+ def forward(self, idx, date, targets=None, get_logits=False, moe=False):
274
+ device = idx.device
275
+ b, t = idx.size()
276
+ assert (
277
+ t <= self.config.sequence_length
278
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
279
+ # shape (1, t)
280
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
281
+
282
+ # forward the GPT model itself
283
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
284
+ pos_emb = self.transformer.wpe(
285
+ pos
286
+ ) # position embeddings of shape (1, t, n_embd)
287
+ x = self.transformer.drop(tok_emb + pos_emb)
288
+
289
+ # router logits is a list for each layer's routing, each of shape (b * seq_len, n_experts)
290
+ mlp_router_logits = []
291
+ attn_router_logits = []
292
+ # experts is a list for each layer's selected experts, shape (b * seq_len, topk)
293
+ mlp_experts = []
294
+ attn_experts = []
295
+
296
+
297
+ # forward pass through all the transformer blocks
298
+ for block in self.transformer.h:
299
+ x, mlp_logits_and_experts, attn_logits_and_experts = block(x, date)
300
+ if len(mlp_logits_and_experts) > 0:
301
+ mlp_router_logits.append(mlp_logits_and_experts["router_logits"])
302
+ mlp_experts.append(mlp_logits_and_experts["selected_experts"])
303
+ if len(attn_logits_and_experts) > 0:
304
+ attn_router_logits.append(attn_logits_and_experts["router_logits"])
305
+ attn_experts.append(attn_logits_and_experts["selected_experts"])
306
+ x = self.transformer.ln_f(x)
307
+
308
+ # aux_losses is a dict with keys for different auxiliary losses
309
+ aux_losses_mlp = {}
310
+ aux_losses_attn = {}
311
+
312
+ if targets is not None:
313
+ # if we are given some desired targets also calculate the loss
314
+ logits = self.lm_head(x)
315
+ loss = F.cross_entropy(
316
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
317
+ )
318
+ loss_to_log = loss.item()
319
+ if moe and (self.config.moe_routing == "standard_gating" or self.config.moe_routing == "masked"):
320
+ # calculate the router losses per layer
321
+ for logit, expert_choice in zip(mlp_router_logits, mlp_experts):
322
+ router_losses = self.get_router_losses(
323
+ logit, expert_choice, eval=not self.training
324
+ )
325
+ for k, v in router_losses.items():
326
+ aux_losses_mlp[k] = aux_losses_mlp.get(k, 0.0) + v
327
+ if self.training:
328
+ loss += (
329
+ v
330
+ * getattr(self.config, k + "_factor")
331
+ / self.config.n_layer
332
+ )
333
+ for logit, expert_choice in zip(attn_router_logits, attn_experts):
334
+ router_losses = self.get_router_losses(
335
+ logit, expert_choice, eval=not self.training
336
+ )
337
+ for k, v in router_losses.items():
338
+ aux_losses_attn[k] = aux_losses_attn.get(k, 0.0) + v
339
+ if self.training:
340
+ loss += (
341
+ v
342
+ * getattr(self.config, k + "_factor")
343
+ / self.config.n_layer
344
+ )
345
+ else:
346
+ # inference-time mini-optimization: only forward the lm_head on the very last position
347
+ logits = self.lm_head(
348
+ #x[:, [-1], :]
349
+ x
350
+ ) # note: using list [-1] to preserve the time dim
351
+ loss = None
352
+ loss_to_log = None
353
+ logits = logits if get_logits else None
354
+ mlp_router_logits = (
355
+ torch.stack(mlp_router_logits, dim=0) if len(mlp_router_logits) > 0 else None
356
+ )
357
+ attn_router_logits = (
358
+ torch.stack(attn_router_logits, dim=0) if len(attn_router_logits) > 0 else None
359
+ )
360
+ return {
361
+ "loss_to_log": loss_to_log,
362
+ "logits": logits,
363
+ "loss": loss,
364
+ "aux_losses_mlp": aux_losses_mlp,
365
+ "aux_losses_attn": aux_losses_attn,
366
+ "mlp_router_logits": mlp_router_logits,
367
+ "attn_router_logits": attn_router_logits
368
+ }
369
+
370
+ def crop_sequence_length(self, sequence_length):
371
+ # model surgery to decrease the block size if necessary
372
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
373
+ # but want to use a smaller block size for some smaller, simpler model
374
+ assert sequence_length <= self.config.sequence_length
375
+ self.config.sequence_length = sequence_length
376
+ self.transformer.wpe.weight = nn.Parameter(
377
+ self.transformer.wpe.weight[:sequence_length]
378
+ )
379
+ for block in self.transformer.h:
380
+ block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length]
381
+
382
+ @classmethod
383
+ def from_pretrained(cls, model_type, override_args=None):
384
+ # TODO
385
+ pass
386
+
387
+ def get_parameter_group_specs(self):
388
+ """
389
+ This long function is unfortunately doing something very simple and is being very defensive:
390
+ We are separating out all parameters of the model into two buckets: those that will experience
391
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
392
+ We are then returning the PyTorch optimizer object.
393
+ """
394
+
395
+ # separate out all parameters to those that will and won't experience regularizing weight decay
396
+ decay = set()
397
+ no_decay = set()
398
+ whitelist_weight_modules = (torch.nn.Linear,)
399
+
400
+ BLACKLIST_WEIGHT_MODULES = (
401
+ torch.nn.LayerNorm,
402
+ LayerNorm,
403
+ torch.nn.Embedding,
404
+ )
405
+
406
+ for mn, m in self.named_modules():
407
+ for pn, p in m.named_parameters():
408
+ fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
409
+ # random note: because named_modules and named_parameters are recursive
410
+ # we will see the same tensors p many many times. but doing it this way
411
+ # allows us to know which parent module any tensor p belongs to...
412
+ if pn.endswith("bias"):
413
+ # all biases will not be decayed
414
+ no_decay.add(fpn)
415
+ elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
416
+ # weights of whitelist modules will be weight decayed
417
+ decay.add(fpn)
418
+ elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES):
419
+ # weights of blacklist modules will NOT be weight decayed
420
+ no_decay.add(fpn)
421
+
422
+ # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
423
+ # will appear in the no_decay and decay sets respectively after the above.
424
+ # In addition, because named_parameters() doesn't return duplicates, it
425
+ # will only return the first occurence, key'd by 'transformer.wte.weight', below.
426
+ # so let's manually remove 'lm_head.weight' from decay set. This will include
427
+ # this tensor into optimization via transformer.wte.weight only, and not decayed.
428
+ decay.remove("lm_head.weight")
429
+
430
+ # validate that we considered every parameter
431
+ param_dict = {pn: p for pn, p in self.named_parameters()}
432
+ inter_params = decay & no_decay
433
+ union_params = decay | no_decay
434
+ assert (
435
+ len(inter_params) == 0
436
+ ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
437
+ assert (
438
+ len(param_dict.keys() - union_params) == 0
439
+ ), "parameters %s were not separated into either decay/no_decay set!" % (
440
+ str(param_dict.keys() - union_params),
441
+ )
442
+
443
+ # create the pytorch optimizer object
444
+ return [
445
+ {"params": sorted(list(decay))},
446
+ {"params": sorted(list(no_decay)), "weight_decay": 0.0},
447
+ ]
448
+
449
+ @torch.no_grad()
450
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
451
+ """
452
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
453
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
454
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
455
+ """
456
+ for _ in range(max_new_tokens):
457
+ # if the sequence context is growing too long we must crop it at sequence_length
458
+ idx_cond = (
459
+ idx
460
+ if idx.size(1) <= self.config.sequence_length
461
+ else idx[:, -self.config.sequence_length :]
462
+ )
463
+ # forward the model to get the logits for the index in the sequence
464
+ logits = self(idx_cond, get_logits=True)["logits"]
465
+ # pluck the logits at the final step and scale by desired temperature
466
+ logits = logits[:, -1, :] / temperature
467
+ # optionally crop the logits to only the top k options
468
+ if top_k is not None:
469
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
470
+ logits[logits < v[:, [-1]]] = -float("Inf")
471
+ # apply softmax to convert logits to (normalized) probabilities
472
+ probs = F.softmax(logits, dim=-1)
473
+ # sample from the distribution
474
+ idx_next = torch.multinomial(probs, num_samples=1)
475
+ # append sampled index to the running sequence and continue
476
+ idx = torch.cat((idx, idx_next), dim=1)
477
+
478
+ return idx
479
+
480
+ @torch.no_grad()
481
+ def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None):
482
+ idx = (
483
+ torch.tensor(
484
+ self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
485
+ )
486
+ .view(1, -1)
487
+ .to(self.lm_head.weight.device)
488
+ )
489
+ out_idx = (
490
+ self.generate(idx, max_new_tokens, temperature, top_k)
491
+ .view(-1)
492
+ .to("cpu")
493
+ .numpy()
494
+ )
495
+ return self.tokenizer.decode(out_idx)