abhinavv3 commited on
Commit
ccfb646
·
1 Parent(s): 512c0f0

initial commit of MEMGPT

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ data/edu_fineweb10B
2
+ log/model*
3
+ !log/model_final.pt
Readme.md ADDED
Binary file (7.89 kB). View file
 
configs/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "block_size": 1024,
4
+ "vocab_size": 50304,
5
+ "n_layer": 12,
6
+ "n_head": 12,
7
+ "n_embd": 768
8
+ },
9
+ "training": {
10
+ "max_steps": 19073,
11
+ "log_dir": "log",
12
+ "total_batch_size": 524288,
13
+ "B": 64,
14
+ "T": 1024,
15
+ "max_lr": 0.0006,
16
+ "min_lr": 0.00006,
17
+ "warmup_steps": 715,
18
+ "weight_decay": 0.1,
19
+ "learning_rate": 0.0006
20
+ }
21
+ }
data/__init__.py ADDED
File without changes
data/fineweb.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FineWeb-Edu dataset (for srs pretraining)
3
+ https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu
4
+ Downloads and tokenizes the data and saves data shards to disk.
5
+ Will save shards to the local directory "edu_fineweb10B".
6
+ """
7
+ import os
8
+ import multiprocessing as mp
9
+ import numpy as np
10
+ import tiktoken
11
+ from datasets import load_dataset
12
+ from tqdm import tqdm
13
+
14
+
15
+ local_dir = "edu_fineweb10B"
16
+ remote_name = "sample-10BT"
17
+ shard_size = int(1e8) # 100M tokens per shard, total of 100 shards
18
+
19
+ DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
20
+ os.makedirs(DATA_CACHE_DIR, exist_ok=True)
21
+ print("Shards will be saved to:",DATA_CACHE_DIR)
22
+
23
+ #dataset download
24
+ fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")
25
+
26
+ #tokenizer
27
+ enc = tiktoken.get_encoding("gpt2")
28
+ eot = enc._special_tokens['<|endoftext|>'] # end of text token
29
+
30
+ def tokenize(doc):
31
+ # tokenizes a single document and returns a numpy array of uint16 tokens
32
+ tokens = [eot]
33
+ tokens.extend(enc.encode_ordinary(doc["text"]))
34
+ tokens_np = np.array(tokens)
35
+ assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
36
+ tokens_np_uint16 = tokens_np.astype(np.uint16)
37
+ return tokens_np_uint16
38
+
39
+ def write_datafile(filename, tokens_np):
40
+ np.save(filename, tokens_np)
41
+
42
+ nprocs = max(1, os.cpu_count()//2)
43
+ with mp.Pool(nprocs) as pool:
44
+ shard_index = 0
45
+ # preallocate buffer to hold current shard
46
+ all_tokens_np = np.empty((shard_size,), dtype=np.uint16)
47
+ token_count = 0
48
+ progress_bar = None
49
+ for tokens in pool.imap(tokenize, fw, chunksize=16):
50
+
51
+ # is there enough space in the current shard for the new tokens?
52
+ if token_count + len(tokens) < shard_size:
53
+ # simply append tokens to current shard
54
+ all_tokens_np[token_count:token_count+len(tokens)] = tokens
55
+ token_count += len(tokens)
56
+ # update progress bar
57
+ if progress_bar is None:
58
+ progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"Shard {shard_index}")
59
+ progress_bar.update(len(tokens))
60
+ else:
61
+ # write the current shard and start a new one
62
+ split = "val" if shard_index == 0 else "train"
63
+ filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
64
+ # split the document into whatever fits in this shard; the remainder goes to next one
65
+ remainder = shard_size - token_count
66
+ progress_bar.update(remainder)
67
+ all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
68
+ write_datafile(filename, all_tokens_np)
69
+ shard_index += 1
70
+ progress_bar = None
71
+ # populate the next shard with the leftovers of the current doc
72
+ all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
73
+ token_count = len(tokens)-remainder
74
+
75
+ # write any remaining tokens as the last shard
76
+ if token_count != 0:
77
+ split = "val" if shard_index == 0 else "train"
78
+ filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
79
+ write_datafile(filename, all_tokens_np[:token_count])
evaluation/__init__.py ADDED
File without changes
evaluation/hellaswag.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Downloads and evaluates HellaSwag in Python.
3
+ https://github.com/rowanz/hellaswag
4
+
5
+ """
6
+ import os
7
+ import json
8
+ import requests
9
+ import tiktoken
10
+ from tqdm import tqdm
11
+ import torch
12
+ from torch.nn import functional as F
13
+
14
+ DATA_DOWNLOADED_PATH = '"data/hellaswag"'
15
+
16
+ def download_file(url:str, fname:str, chunk_size=1024):
17
+ resp = requests.get(url, stream=True)
18
+ total = int(resp.headers.get("content-length", 0 ))
19
+ with open(fname, "wb") as file, tqdm(
20
+ desc = fname,
21
+ total=total,
22
+ unit="iB",
23
+ unit_scale=True,
24
+ unit_divisor=1024
25
+ )as bar:
26
+ for data in resp.iter_content(chunk_size=chunk_size):
27
+ size = file.write(data)
28
+ bar.update(size)
29
+
30
+ hellaswags = {
31
+ "train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
32
+ "val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl",
33
+ "test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl",
34
+ }
35
+
36
+ enc = tiktoken.get_encoding("gpt2")
37
+
38
+ def download(split):
39
+ """Downloads HellaSwag DATA_DOWNLOADED_PATH"""
40
+ os.makedirs(DATA_DOWNLOADED_PATH, exist_ok=True)
41
+ data_url = hellaswags[split]
42
+ data_filename = os.path.join(DATA_DOWNLOADED_PATH, f"hellaswag_{split}.jsonl")
43
+ if not os.path.exists(data_filename):
44
+ print(f"Downloading {data_url} to {data_filename}...")
45
+ download_file(data_url, data_filename)
46
+
47
+ def render_example(example):
48
+ """
49
+ Given the example as a dictionary, render it as three torch tensors:
50
+ - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)
51
+ - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)
52
+ - label (the index of the correct completion, which we hope has the highest likelihood)
53
+ """
54
+ ctx = example["ctx"]
55
+ label = example["label"]
56
+ endings = example["endings"]
57
+
58
+ # data needed to reproduce this eval on the C size
59
+ data = {
60
+ "label": label,
61
+ "ctx_tokens": None,
62
+ "ending_tokens": [],
63
+ }
64
+
65
+ # gather up all the tokens
66
+ ctx_tokens = enc.encode(ctx)
67
+ data["ctx_tokens"] = ctx_tokens
68
+ tok_rows = []
69
+ mask_rows = []
70
+ for end in endings:
71
+ end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer
72
+ tok_rows.append(ctx_tokens + end_tokens)
73
+ mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))
74
+ data["ending_tokens"].append(end_tokens)
75
+
76
+ # have to be careful during the collation because the number of tokens in each row can differ
77
+ max_len = max(len(row) for row in tok_rows)
78
+ tokens = torch.zeros((4, max_len), dtype=torch.long)
79
+ mask = torch.zeros((4, max_len), dtype=torch.long)
80
+ for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
81
+ tokens[i, :len(tok_row)] = torch.tensor(tok_row)
82
+ mask[i, :len(mask_row)] = torch.tensor(mask_row)
83
+
84
+ return data, tokens, mask, label
85
+
86
+ def iterate_examples(split):
87
+ # there are 10,042 examples in total in val
88
+ download(split)
89
+ with open(os.path.join(DATA_DOWNLOADED_PATH, f"hellaswag_{split}.jsonl"), "r") as f:
90
+ for line in f:
91
+ example = json.loads(line)
92
+ yield example
93
+
94
+
95
+ def get_most_likely_row(tokens, mask, logits):
96
+ shift_logits = (logits[..., :-1, :]).contiguous() #this will be x for loss calculation
97
+ shift_tokens = (tokens[..., 1:]).contiguous() #this will be y for loss calculation
98
+ shift_mask = (mask[..., 1:]).contiguous() #shifting same as tokens shifted
99
+ flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
100
+ flat_shift_tokens = shift_tokens.view(-1)
101
+ shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
102
+ shift_losses = shift_losses.view(tokens.size(0), -1)
103
+ masked_shift_losses = shift_losses * shift_mask
104
+ sum_loss = masked_shift_losses.sum(dim=1)
105
+ avg_loss = sum_loss / shift_mask.sum(dim=1)
106
+ pred_norm = avg_loss.argmin().item() #taking the index of minimum loss
107
+ return pred_norm
108
+
109
+
110
+
111
+
112
+
113
+
evaluation/val_hellaswag.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..hellaswag import render_example, iterate_examples, get_most_likely_row
3
+ import torch.distributed as dist
4
+ from torch.distributed import init_process_group, destroy_process_group
5
+ from torch.nn.parallel import DistributedDataParallel as DDP
6
+ import os
7
+ from ..ModelGPT2 import GPT,log_file
8
+
9
+ ddp = int(os.environ.get('RANK', -1)) != -1 #will be True if ddp run
10
+ if ddp:
11
+ assert torch.cuda.is_available()
12
+ init_process_group(backend='nccl')
13
+ ddp_rank = int(os.environ['RANK'])
14
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
15
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
16
+ device = f"cuda:{ddp_local_rank}"
17
+ torch.cuda.set_device(device)
18
+ master_process = ddp_rank == 0 #this is the process doing checkpoint,logging,etc
19
+ else:
20
+ ddp_rank = 0
21
+ ddp_local_rank = 0
22
+ ddp_world_size = 1
23
+ master_process = True
24
+ #attempt to autodetect the device
25
+ device = 'cpu'
26
+ if torch.cuda.is_available():
27
+ device = 'cuda'
28
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
29
+ device = "mps" #for mac users use apple silicon cpu which allready have gpu.mps is backend for apple silicon
30
+ print(f"Using device: {device}")
31
+ # device = "cpu" #OVERRIDE
32
+
33
+ device_type = "cuda" if device.startswith("cuda") else "cpu"
34
+
35
+ torch.manual_seed(1337)
36
+ if torch.cuda.is_available():
37
+ torch.cuda.manual_seed(1337)
38
+
39
+
40
+ #Creating model by loading the model weights
41
+ checkpoint_path = '../log/model_final.pt'
42
+ if master_process:
43
+ print(f"Loading checkpoint from {checkpoint_path}")
44
+
45
+ checkpoint = torch.load(checkpoint_path, map_location=device)
46
+
47
+ # Extract config and create model
48
+ model_config = checkpoint['config']
49
+ model_config.vocab_size = 50304 #for computational effciency(power of 2)
50
+ model = GPT(model_config)
51
+ # Load model state dict
52
+ model.load_state_dict(checkpoint['model'])
53
+ model = DDP(model, device_ids=[ddp_local_rank])
54
+ model.to(device)
55
+
56
+
57
+ def evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process):
58
+
59
+ num_correct_norm = 0
60
+ num_total = 0
61
+
62
+ for i, example in enumerate(iterate_examples("val")):
63
+ # only process example where i % ddp_world_size ==ddp_rank#this is for proper managemnt of which part is deal by which gpu
64
+ if ddp:
65
+ if i % ddp_world_size != ddp_rank:
66
+ continue
67
+ #rendering example into tokens and labels
68
+ _, tokens, mask, label = render_example(example)
69
+ tokens = tokens.to(device)
70
+ mask = mask.to(device)
71
+ #get the logits
72
+ with torch.no_grad():
73
+ with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
74
+ logits, loss = model(tokens)
75
+ pred_norm = get_most_likely_row(tokens, mask, logits)
76
+ num_total += 1
77
+ num_correct_norm += int(pred_norm == label)
78
+ #reduce the stats accross all process
79
+ if ddp:
80
+ num_total = torch.tensor(num_total, dtype=torch.long, device=device)
81
+ num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
82
+ dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
83
+ dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
84
+ num_total = num_total.item()
85
+ num_correct_norm = num_correct_norm.item()
86
+ acc_norm = num_correct_norm / num_total #accuracy of hellaswag
87
+ if master_process:
88
+ print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
89
+ with open(log_file, "a") as f:
90
+ f.write(f"Final Hellaswag accuracy: {acc_norm:.4f}\n")
91
+
92
+ evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process)
93
+ if ddp:
94
+ destroy_process_group()
log/log.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 val 10.9528
model_core/__init__.py ADDED
File without changes
model_core/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (168 Bytes). View file
 
