yq commited on
Commit
d5d2f03
·
1 Parent(s): 093b21e

using gpt2-124M

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ out-stinfo/ckpt.pt filter=lfs diff=lfs merge=lfs -text
__pycache__/model.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
config/eval_gpt2.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # evaluate the base gpt2
2
+ # n_layer=12, n_head=12, n_embd=768
3
+ # 124M parameters
4
+ batch_size = 8
5
+ eval_iters = 500 # use more iterations to get good estimate
6
+ eval_only = True
7
+ wandb_log = False
8
+ init_from = 'gpt2'
config/finetune_stinfo.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ out_dir = 'out-stinfo'
4
+ eval_interval = 5
5
+ eval_iters = 10
6
+ wandb_log = False # feel free to turn on
7
+ wandb_project = 'stinfo'
8
+ wandb_run_name = 'ft-' + str(time.time())
9
+
10
+ dataset = 'stinfo'
11
+ init_from = 'gpt2' # this is the largest GPT-2 model
12
+
13
+ # only save checkpoints if the validation loss improves
14
+ always_save_checkpoint = False
15
+
16
+ # the number of examples per iter:
17
+ # 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter
18
+ # shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters
19
+ batch_size = 4
20
+ gradient_accumulation_steps = 32
21
+ max_iters = 30
22
+
23
+ # finetune at constant LR
24
+ learning_rate = 3e-5
25
+ decay_lr = True
configurator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import sys
18
+ from ast import literal_eval
19
+
20
+ for arg in sys.argv[1:]:
21
+ if '=' not in arg:
22
+ # assume it's the name of a config file
23
+ assert not arg.startswith('--')
24
+ config_file = arg
25
+ print(f"Overriding config with {config_file}:")
26
+ with open(config_file) as f:
27
+ print(f.read())
28
+ exec(open(config_file).read())
29
+ else:
30
+ # assume it's a --key=value argument
31
+ assert arg.startswith('--')
32
+ key, val = arg.split('=')
33
+ key = key[2:]
34
+ if key in globals():
35
+ try:
36
+ # attempt to eval it it (e.g. if bool, number, or etc)
37
+ attempt = literal_eval(val)
38
+ except (SyntaxError, ValueError):
39
+ # if that goes wrong, just use the string
40
+ attempt = val
41
+ # ensure the types match ok
42
+ assert type(attempt) == type(globals()[key])
43
+ # cross fingers
44
+ print(f"Overriding: {key} = {attempt}")
45
+ globals()[key] = attempt
46
+ else:
47
+ raise ValueError(f"Unknown config key: {key}")
data/stinfo/input/ST_Engineering_info_formatted.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/stinfo/input/ar2021_50anni.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/stinfo/input/st_news.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/stinfo/prepare.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tiktoken
3
+ import numpy as np
4
+
5
+ # read from movie reviews
6
+
7
+ # data obtained from https://ai.stanford.edu/~amaas/data/sentiment/
8
+ folder_path = 'input'
9
+ files = os.listdir(folder_path)
10
+ data = ''
11
+
12
+ for i in files:
13
+ with open(folder_path + '/' + i, 'r') as f:
14
+ content = f.read()
15
+ if content:
16
+ data += content
17
+ data += '\n'
18
+
19
+ n = len(data)
20
+
21
+ train_data = data[:int(n*0.9)]
22
+ val_data = data[int(n*0.9):]
23
+
24
+ # encode with tiktoken gpt2 byte pair encoding
25
+ enc = tiktoken.get_encoding("gpt2")
26
+ train_ids = enc.encode_ordinary(train_data)
27
+ val_ids = enc.encode_ordinary(val_data)
28
+ print(f"train has {len(train_ids):,} tokens")
29
+ print(f"val has {len(val_ids):,} tokens")
30
+
31
+ # export to bin files
32
+ train_ids = np.array(train_ids, dtype=np.uint16)
33
+ val_ids = np.array(val_ids, dtype=np.uint16)
34
+ train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
35
+ val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
data/stinfo/train.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b63d839769b684c7ea0ffec4d2f400f621820ef92bec9661359cde1ab6510437
3
+ size 189660
data/stinfo/val.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f745dc46c85ac88838dcd1721ca733e7a9543c037f538ad830356e22f848474e
3
+ size 20318
model.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5
+ https://github.com/openai/gpt-2/blob/master/src/model.py
6
+ 2) huggingface/transformers PyTorch implementation:
7
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ """
9
+
10
+ import math
11
+ import inspect
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
19
+ def new_gelu(x):
20
+ """
21
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
22
+ Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
23
+ """
24
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
25
+
26
+ class LayerNorm(nn.Module):
27
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
28
+
29
+ def __init__(self, ndim, bias):
30
+ super().__init__()
31
+ self.weight = nn.Parameter(torch.ones(ndim))
32
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
33
+
34
+ def forward(self, input):
35
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
36
+
37
+ class CausalSelfAttention(nn.Module):
38
+
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ assert config.n_embd % config.n_head == 0
42
+ # key, query, value projections for all heads, but in a batch
43
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
44
+ # output projection
45
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
46
+ # regularization
47
+ self.attn_dropout = nn.Dropout(config.dropout)
48
+ self.resid_dropout = nn.Dropout(config.dropout)
49
+ self.n_head = config.n_head
50
+ self.n_embd = config.n_embd
51
+ self.dropout = config.dropout
52
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
53
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
54
+ if not self.flash:
55
+ print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
56
+ # causal mask to ensure that attention is only applied to the left in the input sequence
57
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
58
+ .view(1, 1, config.block_size, config.block_size))
59
+
60
+ def forward(self, x):
61
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
62
+
63
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
64
+ q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
65
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
66
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
67
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68
+
69
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
70
+ if self.flash:
71
+ # efficient attention using Flash Attention CUDA kernels
72
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
73
+ else:
74
+ # manual implementation of attention
75
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
76
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
77
+ att = F.softmax(att, dim=-1)
78
+ att = self.attn_dropout(att)
79
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
80
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
81
+
82
+ # output projection
83
+ y = self.resid_dropout(self.c_proj(y))
84
+ return y
85
+
86
+ class MLP(nn.Module):
87
+
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
91
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
92
+ self.dropout = nn.Dropout(config.dropout)
93
+
94
+ def forward(self, x):
95
+ x = self.c_fc(x)
96
+ x = new_gelu(x)
97
+ x = self.c_proj(x)
98
+ x = self.dropout(x)
99
+ return x
100
+
101
+ class Block(nn.Module):
102
+
103
+ def __init__(self, config):
104
+ super().__init__()
105
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
106
+ self.attn = CausalSelfAttention(config)
107
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
108
+ self.mlp = MLP(config)
109
+
110
+ def forward(self, x):
111
+ x = x + self.attn(self.ln_1(x))
112
+ x = x + self.mlp(self.ln_2(x))
113
+ return x
114
+
115
+ @dataclass
116
+ class GPTConfig:
117
+ block_size: int = 1024
118
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
119
+ n_layer: int = 12
120
+ n_head: int = 12
121
+ n_embd: int = 768
122
+ dropout: float = 0.0
123
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
124
+
125
+ class GPT(nn.Module):
126
+
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ assert config.vocab_size is not None
130
+ assert config.block_size is not None
131
+ self.config = config
132
+
133
+ self.transformer = nn.ModuleDict(dict(
134
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
135
+ wpe = nn.Embedding(config.block_size, config.n_embd),
136
+ drop = nn.Dropout(config.dropout),
137
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
138
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
139
+ ))
140
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
141
+ # with weight tying when using torch.compile() some warnings get generated:
142
+ # "UserWarning: functional_call was passed multiple values for tied weights.
143
+ # This behavior is deprecated and will be an error in future versions"
144
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
145
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
146
+
147
+ # init all weights
148
+ self.apply(self._init_weights)
149
+ # apply special scaled init to the residual projections, per GPT-2 paper
150
+ for pn, p in self.named_parameters():
151
+ if pn.endswith('c_proj.weight'):
152
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
153
+
154
+ # report number of parameters
155
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
156
+
157
+ def get_num_params(self, non_embedding=True):
158
+ """
159
+ Return the number of parameters in the model.
160
+ For non-embedding count (default), the position embeddings get subtracted.
161
+ The token embeddings would too, except due to the parameter sharing these
162
+ params are actually used as weights in the final layer, so we include them.
163
+ """
164
+ n_params = sum(p.numel() for p in self.parameters())
165
+ if non_embedding:
166
+ n_params -= self.transformer.wpe.weight.numel()
167
+ return n_params
168
+
169
+ def _init_weights(self, module):
170
+ if isinstance(module, nn.Linear):
171
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
172
+ if module.bias is not None:
173
+ torch.nn.init.zeros_(module.bias)
174
+ elif isinstance(module, nn.Embedding):
175
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
176
+
177
+ def forward(self, idx, targets=None):
178
+ device = idx.device
179
+ b, t = idx.size()
180
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
181
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
182
+
183
+ # forward the GPT model itself
184
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
185
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
186
+ x = self.transformer.drop(tok_emb + pos_emb)
187
+ for block in self.transformer.h:
188
+ x = block(x)
189
+ x = self.transformer.ln_f(x)
190
+
191
+ if targets is not None:
192
+ # if we are given some desired targets also calculate the loss
193
+ logits = self.lm_head(x)
194
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
195
+ else:
196
+ # inference-time mini-optimization: only forward the lm_head on the very last position
197
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
198
+ loss = None
199
+
200
+ return logits, loss
201
+
202
+ def crop_block_size(self, block_size):
203
+ # model surgery to decrease the block size if necessary
204
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
205
+ # but want to use a smaller block size for some smaller, simpler model
206
+ assert block_size <= self.config.block_size
207
+ self.config.block_size = block_size
208
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
209
+ for block in self.transformer.h:
210
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
211
+
212
+ @classmethod
213
+ def from_pretrained(cls, model_type, override_args=None):
214
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
215
+ override_args = override_args or {} # default to empty dict
216
+ # only dropout can be overridden see more notes below
217
+ assert all(k == 'dropout' for k in override_args)
218
+ from transformers import GPT2LMHeadModel
219
+ print("loading weights from pretrained gpt: %s" % model_type)
220
+
221
+ # n_layer, n_head and n_embd are determined from model_type
222
+ config_args = {
223
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
224
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
225
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
226
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
227
+ }[model_type]
228
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
229
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
230
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
231
+ config_args['bias'] = True # always True for GPT model checkpoints
232
+ # we can override the dropout rate, if desired
233
+ if 'dropout' in override_args:
234
+ print(f"overriding dropout rate to {override_args['dropout']}")
235
+ config_args['dropout'] = override_args['dropout']
236
+ # create a from-scratch initialized minGPT model
237
+ config = GPTConfig(**config_args)
238
+ model = GPT(config)
239
+ sd = model.state_dict()
240
+ sd_keys = sd.keys()
241
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
242
+
243
+ # init a huggingface/transformers model
244
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
245
+ sd_hf = model_hf.state_dict()
246
+
247
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
248
+ sd_keys_hf = sd_hf.keys()
249
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
250
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
251
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
252
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
253
+ # this means that we have to transpose these weights when we import them
254
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
255
+ for k in sd_keys_hf:
256
+ if any(k.endswith(w) for w in transposed):
257
+ # special treatment for the Conv1D weights we need to transpose
258
+ assert sd_hf[k].shape[::-1] == sd[k].shape
259
+ with torch.no_grad():
260
+ sd[k].copy_(sd_hf[k].t())
261
+ else:
262
+ # vanilla copy over the other parameters
263
+ assert sd_hf[k].shape == sd[k].shape
264
+ with torch.no_grad():
265
+ sd[k].copy_(sd_hf[k])
266
+
267
+ return model
268
+
269
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
270
+ """
271
+ This long function is unfortunately doing something very simple and is being very defensive:
272
+ We are separating out all parameters of the model into two buckets: those that will experience
273
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
274
+ We are then returning the PyTorch optimizer object.
275
+ """
276
+
277
+ # separate out all parameters to those that will and won't experience regularizing weight decay
278
+ decay = set()
279
+ no_decay = set()
280
+ whitelist_weight_modules = (torch.nn.Linear, )
281
+ blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
282
+ for mn, m in self.named_modules():
283
+ for pn, p in m.named_parameters():
284
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
285
+ # random note: because named_modules and named_parameters are recursive
286
+ # we will see the same tensors p many many times. but doing it this way
287
+ # allows us to know which parent module any tensor p belongs to...
288
+ if pn.endswith('bias'):
289
+ # all biases will not be decayed
290
+ no_decay.add(fpn)
291
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
292
+ # weights of whitelist modules will be weight decayed
293
+ decay.add(fpn)
294
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
295
+ # weights of blacklist modules will NOT be weight decayed
296
+ no_decay.add(fpn)
297
+
298
+ # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
299
+ # will appear in the no_decay and decay sets respectively after the above.
300
+ # In addition, because named_parameters() doesn't return duplicates, it
301
+ # will only return the first occurence, key'd by 'transformer.wte.weight', below.
302
+ # so let's manually remove 'lm_head.weight' from decay set. This will include
303
+ # this tensor into optimization via transformer.wte.weight only, and not decayed.
304
+ decay.remove('lm_head.weight')
305
+
306
+ # validate that we considered every parameter
307
+ param_dict = {pn: p for pn, p in self.named_parameters()}
308
+ inter_params = decay & no_decay
309
+ union_params = decay | no_decay
310
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
311
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
312
+ % (str(param_dict.keys() - union_params), )
313
+
314
+ # create the pytorch optimizer object
315
+ optim_groups = [
316
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
317
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
318
+ ]
319
+ # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
320
+ use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
321
+ print(f"using fused AdamW: {use_fused}")
322
+ extra_args = dict(fused=True) if use_fused else dict()
323
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
324
+
325
+ return optimizer
326
+
327
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
328
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
329
+ # first estimate the number of flops we do per iteration.
330
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
331
+ N = self.get_num_params()
332
+ cfg = self.config
333
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
334
+ flops_per_token = 6*N + 12*L*H*Q*T
335
+ flops_per_fwdbwd = flops_per_token * T
336
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
337
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
338
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
339
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
340
+ mfu = flops_achieved / flops_promised
341
+ return mfu
342
+
343
+ @torch.no_grad()
344
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
345
+ """
346
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
347
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
348
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
349
+ """
350
+ for _ in range(max_new_tokens):
351
+ # if the sequence context is growing too long we must crop it at block_size
352
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
353
+ # forward the model to get the logits for the index in the sequence
354
+ logits, _ = self(idx_cond)
355
+ # pluck the logits at the final step and scale by desired temperature
356
+ logits = logits[:, -1, :] / temperature
357
+ # optionally crop the logits to only the top k options
358
+ if top_k is not None:
359
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
360
+ logits[logits < v[:, [-1]]] = -float('Inf')
361
+ # apply softmax to convert logits to (normalized) probabilities
362
+ probs = F.softmax(logits, dim=-1)
363
+ # sample from the distribution
364
+ idx_next = torch.multinomial(probs, num_samples=1)
365
+ # append sampled index to the running sequence and continue
366
+ idx = torch.cat((idx, idx_next), dim=1)
367
+
368
+ return idx
out-stinfo/ckpt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44b0ededf0f7fd3535e177396c44f9d334ca8aeab408dceca2d131dffe1401c0
3
+ size 1543762794
requirements.txt ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may be used to create an environment using:
2
+ # $ conda create --name <env> --file <this file>
3
+ # platform: linux-64
4
+ _libgcc_mutex=0.1=main
5
+ _openmp_mutex=5.1=1_gnu
6
+ aiohttp=3.8.4=pypi_0
7
+ aiosignal=1.3.1=pypi_0
8
+ appdirs=1.4.4=pypi_0
9
+ async-timeout=4.0.2=pypi_0
10
+ attrs=22.2.0=pypi_0
11
+ blas=1.0=mkl
12
+ blobfile=2.0.1=pypi_0
13
+ brotlipy=0.7.0=py310h7f8727e_1002
14
+ bzip2=1.0.8=h7b6447c_0
15
+ ca-certificates=2023.01.10=h06a4308_0
16
+ certifi=2022.12.7=py310h06a4308_0
17
+ cffi=1.15.1=py310h5eee18b_3
18
+ charset-normalizer=2.0.4=pyhd3eb1b0_0
19
+ click=8.1.3=pypi_0
20
+ cryptography=38.0.4=py310h9ce1e76_0
21
+ cuda=11.6.1=0
22
+ cuda-cccl=11.6.55=hf6102b2_0
23
+ cuda-command-line-tools=11.6.2=0
24
+ cuda-compiler=11.6.2=0
25
+ cuda-cudart=11.6.55=he381448_0
26
+ cuda-cudart-dev=11.6.55=h42ad0f4_0
27
+ cuda-cuobjdump=11.6.124=h2eeebcb_0
28
+ cuda-cupti=11.6.124=h86345e5_0
29
+ cuda-cuxxfilt=11.6.124=hecbf4f6_0
30
+ cuda-driver-dev=11.6.55=0
31
+ cuda-gdb=12.0.140=0
32
+ cuda-libraries=11.6.1=0
33
+ cuda-libraries-dev=11.6.1=0
34
+ cuda-memcheck=11.8.86=0
35
+ cuda-nsight=12.0.140=0
36
+ cuda-nsight-compute=12.0.1=0
37
+ cuda-nvcc=11.6.124=hbba6d2d_0
38
+ cuda-nvdisasm=12.0.140=0
39
+ cuda-nvml-dev=11.6.55=haa9ef22_0
40
+ cuda-nvprof=12.0.146=0
41
+ cuda-nvprune=11.6.124=he22ec0a_0
42
+ cuda-nvrtc=11.6.124=h020bade_0
43
+ cuda-nvrtc-dev=11.6.124=h249d397_0
44
+ cuda-nvtx=11.6.124=h0630a44_0
45
+ cuda-nvvp=12.0.146=0
46
+ cuda-runtime=11.6.1=0
47
+ cuda-samples=11.6.101=h8efea70_0
48
+ cuda-sanitizer-api=12.0.140=0
49
+ cuda-toolkit=11.6.1=0
50
+ cuda-tools=11.6.1=0
51
+ cuda-visual-tools=11.6.1=0
52
+ datasets=2.9.0=pypi_0
53
+ dill=0.3.6=pypi_0
54
+ docker-pycreds=0.4.0=pypi_0
55
+ docopt=0.6.2=pypi_0
56
+ ffmpeg=4.3=hf484d3e_0
57
+ filelock=3.9.0=pypi_0
58
+ flit-core=3.6.0=pyhd3eb1b0_0
59
+ freetype=2.12.1=h4a9f257_0
60
+ frozenlist=1.3.3=pypi_0
61
+ fsspec=2023.1.0=pypi_0
62
+ gds-tools=1.5.1.14=0
63
+ giflib=5.2.1=h5eee18b_1
64
+ gitdb=4.0.10=pypi_0
65
+ gitpython=3.1.31=pypi_0
66
+ gmp=6.2.1=h295c915_3
67
+ gnutls=3.6.15=he1e5248_0
68
+ huggingface-hub=0.12.1=pypi_0
69
+ idna=3.4=py310h06a4308_0
70
+ intel-openmp=2021.4.0=h06a4308_3561
71
+ jpeg=9e=h7f8727e_0
72
+ lame=3.100=h7b6447c_0
73
+ lcms2=2.12=h3be6417_0
74
+ ld_impl_linux-64=2.38=h1181459_1
75
+ lerc=3.0=h295c915_0
76
+ libcublas=11.9.2.110=h5e84587_0
77
+ libcublas-dev=11.9.2.110=h5c901ab_0
78
+ libcufft=10.7.1.112=hf425ae0_0
79
+ libcufft-dev=10.7.1.112=ha5ce4c0_0
80
+ libcufile=1.5.1.14=0
81
+ libcufile-dev=1.5.1.14=0
82
+ libcurand=10.3.1.124=0
83
+ libcurand-dev=10.3.1.124=0
84
+ libcusolver=11.3.4.124=h33c3c4e_0
85
+ libcusparse=11.7.2.124=h7538f96_0
86
+ libcusparse-dev=11.7.2.124=hbbe9722_0
87
+ libdeflate=1.8=h7f8727e_5
88
+ libffi=3.4.2=h6a678d5_6
89
+ libgcc-ng=11.2.0=h1234567_1
90
+ libgomp=11.2.0=h1234567_1
91
+ libiconv=1.16=h7f8727e_2
92
+ libidn2=2.3.2=h7f8727e_0
93
+ libnpp=11.6.3.124=hd2722f0_0
94
+ libnpp-dev=11.6.3.124=h3c42840_0
95
+ libnvjpeg=11.6.2.124=hd473ad6_0
96
+ libnvjpeg-dev=11.6.2.124=hb5906b9_0
97
+ libpng=1.6.37=hbc83047_0
98
+ libstdcxx-ng=11.2.0=h1234567_1
99
+ libtasn1=4.16.0=h27cfd23_0
100
+ libtiff=4.5.0=h6a678d5_1
101
+ libunistring=0.9.10=h27cfd23_0
102
+ libuuid=1.41.5=h5eee18b_0
103
+ libwebp=1.2.4=h11a3e52_0
104
+ libwebp-base=1.2.4=h5eee18b_0
105
+ lxml=4.9.2=pypi_0
106
+ lz4-c=1.9.4=h6a678d5_0
107
+ mkl=2021.4.0=h06a4308_640
108
+ mkl-service=2.4.0=py310h7f8727e_0
109
+ mkl_fft=1.3.1=py310hd6ae3a3_0
110
+ mkl_random=1.2.2=py310h00e6091_0
111
+ multidict=6.0.4=pypi_0
112
+ multiprocess=0.70.14=pypi_0
113
+ ncurses=6.4=h6a678d5_0
114
+ nettle=3.7.3=hbbd107a_1
115
+ nsight-compute=2022.4.1.6=0
116
+ numpy=1.23.5=py310hd5efca6_0
117
+ numpy-base=1.23.5=py310h8e6c178_0
118
+ openh264=2.1.1=h4ff587b_0
119
+ openssl=1.1.1t=h7f8727e_0
120
+ packaging=23.0=pypi_0
121
+ pandas=1.5.3=pypi_0
122
+ pathtools=0.1.2=pypi_0
123
+ pillow=9.3.0=py310h6a678d5_2
124
+ pip=22.3.1=py310h06a4308_0
125
+ pipreqs=0.4.11=pypi_0
126
+ protobuf=4.22.0=pypi_0
127
+ psutil=5.9.4=pypi_0
128
+ pyarrow=11.0.0=pypi_0
129
+ pycparser=2.21=pyhd3eb1b0_0
130
+ pycryptodomex=3.17=pypi_0
131
+ pyopenssl=22.0.0=pyhd3eb1b0_0
132
+ pysocks=1.7.1=py310h06a4308_0
133
+ python=3.10.9=h7a1cb2a_0
134
+ python-dateutil=2.8.2=pypi_0
135
+ pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
136
+ pytorch-cuda=11.6=h867d48c_1
137
+ pytorch-mutex=1.0=cuda
138
+ pytz=2022.7.1=pypi_0
139
+ pyyaml=6.0=pypi_0
140
+ readline=8.2=h5eee18b_0
141
+ regex=2022.10.31=pypi_0
142
+ requests=2.28.1=py310h06a4308_0
143
+ responses=0.18.0=pypi_0
144
+ sentry-sdk=1.15.0=pypi_0
145
+ setproctitle=1.3.2=pypi_0
146
+ setuptools=65.6.3=py310h06a4308_0
147
+ six=1.16.0=pyhd3eb1b0_1
148
+ smmap=5.0.0=pypi_0
149
+ sqlite=3.40.1=h5082296_0
150
+ tiktoken=0.2.0=pypi_0
151
+ tk=8.6.12=h1ccaba5_0
152
+ tokenizers=0.13.2=pypi_0
153
+ torchaudio=0.13.1=py310_cu116
154
+ torchvision=0.14.1=py310_cu116
155
+ tqdm=4.64.1=pypi_0
156
+ transformers=4.26.1=pypi_0
157
+ typing_extensions=4.4.0=py310h06a4308_0
158
+ tzdata=2022g=h04d1e81_0
159
+ urllib3=1.26.14=py310h06a4308_0
160
+ wandb=0.13.10=pypi_0
161
+ wheel=0.38.4=py310h06a4308_0
162
+ xxhash=3.2.0=pypi_0
163
+ xz=5.2.10=h5eee18b_1
164
+ yarg=0.1.9=pypi_0
165
+ yarl=1.8.2=pypi_0
166
+ zlib=1.2.13=h5eee18b_0
167
+ zstd=1.5.2=ha4553b6_0
sample.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample from a trained model
3
+ """
4
+ import os
5
+ import pickle
6
+ from contextlib import nullcontext
7
+ import torch
8
+ import tiktoken
9
+ from model import GPTConfig, GPT
10
+
11
+ # -----------------------------------------------------------------------------
12
+ init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
13
+ out_dir = 'out-stinfo' # ignored if init_from is not 'resume'
14
+ start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
15
+ num_samples = 10 # number of samples to draw
16
+ max_new_tokens = 300 # number of tokens generated in each sample
17
+ temperature = 0.6 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
18
+ top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
19
+ seed = 1337
20
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
21
+ dtype = 'float16' # 'float32' or 'bfloat16' or 'float16'
22
+ compile = False # use PyTorch 2.0 to compile the model to be faster
23
+ exec(open('configurator.py').read()) # overrides from command line or config file
24
+ # -----------------------------------------------------------------------------
25
+
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
29
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
30
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
31
+ ptdtype = {
32
+ 'float32': torch.float32,
33
+ 'bfloat16': torch.bfloat16,
34
+ 'float16': torch.float16
35
+ }[dtype]
36
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
37
+ device_type=device_type, dtype=ptdtype)
38
+
39
+ # model
40
+ if init_from == 'resume':
41
+ # init from a model saved in a specific directory
42
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
43
+ checkpoint = torch.load(ckpt_path, map_location=device)
44
+ gptconf = GPTConfig(**checkpoint['model_args'])
45
+ model = GPT(gptconf)
46
+ state_dict = checkpoint['model']
47
+ unwanted_prefix = '_orig_mod.'
48
+ for k, v in list(state_dict.items()):
49
+ if k.startswith(unwanted_prefix):
50
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
51
+ model.load_state_dict(state_dict)
52
+ elif init_from.startswith('gpt2'):
53
+ # init from a given GPT-2 model
54
+ model = GPT.from_pretrained(init_from, dict(dropout=0.0))
55
+
56
+ model.eval()
57
+ model.to(device)
58
+ if compile:
59
+ model = torch.compile(model) # requires PyTorch 2.0 (optional)
60
+
61
+ # look for the meta pickle in case it is available in the dataset folder
62
+ load_meta = False
63
+ if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint[
64
+ 'config']: # older checkpoints might not have these...
65
+ meta_path = os.path.join('data', checkpoint['config']['dataset'],
66
+ 'meta.pkl')
67
+ load_meta = os.path.exists(meta_path)
68
+ if load_meta:
69
+ print(f"Loading meta from {meta_path}...")
70
+ with open(meta_path, 'rb') as f:
71
+ meta = pickle.load(f)
72
+ # TODO want to make this more general to arbitrary encoder/decoder schemes
73
+ stoi, itos = meta['stoi'], meta['itos']
74
+ encode = lambda s: [stoi[c] for c in s]
75
+ decode = lambda l: ''.join([itos[i] for i in l])
76
+ else:
77
+ # ok let's assume gpt-2 encodings by default
78
+ print("No meta.pkl found, assuming GPT-2 encodings...")
79
+ enc = tiktoken.get_encoding("gpt2")
80
+ encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
81
+ decode = lambda l: enc.decode(l)
82
+
83
+ # encode the beginning of the prompt
84
+ if start.startswith('FILE:'):
85
+ with open(start[5:], 'r', encoding='utf-8') as f:
86
+ start = f.read()
87
+ start_ids = encode(start)
88
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
89
+
90
+ # run generation
91
+ with torch.no_grad():
92
+ with ctx:
93
+ for k in range(num_samples):
94
+ y = model.generate(x,
95
+ max_new_tokens,
96
+ temperature=temperature,
97
+ top_k=top_k)
98
+ print(decode(y[0].tolist()))
99
+ print('---------------')
train.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This training script can be run both on a single gpu in debug mode,
3
+ and also in a larger training run with distributed data parallel (ddp).
4
+
5
+ To run on a single GPU, example:
6
+ $ python train.py --batch_size=32 --compile=False
7
+
8
+ To run with DDP on 4 gpus on 1 node, example:
9
+ $ torchrun --standalone --nproc_per_node=4 train.py
10
+
11
+ To run with DDP on 4 gpus across 2 nodes, example:
12
+ - Run on the first (master) node with example IP 123.456.123.456:
13
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
14
+ - Run on the worker node:
15
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
16
+ (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
17
+ """
18
+
19
+ import os
20
+ import time
21
+ import math
22
+ import pickle
23
+ from contextlib import nullcontext
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch.nn.parallel import DistributedDataParallel as DDP
28
+ from torch.distributed import init_process_group, destroy_process_group
29
+
30
+ from model import GPTConfig, GPT
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # default config values designed to train a gpt2 (124M) on OpenWebText
34
+ # I/O
35
+ out_dir = 'out'
36
+ eval_interval = 2000
37
+ log_interval = 1
38
+ eval_iters = 200
39
+ eval_only = False # if True, script exits right after the first eval
40
+ always_save_checkpoint = True # if True, always save a checkpoint after each eval
41
+ init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
42
+ # wandb logging
43
+ wandb_log = False # disabled by default
44
+ wandb_project = 'owt'
45
+ wandb_run_name = 'gpt2' # 'run' + str(time.time())
46
+ # data
47
+ dataset = 'openwebtext'
48
+ gradient_accumulation_steps = 5 # used to simulate larger batch sizes
49
+ batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
50
+ block_size = 1024
51
+ # model
52
+ n_layer = 12
53
+ n_head = 12
54
+ n_embd = 768
55
+ dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
56
+ bias = False # do we use bias inside LayerNorm and Linear layers?
57
+ # adamw optimizer
58
+ learning_rate = 6e-4 # max learning rate
59
+ max_iters = 600000 # total number of training iterations
60
+ weight_decay = 1e-1
61
+ beta1 = 0.9
62
+ beta2 = 0.95
63
+ grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
64
+ # learning rate decay settings
65
+ decay_lr = True # whether to decay the learning rate
66
+ warmup_iters = 2000 # how many steps to warm up for
67
+ lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
68
+ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
69
+ # DDP settings
70
+ backend = 'nccl' # 'nccl', 'gloo', etc.
71
+ # system
72
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
73
+ dtype = 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
74
+ compile = False #True # use PyTorch 2.0 to compile the model to be faster
75
+ # -----------------------------------------------------------------------------
76
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
77
+ exec(open('configurator.py').read()) # overrides from command line or config file
78
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
79
+ # -----------------------------------------------------------------------------
80
+
81
+ # various inits, derived attributes, I/O setup
82
+ ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
83
+ if ddp:
84
+ init_process_group(backend=backend)
85
+ ddp_rank = int(os.environ['RANK'])
86
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
87
+ device = f'cuda:{ddp_local_rank}'
88
+ torch.cuda.set_device(device)
89
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
90
+ seed_offset = ddp_rank # each process gets a different seed
91
+ else:
92
+ # if not ddp, we are running on a single gpu, and one process
93
+ master_process = True
94
+ seed_offset = 0
95
+ gradient_accumulation_steps *= 8 # simulate 8 gpus
96
+
97
+ if master_process:
98
+ os.makedirs(out_dir, exist_ok=True)
99
+ torch.manual_seed(1337 + seed_offset)
100
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
101
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
102
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
103
+ # note: float16 data type will automatically use a GradScaler
104
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
105
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
106
+
107
+ # poor man's data loader
108
+ data_dir = os.path.join('data', dataset)
109
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
110
+ val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
111
+ def get_batch(split):
112
+ data = train_data if split == 'train' else val_data
113
+ ix = torch.randint(len(data) - block_size, (batch_size,))
114
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
115
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
116
+ if device_type == 'cuda':
117
+ # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
118
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
119
+ else:
120
+ x, y = x.to(device), y.to(device)
121
+ return x, y
122
+
123
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
124
+ iter_num = 0
125
+ best_val_loss = 1e9
126
+
127
+ # attempt to derive vocab_size from the dataset
128
+ meta_path = os.path.join(data_dir, 'meta.pkl')
129
+ meta_vocab_size = None
130
+ if os.path.exists(meta_path):
131
+ with open(meta_path, 'rb') as f:
132
+ meta = pickle.load(f)
133
+ meta_vocab_size = meta['vocab_size']
134
+ print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
135
+
136
+ # model init
137
+ model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
138
+ bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
139
+ if init_from == 'scratch':
140
+ # init a new model from scratch
141
+ print("Initializing a new model from scratch")
142
+ # determine the vocab size we'll use for from-scratch training
143
+ if meta_vocab_size is None:
144
+ print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
145
+ model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
146
+ gptconf = GPTConfig(**model_args)
147
+ model = GPT(gptconf)
148
+ elif init_from == 'resume':
149
+ print(f"Resuming training from {out_dir}")
150
+ # resume training from a checkpoint.
151
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
152
+ checkpoint = torch.load(ckpt_path, map_location=device)
153
+ checkpoint_model_args = checkpoint['model_args']
154
+ # force these config attributes to be equal otherwise we can't even resume training
155
+ # the rest of the attributes (e.g. dropout) can stay as desired from command line
156
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
157
+ model_args[k] = checkpoint_model_args[k]
158
+ # create the model
159
+ gptconf = GPTConfig(**model_args)
160
+ model = GPT(gptconf)
161
+ state_dict = checkpoint['model']
162
+ # fix the keys of the state dictionary :(
163
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
164
+ unwanted_prefix = '_orig_mod.'
165
+ for k,v in list(state_dict.items()):
166
+ if k.startswith(unwanted_prefix):
167
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
168
+ model.load_state_dict(state_dict)
169
+ iter_num = checkpoint['iter_num']
170
+ best_val_loss = checkpoint['best_val_loss']
171
+ elif init_from.startswith('gpt2'):
172
+ print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
173
+ # initialize from OpenAI GPT-2 weights
174
+ override_args = dict(dropout=dropout)
175
+ model = GPT.from_pretrained(init_from, override_args)
176
+ # read off the created config params, so we can store them into checkpoint correctly
177
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
178
+ model_args[k] = getattr(model.config, k)
179
+ # crop down the model block size if desired, using model surgery
180
+ if block_size < model.config.block_size:
181
+ model.crop_block_size(block_size)
182
+ model_args['block_size'] = block_size # so that the checkpoint will have the right value
183
+ model.to(device)
184
+
185
+ # initialize a GradScaler. If enabled=False scaler is a no-op
186
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
187
+
188
+ # optimizer
189
+ optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
190
+ if init_from == 'resume':
191
+ optimizer.load_state_dict(checkpoint['optimizer'])
192
+
193
+ # compile the model
194
+ if compile:
195
+ print("compiling the model... (takes a ~minute)")
196
+ unoptimized_model = model
197
+ model = torch.compile(model) # requires PyTorch 2.0
198
+
199
+ # wrap model into DDP container
200
+ if ddp:
201
+ model = DDP(model, device_ids=[ddp_local_rank])
202
+
203
+ # helps estimate an arbitrarily accurate loss over either split using many batches
204
+ @torch.no_grad()
205
+ def estimate_loss():
206
+ out = {}
207
+ model.eval()
208
+ for split in ['train', 'val']:
209
+ losses = torch.zeros(eval_iters)
210
+ for k in range(eval_iters):
211
+ X, Y = get_batch(split)
212
+ with ctx:
213
+ logits, loss = model(X, Y)
214
+ losses[k] = loss.item()
215
+ out[split] = losses.mean()
216
+ model.train()
217
+ return out
218
+
219
+ # learning rate decay scheduler (cosine with warmup)
220
+ def get_lr(it):
221
+ # 1) linear warmup for warmup_iters steps
222
+ if it < warmup_iters:
223
+ return learning_rate * it / warmup_iters
224
+ # 2) if it > lr_decay_iters, return min learning rate
225
+ if it > lr_decay_iters:
226
+ return min_lr
227
+ # 3) in between, use cosine decay down to min learning rate
228
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
229
+ assert 0 <= decay_ratio <= 1
230
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
231
+ return min_lr + coeff * (learning_rate - min_lr)
232
+
233
+ # logging
234
+ if wandb_log and master_process:
235
+ import wandb
236
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
237
+
238
+ # training loop
239
+ X, Y = get_batch('train') # fetch the very first batch
240
+ t0 = time.time()
241
+ local_iter_num = 0 # number of iterations in the lifetime of this process
242
+ raw_model = model.module if ddp else model # unwrap DDP container if needed
243
+ running_mfu = -1.0
244
+ while True:
245
+
246
+ # determine and set the learning rate for this iteration
247
+ lr = get_lr(iter_num) if decay_lr else learning_rate
248
+ for param_group in optimizer.param_groups:
249
+ param_group['lr'] = lr
250
+
251
+ # evaluate the loss on train/val sets and write checkpoints
252
+ if iter_num % eval_interval == 0 and master_process:
253
+ losses = estimate_loss()
254
+ print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
255
+ if wandb_log:
256
+ wandb.log({
257
+ "iter": iter_num,
258
+ "train/loss": losses['train'],
259
+ "val/loss": losses['val'],
260
+ "lr": lr,
261
+ "mfu": running_mfu*100, # convert to percentage
262
+ })
263
+ if losses['val'] < best_val_loss or always_save_checkpoint:
264
+ best_val_loss = losses['val']
265
+ if iter_num > 0:
266
+ checkpoint = {
267
+ 'model': raw_model.state_dict(),
268
+ 'optimizer': optimizer.state_dict(),
269
+ 'model_args': model_args,
270
+ 'iter_num': iter_num,
271
+ 'best_val_loss': best_val_loss,
272
+ 'config': config,
273
+ }
274
+ print(f"saving checkpoint to {out_dir}")
275
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
276
+ if iter_num == 0 and eval_only:
277
+ break
278
+
279
+ # forward backward update, with optional gradient accumulation to simulate larger batch size
280
+ # and using the GradScaler if data type is float16
281
+ for micro_step in range(gradient_accumulation_steps):
282
+ if ddp:
283
+ # in DDP training we only need to sync gradients at the last micro step.
284
+ # the official way to do this is with model.no_sync() context manager, but
285
+ # I really dislike that this bloats the code and forces us to repeat code
286
+ # looking at the source of that context manager, it just toggles this variable
287
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
288
+ with ctx:
289
+ logits, loss = model(X, Y)
290
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
291
+ X, Y = get_batch('train')
292
+ # backward pass, with gradient scaling if training in fp16
293
+ scaler.scale(loss).backward()
294
+ # clip the gradient
295
+ if grad_clip != 0.0:
296
+ scaler.unscale_(optimizer)
297
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
298
+ # step the optimizer and scaler if training in fp16
299
+ scaler.step(optimizer)
300
+ scaler.update()
301
+ # flush the gradients as soon as we can, no need for this memory anymore
302
+ optimizer.zero_grad(set_to_none=True)
303
+
304
+ # timing and logging
305
+ t1 = time.time()
306
+ dt = t1 - t0
307
+ t0 = t1
308
+ if iter_num % log_interval == 0 and master_process:
309
+ lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
310
+ if local_iter_num >= 5: # let the training loop settle a bit
311
+ mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
312
+ running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
313
+ print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
314
+ iter_num += 1
315
+ local_iter_num += 1
316
+
317
+ # termination conditions
318
+ if iter_num > max_iters:
319
+ break
320
+
321
+ if ddp:
322
+ destroy_process_group()