abhishek4607 commited on
Commit
43ce548
·
verified ·
1 Parent(s): 2de7746

Upload 11 files

Browse files
__init__.py ADDED
File without changes
dataloader.cpython-311.pyc ADDED
Binary file (4.39 kB). View file
 
dataloader.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+
5
+ script_dir = os.path.dirname(__file__)
6
+
7
+
8
+ class DataLoaderLite:
9
+ """ A simple dataloader for FineWebEdu-10B dataset """
10
+
11
+ def __init__(self, B, T, process_rank, num_processes, split='train'):
12
+ super().__init__()
13
+ self.B, self.T = B, T
14
+ self.process_rank = process_rank
15
+ self.num_processes = num_processes
16
+ assert split in {'train', 'val'}
17
+
18
+ # get the shard filenames
19
+ data_root = os.path.join(script_dir, "../data/edu_fineweb10B")
20
+ shard_filenames = os.listdir(data_root)
21
+ shard_filenames = sorted([filename for filename in shard_filenames if split in filename])
22
+ self.shard_filepaths = [os.path.join(data_root, filename) for filename in shard_filenames]
23
+ assert len(self.shard_filepaths) > 0, f'no shards found for split {split}'
24
+ master_process = process_rank == 0
25
+ if master_process:
26
+ print(f'found {len(self.shard_filepaths)} shards for split {split}')
27
+ self.reset()
28
+
29
+ def load_tokens(self, filepath):
30
+ tokens = torch.tensor(np.load(filepath).astype(np.int32), dtype=torch.long)
31
+ return tokens
32
+
33
+ def reset(self):
34
+ # state, init at shard 0
35
+ self.curr_shard = 0
36
+ self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
37
+ self.curr_pos = self.B * self.T * self.process_rank
38
+
39
+ def next_batch(self):
40
+ B, T = self.B, self.T
41
+ batch = self.tokens[self.curr_pos : self.curr_pos + B*T + 1]
42
+ x_batch = batch[:-1].view(B, T)
43
+ y_batch = batch[1:].view(B, T)
44
+ self.curr_pos += B * T * self.num_processes
45
+ if self.curr_pos + (B * T + 1) > len(self.tokens):
46
+ self.curr_shard = (self.curr_shard + 1) % len(self.shard_filepaths)
47
+ self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
48
+ self.curr_pos = self.B * self.T * self.process_rank
49
+ return x_batch, y_batch
hellaswag_eval.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
hellaswag_eval.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Downloads and evaluates HellaSwag in Python.
3
+ https://github.com/rowanz/hellaswag
4
+
5
+ Example HellaSwag json item:
6
+
7
+ {"ind": 24, "activity_label": "Roof shingle removal", "ctx_a": "A man is sitting on a roof.", "ctx_b": "he", "ctx": "A man is sitting on a roof. he", "split": "val", "split_type": "indomain", "label": 3, "endings": ["is using wrap to wrap a pair of skis.", "is ripping level tiles off.", "is holding a rubik's cube.", "starts pulling up roofing on a roof."], "source_id": "activitynet~v_-JhWjGDPHMY"}
8
+
9
+ ind: dataset ID
10
+ activity_label: The ActivityNet or WikiHow label for this example
11
+ context: There are two formats. The full context is in ctx. When the context ends in an (incomplete) noun phrase, like for ActivityNet, this incomplete noun phrase is in ctx_b, and the context up until then is in ctx_a. This can be useful for models such as BERT that need the last sentence to be complete. However, it's never required. If ctx_b is nonempty, then ctx is the same thing as ctx_a, followed by a space, then ctx_b.
12
+ endings: a list of 4 endings. The correct index is given by label (0,1,2, or 3)
13
+ split: train, val, or test.
14
+ split_type: indomain if the activity label is seen during training, else zeroshot
15
+ source_id: Which video or WikiHow article this example came from
16
+
17
+ gpt2 (124M)
18
+ - eleuther harness reports acc 28.92%, acc_norm 31.14% (multiple choice style)
19
+ - this script: 10042 acc: 0.2859 acc_norm: 0.2955 (completion style)
20
+
21
+ gpt2-xl (1558M)
22
+ - eleuther harness reports acc 40.04%, acc_norm 50.89% (multiple choice style)
23
+ - this script: 10042 acc: 0.3842 acc_norm: 0.4893 (completion style)
24
+
25
+ The validation set of HellaSwag has a total of 10,042 examples.
26
+ """
27
+
28
+ import os
29
+ import json
30
+ import requests
31
+ import tiktoken
32
+ from tqdm import tqdm
33
+ import torch
34
+ import torch.nn as nn
35
+ from torch.nn import functional as F
36
+ from transformers import GPT2LMHeadModel
37
+
38
+ # -----------------------------------------------------------------------------
39
+ DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "hellaswag")
40
+
41
+ def download_file(url: str, fname: str, chunk_size=1024):
42
+ """Helper function to download a file from a given url"""
43
+ resp = requests.get(url, stream=True)
44
+ total = int(resp.headers.get("content-length", 0))
45
+ with open(fname, "wb") as file, tqdm(
46
+ desc=fname,
47
+ total=total,
48
+ unit="iB",
49
+ unit_scale=True,
50
+ unit_divisor=1024,
51
+ ) as bar:
52
+ for data in resp.iter_content(chunk_size=chunk_size):
53
+ size = file.write(data)
54
+ bar.update(size)
55
+
56
+ hellaswags = {
57
+ "train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
58
+ "val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl",
59
+ "test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl",
60
+ }
61
+
62
+ enc = tiktoken.get_encoding("gpt2")
63
+
64
+ def download(split):
65
+ """Downloads HellaSwag DATA_CACHE_DIR"""
66
+ os.makedirs(DATA_CACHE_DIR, exist_ok=True)
67
+ data_url = hellaswags[split]
68
+ data_filename = os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl")
69
+ if not os.path.exists(data_filename):
70
+ print(f"Downloading {data_url} to {data_filename}...")
71
+ download_file(data_url, data_filename)
72
+
73
+ def render_example(example):
74
+ """
75
+ Given the example as a dictionary, render it as three torch tensors:
76
+ - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)
77
+ - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)
78
+ - label (the index of the correct completion, which we hope has the highest likelihood)
79
+ """
80
+ ctx = example["ctx"]
81
+ label = example["label"]
82
+ endings = example["endings"]
83
+ # data needed to reproduce this eval on the C size
84
+ data = {
85
+ "label": label,
86
+ "ctx_tokens": None,
87
+ "ending_tokens": [],
88
+ }
89
+ # gather up all the tokens
90
+ ctx_tokens = enc.encode(ctx)
91
+ data["ctx_tokens"] = ctx_tokens
92
+ tok_rows = []
93
+ mask_rows = []
94
+ for end in endings:
95
+ end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer
96
+ tok_rows.append(ctx_tokens + end_tokens)
97
+ mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))
98
+ data["ending_tokens"].append(end_tokens)
99
+
100
+ # have to be careful during the collation because the number of tokens in each row can differ
101
+ max_len = max(len(row) for row in tok_rows)
102
+ tokens = torch.zeros((4, max_len), dtype=torch.long)
103
+ mask = torch.zeros((4, max_len), dtype=torch.long)
104
+ for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
105
+ tokens[i, :len(tok_row)] = torch.tensor(tok_row)
106
+ mask[i, :len(mask_row)] = torch.tensor(mask_row)
107
+ return data, tokens, mask, label
108
+
109
+ def iterate_examples(split):
110
+ # there are 10,042 examples in total in val
111
+ download(split)
112
+ with open(os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl"), "r") as f:
113
+ for line in f:
114
+ example = json.loads(line)
115
+ yield example
116
+
117
+ @torch.no_grad()
118
+ def evaluate(model_type, device):
119
+ torch.set_float32_matmul_precision('high') # use tf32
120
+ model = GPT2LMHeadModel.from_pretrained(model_type)
121
+ model.to(device)
122
+ # model = torch.compile(model) # optionally torch compile the model
123
+ num_correct_norm = 0
124
+ num_correct = 0
125
+ num_total = 0
126
+ for example in iterate_examples("val"):
127
+ data, tokens, mask, label = render_example(example)
128
+ tokens = tokens.to(device)
129
+ mask = mask.to(device)
130
+
131
+ # get the logits
132
+ logits = model(tokens).logits
133
+ # evaluate the autoregressive loss at all positions
134
+ shift_logits = (logits[..., :-1, :]).contiguous()
135
+ shift_tokens = (tokens[..., 1:]).contiguous()
136
+ flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
137
+ flat_shift_tokens = shift_tokens.view(-1)
138
+ shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
139
+ shift_losses = shift_losses.view(tokens.size(0), -1)
140
+ # now get the average loss just for the completion region (where mask == 1), in each row
141
+ shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
142
+ masked_shift_losses = shift_losses * shift_mask
143
+ # sum and divide by the number of 1s in the mask
144
+ sum_loss = masked_shift_losses.sum(dim=1)
145
+ avg_loss = sum_loss / shift_mask.sum(dim=1)
146
+ # now we have a loss for each of the 4 completions
147
+ # the one with the lowest loss should be the most likely
148
+ pred = sum_loss.argmin().item()
149
+ pred_norm = avg_loss.argmin().item()
150
+
151
+ # accumulate stats
152
+ num_total += 1
153
+ num_correct += int(pred == label)
154
+ num_correct_norm += int(pred_norm == label)
155
+ print(f"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}")
156
+
157
+ # debug: pretty print a few examples, and the losses in each case
158
+ if num_total < 10:
159
+ print("---")
160
+ print(f"Context:\n {example['ctx']}")
161
+ print(f"Endings:")
162
+ for i, end in enumerate(example["endings"]):
163
+ print(f"{i} (loss: {avg_loss[i].item():.4f}) {end}")
164
+ print(f"predicted: {pred_norm}, actual: {label}")
165
+
166
+
167
+ def get_most_likely_row(tokens, mask, logits):
168
+ """
169
+ helper function for HellaSwag eval. Takes tokens, mask, and logits,
170
+ returns the index of the completion with the lowest loss
171
+ """
172
+ # evaluate the autoregressive loss at all positions
173
+ shift_logits = (logits[..., :-1, :]).contiguous()
174
+ shift_tokens = (tokens[..., 1:]).contiguous()
175
+ flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
176
+ flat_shift_tokens = shift_tokens.view(-1)
177
+ shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
178
+ shift_losses = shift_losses.view(tokens.size(0), -1)
179
+ # now get the average loss just for the completion region (where mask == 1), in each row
180
+ shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
181
+ masked_shift_losses = shift_losses * shift_mask
182
+ # sum and divide by the number of 1s in the mask
183
+ sum_loss = masked_shift_losses.sum(dim=1)
184
+ avg_loss = sum_loss / shift_mask.sum(dim=1)
185
+ # now we have a loss for each of the 4 completions
186
+ # the one with the lowest loss should be the most likely
187
+ pred_norm = avg_loss.argmin().item()
188
+ return pred_norm
189
+
190
+
191
+ if __name__ == "__main__":
192
+ import argparse
193
+ parser = argparse.ArgumentParser()
194
+ parser.add_argument("-m", "--model_type", type=str, default="gpt2", help="the model type to use")
195
+ parser.add_argument("-d", "--device", type=str, default="cuda", help="the device to use")
196
+ args = parser.parse_args()
197
+ evaluate(args.model_type, args.device)
inference.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import tiktoken
5
+ from dataclasses import dataclass
6
+
7
+ from model import GPT
8
+
9
+
10
+ class GPT2Inference:
11
+ """ To generate text sequences using a trained GPT2 model """
12
+
13
+ def __init__(self, model, token_encoder, device):
14
+ self.model = model
15
+ self.token_encoder = token_encoder
16
+ self.device = device
17
+ self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
18
+
19
+ def generate_sequences(self, prompt, num_seq=5, max_tokens=50):
20
+ self.model.eval()
21
+ tokens = self.token_encoder.encode(prompt)
22
+ tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
23
+ tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
24
+ gen_tokens = tokens.to(self.device)
25
+ # create a different rng generator so as not to impact the global rng state used for training
26
+ sample_rng = torch.Generator(device=self.device).manual_seed(42)
27
+
28
+ # generate new tokens one token at a time until the sequence length becomes 'max_tokens'
29
+ while gen_tokens.shape[-1] <= max_tokens:
30
+ with torch.no_grad():
31
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
32
+ logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
33
+ logits = logits[:, -1, :] # (num_seq, vocab_size)
34
+ probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
35
+ # take top-k 50 probs
36
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
37
+ # sample a token from top-50 probabilities
38
+ ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
39
+ next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
40
+ gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
41
+ # decode generated tokens and print generated text
42
+ for i in range(num_seq):
43
+ tokens = gen_tokens[i, :max_tokens].tolist()
44
+ gen_text = self.token_encoder.decode(tokens)
45
+ print(f"> sample {i}: {gen_text}")
46
+
47
+
48
+ def parse_args():
49
+ import argparse
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument('--prompt', type=str, default="Hello, I am a language model,")
52
+ parser.add_argument('--num_seq', type=int, default=5)
53
+ parser.add_argument('--max_tokens', type=int, default=50)
54
+ args = parser.parse_args()
55
+ return args
56
+
57
+
58
+ @dataclass
59
+ class GPTConfig:
60
+ context_length: int = 1024 # max context / sequence length
61
+ vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
62
+ num_layers: int = 12
63
+ embd_size: int = 768 # embedding dim
64
+ num_heads: int = 12
65
+
66
+
67
+ def inference(args=None):
68
+ if args is None:
69
+ args = parse_args()
70
+
71
+ device = 'cpu'
72
+ if torch.cuda.is_available():
73
+ device = 'cuda'
74
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
75
+ device = 'mps' # for apple macbook GPUs
76
+ print(f'using device: {device}')
77
+
78
+ model_path = './logs/model_95364.pt'
79
+ checkpoint = torch.load(model_path, weights_only=False)
80
+ print(f"loaded model from: {model_path}")
81
+ # print(checkpoint['model'].keys())
82
+
83
+ model = GPT(config=checkpoint['config'])
84
+ model.load_state_dict(checkpoint['model'])
85
+ model = model.to(device)
86
+ token_encoder = tiktoken.get_encoding('gpt2')
87
+ generator = GPT2Inference(model, token_encoder, device)
88
+
89
+ generator.generate_sequences(args.prompt, args.num_seq, args.max_tokens)
90
+
91
+
92
+ if __name__ == '__main__':
93
+ inference()
log.txt ADDED
File without changes
model.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
model.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ import inspect
6
+
7
+
8
+ @dataclass
9
+ class GPTConfig:
10
+ context_length: int = 1024 # max context / sequence length
11
+ vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
12
+ num_layers: int = 12
13
+ embd_size: int = 768 # embedding dim
14
+ num_heads: int = 12
15
+
16
+
17
+ class CausalSelfAttention(nn.Module):
18
+ def __init__(self, config):
19
+ super().__init__()
20
+ # 'embd_size' sized vector divided into 'num_heads' heads
21
+ assert config.embd_size % config.num_heads == 0, f"embedding dim should be divisible by number of heads"
22
+ self.num_heads = config.num_heads
23
+ self.embd_size = config.embd_size
24
+ # batched key, query, and value projections for all heads
25
+ self.c_attn = nn.Linear(config.embd_size, 3 * config.embd_size)
26
+ self.c_proj = nn.Linear(config.embd_size, config.embd_size)
27
+ self.c_proj.SCALE_INIT = 1.0
28
+ # not really a bias, more of a mask, but following OpenAI/HF naming convention
29
+ # self.register_buffer("bias", torch.tril(torch.ones(config.context_length, config.context_length)).view(1, 1, config.context_length, config.context_length))
30
+
31
+ def forward(self, x):
32
+ B, T, C = x.shape
33
+ # calculate query, key, values for all heads in a batch and move head forward to be the batch dim
34
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
35
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels
36
+ qkv = self.c_attn(x) # (B, T, 3C)
37
+ q, k, v = qkv.split(self.embd_size, dim=-1) # (B,T,C), (B,T,C), (B,T,C)
38
+ q = q.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs)
39
+ k = k.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs)
40
+ v = v.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs)
41
+ # attn = q @ k.transpose(-2, -1) / np.sqrt(k.shape[-1]) # (B,nh,T,hs) @ (B,nh,hs,T) --> (B,nh,T,T)
42
+ # attn = attn.masked_fill(self.bias[:,:,:T,:T] == 0, float("-inf"))
43
+ # attn = F.softmax(attn, dim=-1)
44
+ # out = attn @ v # (B,nh,T,T) @ (B,nh,T,hs) --> (B,nh,T,hs)
45
+ # flash-attention paper (significantly faster, but logically the same as above 4 lines)
46
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # (B,nh,T,hs)
47
+ out = out.transpose(1, 2).contiguous().view(B, T, C) # (B,nh,T,hs) --> (B,T,nh,hs) --> (B,T,C=nh*hs)
48
+ out = self.c_proj(out) # (B,T,C) --> (B,T,C)
49
+ return out
50
+
51
+
52
+ class MLP(nn.Module):
53
+ def __init__(self, config):
54
+ super().__init__()
55
+ self.c_fc = nn.Linear(config.embd_size, 4 * config.embd_size)
56
+ self.gelu = nn.GELU(approximate='tanh') # approximate='tanh' used to try to reproduce gpt2 paper
57
+ self.c_proj = nn.Linear(4 * config.embd_size, config.embd_size)
58
+ self.c_proj.SCALE_INIT = 1.0
59
+
60
+ def forward(self, x):
61
+ x = self.c_fc(x)
62
+ x = self.gelu(x)
63
+ x = self.c_proj(x)
64
+ return x
65
+
66
+
67
+ class Block(nn.Module):
68
+ """ Transformer Encoder block """
69
+
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.ln_1 = nn.LayerNorm(config.embd_size)
73
+ self.attn = CausalSelfAttention(config)
74
+ self.ln_2 = nn.LayerNorm(config.embd_size)
75
+ self.mlp = MLP(config)
76
+
77
+ def forward(self, x):
78
+ x = x + self.attn(self.ln_1(x))
79
+ x = x + self.mlp(self.ln_2(x))
80
+ return x
81
+
82
+
83
+ class GPT(nn.Module):
84
+ def __init__(self, config):
85
+ super().__init__()
86
+ self.config = config
87
+ self.transformer = nn.ModuleDict(dict(
88
+ wte = nn.Embedding(self.config.vocab_size, self.config.embd_size),
89
+ wpe = nn.Embedding(self.config.context_length, self.config.embd_size),
90
+ h = nn.ModuleList([Block(self.config) for _ in range(self.config.num_layers)]),
91
+ ln_f = nn.LayerNorm(self.config.embd_size)
92
+ ))
93
+ # language modeling head
94
+ self.lm_head = nn.Linear(self.config.embd_size, self.config.vocab_size, bias=False)
95
+ # weight sharing scheme (reduces 768*50267=~40M params, fewer params, more efficient)
96
+ self.transformer.wte.weight = self.lm_head.weight
97
+ # init params (iterates over all submodules and applies _init_weights)
98
+ self.apply(self._init_weights)
99
+
100
+ def _init_weights(self, module):
101
+ if isinstance(module, nn.Linear):
102
+ std = 0.02
103
+ if hasattr(module, 'SCALE_INIT'):
104
+ std /= (2 * self.config.num_layers)**0.5
105
+ torch.nn.init.normal_(module.weight, mean=0, std=std) # as per openai gpt-2 source code
106
+ if module.bias is not None:
107
+ torch.nn.init.zeros_(module.bias)
108
+ elif isinstance(module, nn.Embedding):
109
+ torch.nn.init.normal_(module.weight, mean=0, std=0.02)
110
+
111
+ def forward(self, idx, targets=None):
112
+ B, T = idx.shape
113
+ assert T <= self.config.context_length, f'sequence length {T} should be <= {self.config.context_length}'
114
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # (T,)
115
+ pos_embd = self.transformer.wpe(pos) # (T, embd_size)
116
+ tok_embd = self.transformer.wte(idx) # (B, T, embd_size)
117
+ x = pos_embd + tok_embd # (B, T, embd_size)
118
+ for block in self.transformer.h:
119
+ x = block(x)
120
+ x = self.transformer.ln_f(x) # (B, T, embd_size)
121
+ logits = self.lm_head(x) # (B, T, vocab_size)
122
+ loss = None
123
+ if targets is not None:
124
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
125
+ return logits, loss
126
+
127
+ @classmethod
128
+ def from_pretrained(cls, model_type):
129
+ """ Loads pretrained GPT2 model weights from huggingface """
130
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
131
+ from transformers import GPT2LMHeadModel
132
+ print(f"loading weights from pretrained gpt: {model_type}")
133
+
134
+ config_args = {
135
+ 'gpt2': dict(num_layers=12, num_heads=12, embd_size=768), # 124M params
136
+ 'gpt2-medium': dict(num_layers=24, num_heads=16, embd_size=1024), # 350M params
137
+ 'gpt2-large': dict(num_layers=36, num_heads=20, embd_size=1280), # 774M params
138
+ 'gpt2-xl': dict(num_layers=48, num_heads=25, embd_size=1600), # 1558M params
139
+ }[model_type]
140
+ config_args['vocab_size'] = 50257
141
+ config_args['context_length'] = 1024
142
+
143
+ # create a from-scratch minGPT model
144
+ config = GPTConfig(**config_args)
145
+ model = GPT(config)
146
+ sd = model.state_dict()
147
+ sd_keys = sd.keys()
148
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')]
149
+
150
+ # init a huggingface transformers model
151
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
152
+ sd_hf = model_hf.state_dict()
153
+ sd_keys_hf = sd_hf.keys()
154
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')]
155
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')]
156
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
157
+
158
+ assert len(sd_keys) == len(sd_keys_hf), f"mismatched keys {len(sd_keys)} != {len(sd_keys_hf)}"
159
+
160
+ # copy while ensuring all parameters are aligned in names and shape
161
+ for k in sd_keys_hf:
162
+ if any(k.endswith(w) for w in transposed):
163
+ # need to transpose Conv1D weights
164
+ assert sd_hf[k].shape[::-1] == sd[k].shape
165
+ with torch.no_grad():
166
+ sd[k].copy_(sd_hf[k].T)
167
+ else:
168
+ assert sd_hf[k].shape == sd[k].shape
169
+ with torch.no_grad():
170
+ sd[k].copy_(sd_hf[k])
171
+ return model
172
+
173
+ def configure_optimizers(self, weight_decay, lr, device_type, master_process):
174
+ """
175
+ Essentially implements weight decay (regularization tool, by decaying the weights, we
176
+ forcing the optimizer to use more of the weights, and not allowing any single weight to dominate)
177
+ """
178
+ # start with all of the candidate params (that require gradient)
179
+ param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
180
+
181
+ # create optim groups: any parameters that are 2D will be weight decayed, otherwise no.
182
+ # i.e., all weight tensors in matmuls + embeddings will decay, whereas biases and layernorms won't be decayed
183
+ decay_params = [p for pn, p in param_dict.items() if p.dim() >= 2]
184
+ nodecay_params = [p for pn, p in param_dict.items() if p.dim() < 2]
185
+ optim_groups = [
186
+ {'params': decay_params, 'weight_decay': weight_decay},
187
+ {'params': nodecay_params, 'weight_decay': 0.0}
188
+ ]
189
+ num_decay_params = sum(p.numel() for p in decay_params)
190
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
191
+ if master_process:
192
+ print(f'num decay parameter tensors: {len(decay_params)} with {num_decay_params:,} parameters')
193
+ print(f'num nodecay parameter tensors: {len(nodecay_params)} with {num_nodecay_params:,} parameters')
194
+
195
+ # use fused version of AdamW optimizer (faster than non-fused version)
196
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
197
+ use_fused = fused_available and device_type == 'cuda'
198
+ if master_process:
199
+ print(f'using fused AdamW optimizer: {use_fused}')
200
+ optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
201
+ return optimizer
prepare_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing as mp
2
+ from datasets import load_dataset, DownloadConfig
3
+ import backoff
4
+ import os
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import tiktoken
8
+
9
+ # Function to process individual dataset items
10
+ def process_data(item):
11
+ """
12
+ Process a single dataset item.
13
+ Replace this with your actual processing logic (e.g., tokenization).
14
+ """
15
+ # Example: Tokenize text using tiktoken (adjust based on your needs)
16
+ encoder = tiktoken.get_encoding('gpt2')
17
+ text = item.get('text', '') # Assuming dataset has a 'text' field
18
+ tokens = encoder.encode(text)
19
+ return tokens
20
+
21
+ @backoff.on_exception(backoff.expo, Exception, max_tries=5)
22
+ def fetch_data(item):
23
+ """
24
+ Wrapper for process_data with exponential backoff for retries.
25
+ """
26
+ return process_data(item)
27
+
28
+ def main():
29
+ """
30
+ Main function to load and process the FineWeb-Edu dataset.
31
+ """
32
+ # Configuration
33
+ remote_name = "sample-10BT" # Dataset configuration name
34
+ output_dir = "./data" # Directory to save processed data
35
+ os.makedirs(output_dir, exist_ok=True)
36
+
37
+ # Set up download config to handle rate limits and caching
38
+ download_config = DownloadConfig(
39
+ max_retries=5,
40
+ num_proc=4, # Limit to 4 processes to avoid HTTP 429
41
+ cache_dir=Path.home() / ".cache" / "huggingface" / "datasets"
42
+ )
43
+
44
+ try:
45
+ # Load dataset with caching
46
+ print("Loading dataset...")
47
+ dataset = load_dataset(
48
+ 'HuggingFaceFW/fineweb-edu',
49
+ name=remote_name,
50
+ split='train',
51
+ download_mode="reuse_dataset_if_exists",
52
+ download_config=download_config
53
+ )
54
+ print(f"Dataset loaded with {len(dataset)} items.")
55
+
56
+ # Limit number of processes to avoid overwhelming Hugging Face Hub
57
+ nprocs = min(mp.cpu_count(), 4)
58
+ print(f"Using {nprocs} processes for multiprocessing.")
59
+
60
+ # Process dataset using multiprocessing
61
+ with mp.Pool(nprocs) as pool:
62
+ results = pool.map(fetch_data, dataset)
63
+
64
+ # Save processed results (example: save as numpy arrays)
65
+ output_path = os.path.join(output_dir, "processed_fineweb_edu.npy")
66
+ np.save(output_path, results)
67
+ print(f"Processed dataset saved to {output_path}")
68
+
69
+ except Exception as e:
70
+ print(f"Error loading or processing dataset: {e}")
71
+ raise
72
+
73
+ if __name__ == '__main__':
74
+ mp.freeze_support() # Required for Windows compatibility with executables
75
+ main()
train.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import time
5
+ from dataclasses import dataclass
6
+ import tiktoken
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.distributed as dist
11
+ from torch.nn.parallel import DistributedDataParallel as DDP
12
+ # import code; code.interact(local=locals())
13
+
14
+ from model import GPT
15
+ from dataloader import DataLoaderLite
16
+ from hellaswag_eval import render_example, iterate_examples, get_most_likely_row
17
+
18
+ torch.set_float32_matmul_precision('high') # enable TF32 precision
19
+
20
+ # set torch compile to True (if it doesn't throws any error) to speed up training
21
+ use_torch_compile = False
22
+
23
+
24
+ class Trainer:
25
+ def __init__(
26
+ self,
27
+ model,
28
+ optimizer,
29
+ train_loader,
30
+ val_loader,
31
+ token_encoder,
32
+ eval_freq,
33
+ grad_accum_steps,
34
+ ddp,
35
+ ddp_rank,
36
+ ddp_world_size,
37
+ device,
38
+ logpath
39
+ ):
40
+ self.ddp = ddp
41
+ self.ddp_rank = ddp_rank
42
+ self.master_process = ddp_rank == 0
43
+ self.ddp_world_size = ddp_world_size
44
+
45
+ self.model = model
46
+ self.optimizer = optimizer
47
+ self.train_loader = train_loader
48
+ self.val_loader = val_loader
49
+ self.token_encoder = token_encoder
50
+
51
+ self.eval_freq = eval_freq
52
+ self.grad_accum_steps = grad_accum_steps
53
+ self.device = device
54
+ self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
55
+ self.logpath = logpath
56
+
57
+
58
+ def train(
59
+ self,
60
+ max_steps,
61
+ warmup_steps,
62
+ max_lr,
63
+ min_lr
64
+ ):
65
+ for step in range(max_steps):
66
+ t0 = time.time()
67
+ self.is_last_step = (step == max_steps - 1)
68
+
69
+ # evaluate validation loss
70
+ if step % self.eval_freq == 0 or self.is_last_step:
71
+ self.evaluate_validation(step)
72
+
73
+ # evaluate model performance on HellaSwag every once in a while
74
+ if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
75
+ self.evaluate_helloswag(step)
76
+
77
+ # generate sequences from the model every once in a while
78
+ if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile):
79
+ self.generate_sequences(num_seq=5, max_tokens=32)
80
+
81
+ # training loop starts here
82
+ self.model.train() # sets model to train mode
83
+ self.optimizer.zero_grad() # resets all gradients
84
+ batch_loss = 0.0
85
+
86
+ for mini_step in range(self.grad_accum_steps):
87
+ inp, tar = self.train_loader.next_batch()
88
+ inp, tar = inp.to(self.device), tar.to(self.device)
89
+
90
+ # FORWARD PASS !!!
91
+ # autocast to bfloat16 for faster compute and memory efficiency
92
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
93
+ logits, loss = self.model(inp, tar)
94
+
95
+ # loss is scaled to account for gradient accumulation, because the gradients just add
96
+ # on each successive backward() call. Addition of gradients corresponds to SUM in the objective,
97
+ # but we want MEAN instead of a SUM
98
+ loss /= self.grad_accum_steps
99
+ batch_loss += loss.detach()
100
+
101
+ if self.ddp:
102
+ # in the final mini_step, sync and avg all gradients across all processes. used by both forward and backward processes
103
+ # can use 'no_sync()' context manager alternatively.
104
+ self.model.require_backward_grad_sync = (mini_step == self.grad_accum_steps - 1)
105
+
106
+ # each process accumulates gradients separately when 'require_backward_grad_sync'=False
107
+ # in the final 'mini_step', 'require_backward_grad_sync' becomes True, therefore
108
+ # gradients are averaged across all processes and shared among them by loss.backward()
109
+ loss.backward()
110
+
111
+ if self.ddp:
112
+ # 'batch_loss' is outside of DDP container, so need to perform 'all_reduce' to
113
+ # average out 'batch_loss' across all processes of all ranks. 'batch_loss' tensor exists on all GPUs.
114
+ # 'all_reduce' averages and deposits the result on all the processes
115
+ dist.all_reduce(batch_loss, op=dist.ReduceOp.AVG)
116
+
117
+ # once gradients are computed, clip the global l2-norm of the gradient at 1.0
118
+ norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # monitor/print 'norm'
119
+
120
+ # determine learning rate with decay
121
+ lr = self.estimate_lr(step, warmup_steps, max_steps, max_lr, min_lr)
122
+ # set learning rate for this iteration
123
+ for param_group in self.optimizer.param_groups:
124
+ param_group['lr'] = lr
125
+
126
+ self.optimizer.step()
127
+ if self.device_type == 'cuda':
128
+ torch.cuda.synchronize() # wait for the GPU to finish work
129
+
130
+ dt = (time.time() - t0) * 1000.0 # in ms
131
+ tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * self.ddp_world_size
132
+ tokens_per_sec = tokens_processed / dt
133
+
134
+ if self.master_process:
135
+ print(f'step {step:4d} | loss: {batch_loss.item():.6f} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}')
136
+ with open(self.logpath, 'a') as f:
137
+ f.write(f'{step} train {batch_loss.item():.6f}\n')
138
+
139
+
140
+ def evaluate_validation(self, step):
141
+ self.model.eval() # sets model to eval mode
142
+ self.val_loader.reset()
143
+ # evaluate the model on validation set
144
+ with torch.no_grad():
145
+ val_loss_accum = 0.0
146
+ val_steps = 20
147
+ for _ in range(val_steps):
148
+ inp, tar = self.val_loader.next_batch()
149
+ inp, tar = inp.to(self.device), tar.to(self.device)
150
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
151
+ logits, loss = self.model(inp, tar)
152
+ loss /= val_steps
153
+ val_loss_accum += loss.detach()
154
+
155
+ if self.ddp:
156
+ dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
157
+ if self.master_process:
158
+ print(f'Val loss: {val_loss_accum.item():.4f}')
159
+ with open(self.logpath, 'a') as f:
160
+ f.write(f'{step} val {val_loss_accum.item():.4f}\n')
161
+
162
+ if step > 0 and (step % 10000 == 0 or self.is_last_step):
163
+ raw_model = self.model.module if self.ddp else self.model
164
+ logdir = os.path.dirname(self.logpath)
165
+ ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt')
166
+ checkpoint = {
167
+ 'model': raw_model.state_dict(),
168
+ 'config': raw_model.config,
169
+ 'step': step,
170
+ 'val_loss': val_loss_accum.item()
171
+ } # add optimizer.state_dict(), rng_seeds, etc. if resuming training
172
+ torch.save(checkpoint, ckpt_path)
173
+
174
+
175
+ def evaluate_helloswag(self, step):
176
+ """
177
+ Construct a batch of 4 sequences and perform token completion using
178
+ our model.
179
+ """
180
+ n_total = 0
181
+ n_correct_norm = 0
182
+ for i, example in enumerate(iterate_examples('val')):
183
+ # only process examples where i % ddp_world_size == ddp_rank
184
+ if i % self.ddp_world_size != self.ddp_rank:
185
+ continue
186
+ # render the example into tokens and labels
187
+ _, tokens, mask, label = render_example(example) # (4,N), (4,N), (4,N)
188
+ tokens, mask = tokens.to(self.device), mask.to(self.device)
189
+ with torch.no_grad():
190
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
191
+ logits, loss = self.model(tokens)
192
+ pred_norm = get_most_likely_row(tokens, mask, logits)
193
+ n_total += 1
194
+ n_correct_norm += int(pred_norm == label)
195
+ # reduce the stats across all processes
196
+ if self.ddp:
197
+ n_total = torch.tensor(n_total, device=self.device, dtype=torch.long)
198
+ n_correct_norm = torch.tensor(n_correct_norm, device=self.device, dtype=torch.long)
199
+ dist.all_reduce(n_total, op=dist.ReduceOp.SUM)
200
+ dist.all_reduce(n_correct_norm, op=dist.ReduceOp.SUM)
201
+ n_total = n_total.item()
202
+ n_correct_norm = n_correct_norm.item()
203
+ acc_norm = n_correct_norm / n_total
204
+ if self.master_process:
205
+ print(f'HelloSwag accuracy: {n_correct_norm}/{n_total}={acc_norm:.4f}')
206
+ with open(self.logpath, 'a') as f:
207
+ f.write(f'{step} hellaswag {acc_norm:.4f}\n')
208
+
209
+
210
+ def generate_sequences(self, num_seq=4, max_tokens=32):
211
+ self.model.eval()
212
+ tokens = self.token_encoder.encode("Hello, I am a language model")
213
+ tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length
214
+ tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n)
215
+ gen_tokens = tokens.to(self.device)
216
+ # create a different rng generator so as not to impact the global rng state used for training
217
+ sample_rng = torch.Generator(device=self.device)
218
+ # adding 'ddp_rank' in seeding to generate different tokens for different rank processes
219
+ sample_rng.manual_seed(42 + self.ddp_rank)
220
+ # generate new tokens one token at a time until the sequence length becomes 'max_tokens'
221
+ while gen_tokens.shape[-1] <= max_tokens:
222
+ with torch.no_grad():
223
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
224
+ logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size)
225
+ logits = logits[:, -1, :] # (num_seq, vocab_size)
226
+ probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size)
227
+ # take top-k 50 probs
228
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50)
229
+ # sample a token from top-50 probabilities
230
+ ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1)
231
+ next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1)
232
+ gen_tokens = torch.cat([gen_tokens, next_tok], dim=1)
233
+ # decode generated tokens and print generated text
234
+ for i in range(num_seq):
235
+ tokens = gen_tokens[i, :max_tokens].tolist()
236
+ gen_text = self.token_encoder.decode(tokens)
237
+ print(f"> rank {self.ddp_rank} sample {i}: {gen_text}")
238
+
239
+
240
+ def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr):
241
+ """
242
+ Learning rate scheduler: Cosine-decay learning schedule with warmup
243
+ """
244
+ # 1) linear warmup for 'warmup_iters' steps
245
+ if step < warmup_steps:
246
+ return max_lr * (step+1) / warmup_steps
247
+ # 2) if step > lr_decay_iters, return min lr
248
+ if step > max_steps:
249
+ return min_lr
250
+ # 3) in between, use cosine decay down to min lr
251
+ decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
252
+ assert 0 <= decay_ratio <= 1
253
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
254
+ return min_lr + coeff * (max_lr - min_lr)
255
+
256
+
257
+ @dataclass
258
+ class GPTConfig:
259
+ context_length: int = 1024 # max context / sequence length
260
+ vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 <endoftext> token
261
+ num_layers: int = 12
262
+ embd_size: int = 768 # embedding dim
263
+ num_heads: int = 12
264
+
265
+
266
+ def get_args():
267
+ import argparse
268
+ parser = argparse.ArgumentParser(description="Hyperparameter Configuration")
269
+ parser.add_argument("--total_batch_size", type=int, default=524288, help="number of tokens processed for each weight update") # =2^19 tokens/step update, (~0.5M tokens used in openai gpt3 paper)
270
+ parser.add_argument("--mini_batch_size", type=int, default=32, help="setting of mini_batch_size is just a performance optimization. bigger gpu, bigger mini_batch_size")
271
+ parser.add_argument("--context_length", type=int, default=1024) # max sequence length (can also try 2048)
272
+ parser.add_argument("--num_layers", type=int, default=12)
273
+ parser.add_argument("--embd_size", type=int, default=768)
274
+ parser.add_argument("--num_heads", type=int, default=12)
275
+ parser.add_argument("--max_lr", type=float, default=1e-3)
276
+ parser.add_argument("--min_lr", type=float, default=1e-3 * 0.1)
277
+ parser.add_argument("--warmup_steps", type=int, default=715)
278
+ parser.add_argument("--weight_decay", type=float, default=0.1)
279
+ parser.add_argument("--num_epochs", type=int, default=5)
280
+ parser.add_argument("--steps_per_epoch", type=int, default=19073) # 10^10 / 2^19 ~ 19073 for 1 epoch on FineWebEdu-sample10BT
281
+ parser.add_argument("--eval_freq", type=int, default=250)
282
+ # parser.add_argument("--use_torch_compile", action='store_true') # default False
283
+ parser.add_argument("--seed", type=int, default=1337, help="Random seed for reproducibility")
284
+ parser.add_argument("--logdir", type=str, default="./logs/")
285
+ return parser.parse_args()
286
+
287
+
288
+ def main():
289
+ args = get_args()
290
+
291
+ # Print the hyperparameters
292
+ print("Hyperparameter Configuration:")
293
+ for key, value in vars(args).items():
294
+ print(f"{key}: {value}")
295
+
296
+ # create the logs directory if it doesn't exist
297
+ os.makedirs(args.logdir, exist_ok=True)
298
+ logpath = os.path.join(args.logdir, 'log.txt')
299
+ with open(logpath, 'w') as f:
300
+ pass
301
+
302
+ # set up DDP (distributed data parallel)
303
+ # 'torchrun' command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
304
+ # RANK and LOCAL_RANK same for (single node, multi-GPU) settings, may differ for (multinode,
305
+ # multi GPU) settings.
306
+ ddp = int(os.environ.get('RANK', -1)) != -1 # if this is a ddp run or not
307
+ if ddp:
308
+ # use of ddp requires CUDA
309
+ assert torch.cuda.is_available(), f'use of DDP requires CUDA'
310
+ dist.init_process_group(backend='nccl')
311
+ ddp_rank = int(os.environ['RANK'])
312
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
313
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
314
+ device = f'cuda:{ddp_local_rank}'
315
+ torch.cuda.set_device(device)
316
+ # master process (arbitrarily set to 0) will do printing, logging, checkpointing, etc.
317
+ master_process = ddp_rank == 0
318
+ else:
319
+ # not using ddp
320
+ ddp_rank = 0
321
+ ddp_local_rank = 0
322
+ ddp_world_size = 1
323
+ master_process = True # ddp_rank == 0
324
+ device = 'cpu'
325
+ if torch.cuda.is_available():
326
+ device = 'cuda'
327
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
328
+ device = 'mps' # for apple macbook GPUs
329
+ print(f'using device: {device}')
330
+
331
+ device_type = 'cuda' if device.startswith('cuda') else 'cpu'
332
+
333
+ # setting seed for reproducibility
334
+ np.random.seed(args.seed)
335
+ torch.manual_seed(args.seed) # sets seed for random number generation on CPU
336
+ if torch.cuda.is_available():
337
+ torch.cuda.manual_seed(args.seed) # sets seed for random number generation on GPU
338
+ torch.cuda.manual_seed_all(args.seed) # sets seed for all GPUs
339
+
340
+ assert args.total_batch_size % (args.mini_batch_size * args.context_length * ddp_world_size) == 0, f'ensure total_batch_size divisible by B*T*ddp_world_size'
341
+ grad_accum_steps = args.total_batch_size // (args.mini_batch_size * args.context_length * ddp_world_size)
342
+ if master_process:
343
+ print(f'desired batch size (number of tokens): {args.total_batch_size}')
344
+ print(f'gradient accumulation steps: {grad_accum_steps}')
345
+ print(f'GPU: {ddp_rank}, {ddp_local_rank}')
346
+
347
+ train_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='train')
348
+ val_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='val')
349
+
350
+ # create GPT model. each ddp process will create its own instance of the model but since the seed is fixed,
351
+ # they will create same identical model
352
+ gpt_config = GPTConfig(vocab_size=50304, # 50304 (nice number, lots of power of 2s) used instead of 50257 (bad, odd number)
353
+ context_length=args.context_length,
354
+ num_layers=args.num_layers,
355
+ num_heads=args.num_heads,
356
+ embd_size=args.embd_size
357
+ )
358
+ model = GPT(config=gpt_config)
359
+ # model = GPT.from_pretrained('gpt2') # init from OpenAI GPT-2
360
+ model.to(device) # move model to device
361
+ if use_torch_compile:
362
+ # use torch compile almost always unless debugging (requires compilation time, but makes training faster)
363
+ # speedup comes from reducing python overhead and GPU read/write
364
+ model = torch.compile(model)
365
+
366
+ if ddp:
367
+ # wraps the model in DDP container (forward pass is unchanged, but after backward pass,
368
+ # gradients computed across each processes averaged by DDP using 'AllReduce' and shared across
369
+ # all processes so that each process has same gradients)
370
+ model = DDP(model, device_ids=[ddp_local_rank])
371
+
372
+ raw_model = model.module if ddp else model
373
+ optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, lr=args.max_lr, device_type=device_type, master_process=master_process)
374
+ token_encoder = tiktoken.get_encoding('gpt2')
375
+
376
+ start_time = time.time()
377
+ # init the trainer object
378
+ trainer = Trainer(model, optimizer, train_loader, val_loader, token_encoder, args.eval_freq, grad_accum_steps,
379
+ ddp, ddp_rank, ddp_world_size, device, logpath)
380
+
381
+ max_steps = args.steps_per_epoch * args.num_epochs
382
+ trainer.train(max_steps, args.warmup_steps, args.max_lr, args.min_lr)
383
+
384
+ dt = (time.time() - start_time) / (60*60)
385
+ print(f"Total training time: {dt:.4f}hr")
386
+
387
+ if ddp:
388
+ dist.destroy_process_group()
389
+
390
+
391
+ if __name__ == "__main__":
392
+ main()