model_core/__pycache__/attention.cpython-311.pyc ADDED
Binary file (2.6 kB). View file
 
model_core/__pycache__/dataloader.cpython-311.pyc ADDED
Binary file (3.87 kB). View file
 
model_core/__pycache__/model.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
model_core/__pycache__/training.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
model_core/attention.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import functional as F
3
+
4
+
5
+ class CasualSelfAttention(nn.Module):
6
+
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ assert config.n_embd % config.n_head == 0
10
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
11
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
12
+ self.c_proj.NANOGPT_SCALE_INIT = 1
13
+ self.n_head = config.n_head
14
+ self.n_embd = config.n_embd
15
+
16
+ def forward(self, x):
17
+ B, T, C = x.size()
18
+ qkv = self.c_attn(x)
19
+ q, k, v = qkv.split(self.n_embd, dim=2)
20
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
21
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
22
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
23
+
24
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True) #flash attention
25
+
26
+ y = y.transpose(1,2).contiguous().view(B, T, C) # (B, T, C) basically the concat operation of differnt heads
27
+ y = self.c_proj(y)
28
+ return y
model_core/dataloader.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+
5
+ #Data loader
6
+ class DataLoaderLite:
7
+ def __init__(self, B, T, process_rank, num_processes, split, master_process):
8
+ self.B = B
9
+ self.T = T
10
+ self.process_rank = process_rank
11
+ self.num_processes = num_processes
12
+ assert split in {'train', 'val'}
13
+
14
+ #get the shard filenames
15
+ data_root = "data/edu_fineweb10B"
16
+ shards = os.listdir(data_root)
17
+ shards = [s for s in shards if split in s]
18
+ shards = sorted(shards)
19
+ shards = [os.path.join(data_root, s) for s in shards]
20
+ self.shards = shards
21
+ assert len(shards)> 0, f"no shards found for split {split}"
22
+ if master_process:
23
+ print(f"found {len(shards)} shards for split {split}")
24
+ self.reset()
25
+
26
+ def load_tokens(self, filename):
27
+ npt = np.load(filename)
28
+ npt = npt.astype(np.int32)
29
+ ptt = torch.tensor(npt, dtype=torch.long)
30
+ return ptt
31
+
32
+
33
+ def reset(self):
34
+ #state, init at shard 0
35
+ self.current_shard = 0
36
+ self.tokens = self.load_tokens(self.shards[self.current_shard])
37
+ self.current_position = self.B * self.T * self.process_rank
38
+
39
+ def next_batch(self):
40
+ B, T = self.B, self.T
41
+ buf = self.tokens[self.current_position:self.current_position + B*T+1]
42
+ x = (buf[:-1]).view(B,T) #input
43
+ y = (buf[1:]).view(B,T) #targets
44
+
45
+ self.current_position += B * T * self.num_processes
46
+
47
+ if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
48
+ self.current_shard = (self.current_shard + 1) % len(self.shards)
49
+ self.tokens = self.load_tokens(self.shards[self.current_shard])
50
+ self.current_position = B * T * self.process_rank
51
+ return x, y
52
+
53
+
model_core/model.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ import inspect
6
+ from .attention import CasualSelfAttention
7
+
8
+ class MLP(nn.Module):
9
+
10
+ def __init__(self, config):
11
+ super().__init__()
12
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
13
+ self.gelu = nn.GELU(approximate='tanh')
14
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
15
+ self.c_proj.NANOGPT_SCALE_INIT = 1
16
+
17
+ def forward(self, x):
18
+ x = self.c_fc(x)
19
+ x = self.gelu(x)
20
+ x = self.c_proj(x)
21
+ return x
22
+
23
+
24
+ class Block(nn.Module):
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.ln_1 = nn.LayerNorm(config.n_embd)
28
+ self.attn = CasualSelfAttention(config)
29
+ self.ln_2 = nn.LayerNorm(config.n_embd)
30
+ self.mlp = MLP(config)
31
+
32
+ def forward(self, x):
33
+ x = x + self.attn(self.ln_1(x))
34
+ x = x + self.mlp(self.ln_2(x))
35
+ return x
36
+
37
+
38
+ @dataclass
39
+ class GPTConfig:
40
+ block_size: int = 1024 #max sequence length
41
+ vocab_size: int = 50257 #number of tokens: 50000 BPE merges + 256 byte tokens +1 special token which is endoftext
42
+ n_layer: int = 12 #number of layers
43
+ n_head: int = 12 #number of heads
44
+ n_embd: int = 768 #embedding dimensions
45
+
46
+
47
+ class GPT(nn.Module):
48
+ def __init__(self, config):
49
+ super().__init__()
50
+ self.config = config
51
+
52
+ self.transformer = nn.ModuleDict(dict(
53
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
54
+ wpe = nn.Embedding(config.block_size, config.n_embd),
55
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
56
+ ln_f = nn.LayerNorm(config.n_embd),
57
+ ))
58
+
59
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
60
+
61
+ #Weight sharing scheme
62
+ self.transformer.wte.weight = self.lm_head.weight
63
+
64
+ # init params
65
+ self.apply(self._init_weights)
66
+
67
+ def _init_weights(self, module):
68
+ if isinstance(module, nn.Linear):
69
+ std = 0.02
70
+ if hasattr(module, 'NANOGPT_SCALE_INIT'):
71
+ std *= (2 * self.config.n_layer) ** -0.5
72
+ torch.nn.init.normal_(module.weight, mean = 0.0, std=std)
73
+ if module.bias is not None:
74
+ torch.nn.init.zeros_(module.bias)
75
+ elif isinstance(module, nn.Embedding):
76
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
77
+
78
+ def forward(self, idx, targets=None):
79
+ B, T = idx.size()
80
+ assert T <=self.config.block_size, f"Cannot forward sequence of length {T} ,block size is only {self.config.block_size}"
81
+
82
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
83
+ pos_emb = self.transformer.wpe(pos)
84
+ tok_emb = self.transformer.wte(idx)
85
+ x = tok_emb + pos_emb
86
+
87
+ for block in self.transformer.h:
88
+ x = block(x)
89
+
90
+ x = self.transformer.ln_f(x)
91
+ logits = self.lm_head(x) #(B, T, vocab_size)
92
+ loss = None
93
+ if targets is not None:
94
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
95
+
96
+ return logits, loss
97
+
98
+ def configure_optimizers(self, weight_decay, learning_rate, device_type, master_process):
99
+ #taking all candidate parameters that require grad
100
+ param_dict = {pn:p for pn, p in self.named_parameters()}
101
+ param_dict = {pn:p for pn, p in param_dict.items() if p.requires_grad}
102
+ #creating Optim groups that any parameters that 2D will be weight decayed, otherwise no.
103
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
104
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
105
+ optim_groups = [{'params':decay_params, ' weight_decay': weight_decay},
106
+ {'params':nodecay_params, 'weight_decay': 0.0}
107
+ ]
108
+ num_decay_params = sum(p.numel() for p in decay_params)
109
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
110
+ if master_process:
111
+ print(f"num decayed parameters tensors: {len(decay_params)}, with{num_decay_params}:parameters")
112
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
113
+ # Create AdamW optimizer and use the fused version if it is available
114
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
115
+ use_fused = fused_available and device_type == "cuda" #Kernal fusion for optimizer calculations
116
+ if master_process:
117
+ print(f"using fused AdamW: {use_fused}")
118
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9,0.95), eps=1e-8, fused=use_fused)
119
+ return optimizer
120
+
model_core/training.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #Setting up DDP
3
+ #torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
4
+ #run training loop
5
+
6
+ from torch.distributed import init_process_group, destroy_process_group
7
+ from torch.nn.parallel import DistributedDataParallel as DDP
8
+ import torch.distributed as dist
9
+ import os
10
+ import torch
11
+ import time
12
+ import json
13
+ import math
14
+ from .model import GPT,GPTConfig
15
+
16
+
17
+ def train_memgpt(config_path,dataloader_class=None):
18
+
19
+ with open(config_path,'r') as f:
20
+ cfg = json.load(f)
21
+
22
+ model_cfg_params = cfg['model']
23
+ train_cfg_params = cfg['training']
24
+
25
+ ddp = int(os.environ.get('RANK', -1)) != -1
26
+ if ddp:
27
+ assert torch.cuda.is_available()
28
+ init_process_group(backend='nccl')
29
+ ddp_rank = int(os.environ['RANK'])
30
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
31
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
32
+ device = f"cuda:{ddp_local_rank}"
33
+ torch.cuda.set_device(device)
34
+ master_process = ddp_rank == 0
35
+ else:
36
+ ddp_rank = 0
37
+ ddp_local_rank = 0
38
+ ddp_world_size = 1
39
+ master_process = True
40
+ device = 'cpu'
41
+ if torch.cuda.is_available():
42
+ device = 'cuda'
43
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
44
+ device = "mps"
45
+ if master_process:
46
+ print(f"Using device: {device}")
47
+
48
+ device_type = "cuda" if device.startswith("cuda") else "cpu"
49
+
50
+ torch.manual_seed(1337)
51
+ if torch.cuda.is_available():
52
+ torch.cuda.manual_seed(1337)
53
+
54
+
55
+ # --- Use loaded training parameters ---
56
+ total_batch_size = train_cfg_params['total_batch_size']
57
+ B = train_cfg_params['B']
58
+ T = train_cfg_params['T']
59
+ max_steps = train_cfg_params['max_steps']
60
+ log_dir = train_cfg_params['log_dir']
61
+ max_lr = train_cfg_params['max_lr']
62
+ min_lr = train_cfg_params['min_lr']
63
+ warmup_steps = train_cfg_params['warmup_steps']
64
+ weight_decay = train_cfg_params['weight_decay']
65
+ base_learning_rate = train_cfg_params['learning_rate']
66
+
67
+ # total_batch_size = 524288
68
+ # B = 64
69
+ # T = 1024
70
+ assert total_batch_size % (B * T * ddp_world_size) == 0
71
+ grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
72
+ if master_process:
73
+ print(f"Total desired batch size: {total_batch_size}")
74
+ print(f"Calculated gradient accumulation steps: {grad_accum_steps}")
75
+
76
+ train_loader = dataloader_class(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train",master_process=master_process)
77
+ val_loader = dataloader_class(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val",master_process=master_process)
78
+
79
+ torch.set_float32_matmul_precision('high')
80
+
81
+ # Create Model
82
+ model = GPT(GPTConfig(**model_cfg_params))
83
+ model.to(device)
84
+ use_compile = False #True THIS SHOULD CHANGE TO TRUE BEFORE TRAIING#DEBUG
85
+ if use_compile:
86
+ model = torch.compile(model)
87
+ if ddp:
88
+ model = DDP(model, device_ids=[ddp_local_rank])
89
+ raw_model = model.module if ddp else model
90
+
91
+ def get_lr(it):
92
+ if it < warmup_steps:
93
+ return max_lr * (it + 1) / warmup_steps
94
+ if it > max_steps:
95
+ return min_lr
96
+ decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
97
+ assert 0 <= decay_ratio <= 1
98
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
99
+ return min_lr + coeff * (max_lr - min_lr)
100
+
101
+ optimizer = raw_model.configure_optimizers(weight_decay=weight_decay, learning_rate=base_learning_rate, device_type=device_type, master_process=master_process)
102
+
103
+ os.makedirs(log_dir, exist_ok=True)
104
+ log_file = os.path.join(log_dir, "log.txt")
105
+ with open(log_file, "w") as f:
106
+ pass
107
+
108
+ for step in range(max_steps):
109
+ t0 = time.time()
110
+ last_step = (step == max_steps - 1)
111
+
112
+ if step % 350 == 0 or last_step:
113
+ model.eval()
114
+ val_loader.reset()
115
+ with torch.no_grad():
116
+ val_loss_accum = 0.0
117
+ val_loss_steps = 20
118
+ for _ in range(val_loss_steps):
119
+ x, y = val_loader.next_batch()
120
+ x, y = x.to(device), y.to(device)
121
+ with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
122
+ logits, loss = model(x, y)
123
+ loss = loss / val_loss_steps
124
+ val_loss_accum += loss.detach()
125
+ if ddp:
126
+ dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
127
+ if master_process:
128
+ print(f"Validation loss: {val_loss_accum.item():.4f}")
129
+ with open(log_file, "a") as f:
130
+ f.write(f"{step} val {val_loss_accum.item():.4f}\n")
131
+
132
+ checkpoint_name = f"model_final.pt" if last_step else f"model_{step:05d}.pt"
133
+ checkpoint_path = os.path.join(log_dir, checkpoint_name)
134
+
135
+ checkpoint = {
136
+ 'model': raw_model.state_dict(),
137
+ 'optimizer': optimizer.state_dict(),
138
+ 'step': step,
139
+ 'val_loss': val_loss_accum.item(),
140
+ 'config': raw_model.config
141
+ }
142
+ torch.save(checkpoint, checkpoint_path)
143
+
144
+
145
+ model.train()
146
+ optimizer.zero_grad()
147
+ loss_accum = 0.0
148
+ for micro_step in range(grad_accum_steps):
149
+ x, y = train_loader.next_batch()
150
+ x, y = x.to(device), y.to(device)
151
+ if ddp:
152
+ model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
153
+ with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
154
+ logits, loss = model(x, y)
155
+ loss = loss / grad_accum_steps
156
+ loss_accum += loss.detach()
157
+ loss.backward()
158
+
159
+ if ddp:
160
+ dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
161
+
162
+ norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
163
+ lr = get_lr(step)
164
+ for param_group in optimizer.param_groups:
165
+ param_group['lr'] = lr
166
+ optimizer.step()
167
+ if device_type == 'cuda':
168
+ torch.cuda.synchronize()
169
+ t1 = time.time()
170
+ dt = (t1 - t0) * 1000
171
+ tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
172
+ tokens_per_sec = tokens_processed / dt
173
+ if master_process:
174
+ print(f"Step:{step:5d} | Loss: {loss_accum.item():.6f} | lr: {lr:.4e} | Norm:{norm:.4f} | dt: {dt:.2f}ms | Tok/sec: {tokens_per_sec:.2f}")
175
+ with open(log_file, 'a') as f:
176
+ f.write(f"{step} train {loss_accum.item():.6f}\n")
177
+
178
+ if ddp:
179
+ destroy_process_group()
requirement.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+
3
+ safetensors==0.5.3
4
+ tiktoken==0.9.0
5
+ tokenizers==0.21.1
6
+ transformers==4.50.1
7
+ tqdm==4.67.1
8
+ requests==2.32.3
9
+ numpy<1.27,>=1.22
10
+ torch==2.3.1+cu121
rough_work.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import torch
2
+ print(torch.__version__)
3
+ print("CUDA available:", torch.cuda.is_available())
scripts/evaluate.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #To run all evaluation at once
2
+ #Code yet to be added
scripts/generate.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import tiktoken
4
+ from model import GPT
5
+
6
+ def generate_text(model, prompt, num_return_sequences=4, max_length=32, device='cuda'):
7
+ model.eval()
8
+ enc = tiktoken.get_encoding('gpt2')
9
+ tokens = enc.encode(prompt)
10
+ tokens = torch.tensor(tokens, dtype=torch.long)
11
+ tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
12
+ xgen = tokens.to(device)
13
+ sample_rng = torch.Generator(device=device)
14
+ sample_rng.manual_seed(42)
15
+
16
+ while xgen.size(1) < max_length:
17
+ with torch.no_grad():
18
+ logits, loss = model(xgen) # (B, T, vocab_size)
19
+ logits = logits[:, -1, :] # (B, vocab_size)
20
+ probs = F.softmax(logits, dim=-1) # get probabilities
21
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # topk sampling for top 50 probabilities
22
+ ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B,1), selecting a token from topk
23
+ xcol = torch.gather(topk_indices, -1, ix) # gathering corresponding indices
24
+ xgen = torch.cat((xgen, xcol), dim=1) # append to sequence
25
+
26
+ generated_texts = []
27
+ for i in range(num_return_sequences):
28
+ tokens = xgen[i, :max_length].tolist()
29
+ decoded = enc.decode(tokens)
30
+ generated_texts.append(decoded)
31
+ print(f"Sample {i + 1}: {decoded}")
32
+
33
+
34
+ return generated_texts
35
+
36
+
37
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
38
+ print(f"running with {device}")
39
+
40
+ # Extract config and create model
41
+ checkpoint_path = 'log/model_final.pt'
42
+
43
+ print(f"Loading checkpoint from {checkpoint_path}")
44
+ checkpoint = torch.load(checkpoint_path,map_location=device)
45
+ model_config = checkpoint['config']
46
+ model_config.vocab_size = 50304 #for computational effciency(power of 2)
47
+ model = GPT(model_config)
48
+
49
+ #Load model state dict
50
+ model.load_state_dict(checkpoint['model'])
51
+ model.to(device)
52
+
53
+
54
+
55
+ prompt = "Hello, I'm a language model,"
56
+
57
+ generated_texts = generate_text(
58
+ model=model,
59
+ prompt=prompt,
60
+ num_return_sequences=4,
61
+ max_length=32,
62
+ device=device
63
+ )
scripts/train.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+
5
+ from model_core.training import train_memgpt
6
+ from model_core.dataloader import DataLoaderLite
7
+
8
+ if __name__ == "__main__":
9
+ config_path = "configs/config.json"
10
+ print("Training starter")
11
+ train_memgpt(config_path=config_path,dataloader_class=DataLoaderLite)