AliMuhammad73 commited on
Commit
4912248
·
1 Parent(s): 1ec86a4
Naive_gpt/__init__.py DELETED
File without changes
Naive_gpt/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (143 Bytes)
 
Naive_gpt/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (183 Bytes)
 
Naive_gpt/__pycache__/config.cpython-310.pyc DELETED
Binary file (489 Bytes)
 
Naive_gpt/__pycache__/config.cpython-312.pyc DELETED
Binary file (629 Bytes)
 
Naive_gpt/__pycache__/data_loader.cpython-310.pyc DELETED
Binary file (3.13 kB)
 
Naive_gpt/__pycache__/data_loader.cpython-312.pyc DELETED
Binary file (5.69 kB)
 
Naive_gpt/__pycache__/model.cpython-310.pyc DELETED
Binary file (6.99 kB)
 
Naive_gpt/__pycache__/model.cpython-312.pyc DELETED
Binary file (11.9 kB)
 
Naive_gpt/__pycache__/train.cpython-310.pyc DELETED
Binary file (3.39 kB)
 
Naive_gpt/__pycache__/train.cpython-312.pyc DELETED
Binary file (1.9 kB)
 
Naive_gpt/__pycache__/train_autoshardload_ddp.cpython-310.pyc DELETED
Binary file (5.87 kB)
 
Naive_gpt/__pycache__/train_ddp.cpython-310.pyc DELETED
Binary file (4.39 kB)
 
Naive_gpt/config.py DELETED
@@ -1,16 +0,0 @@
1
- # Model/config.py
2
- import torch
3
-
4
- # Define hyperparameters and constants
5
- BATCH_SIZE = 16
6
- BLOCK_SIZE = 1024
7
- MAX_ITERS = 5
8
- EVAL_INTERVAL = 500
9
- LEARNING_RATE = 6e-4
10
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- EVAL_ITERS = 200
12
- N_EMBD = 768
13
- N_HEAD = 12
14
- N_LAYER = 12
15
- DROPOUT = 0.2
16
- MODEL_PATH = "Naive_gpt\model_weights_llama" # Where to save weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Naive_gpt/config_1.5B.py DELETED
@@ -1,16 +0,0 @@
1
- # Model/config.py
2
- import torch
3
-
4
- # Define hyperparameters and constants
5
- BATCH_SIZE = 16
6
- BLOCK_SIZE = 1024
7
- MAX_ITERS = 5
8
- EVAL_INTERVAL = 500
9
- LEARNING_RATE = 6e-4
10
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- EVAL_ITERS = 200
12
- N_EMBD = 1600
13
- N_HEAD = 25
14
- N_LAYER = 48
15
- DROPOUT = 0.2
16
- MODEL_PATH = "Naive_gpt\model_weights_llama" # Where to save weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Naive_gpt/config_355.py DELETED
@@ -1,16 +0,0 @@
1
- # Model/config.py
2
- import torch
3
-
4
- # Define hyperparameters and constants
5
- BATCH_SIZE = 16
6
- BLOCK_SIZE = 1024
7
- MAX_ITERS = 5
8
- EVAL_INTERVAL = 500
9
- LEARNING_RATE = 6e-4
10
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- EVAL_ITERS = 200
12
- N_EMBD = 1024
13
- N_HEAD = 16
14
- N_LAYER = 24
15
- DROPOUT = 0.2
16
- MODEL_PATH = "Naive_gpt\model_weights_llama" # Where to save weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Naive_gpt/config_770.py DELETED
@@ -1,16 +0,0 @@
1
- # Model/config.py
2
- import torch
3
-
4
- # Define hyperparameters and constants
5
- BATCH_SIZE = 16
6
- BLOCK_SIZE = 1024
7
- MAX_ITERS = 5
8
- EVAL_INTERVAL = 500
9
- LEARNING_RATE = 6e-4
10
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- EVAL_ITERS = 200
12
- N_EMBD = 1280
13
- N_HEAD = 20
14
- N_LAYER = 36
15
- DROPOUT = 0.2
16
- MODEL_PATH = "Naive_gpt\model_weights_llama" # Where to save weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Naive_gpt/data_loader.py DELETED
@@ -1,252 +0,0 @@
1
- # Model/data_loader.py
2
- import torch
3
- import os
4
- import logging
5
-
6
- # Set up logging
7
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
-
9
- class TextDataLoader:
10
- def __init__(self, file_path, batch_size, block_size, tokenizer, chunk_size=10**4):
11
- self.file_path = file_path
12
- self.batch_size = batch_size
13
- self.block_size = block_size
14
- self.tokenizer = tokenizer
15
- self.chunk_size = chunk_size
16
- self.file = open(self.file_path, 'r', encoding='utf-8')
17
- self.data = None
18
- self.end_of_file = False
19
-
20
- # Load the initial chunk of data
21
- self.load_chunk()
22
-
23
- def load_chunk(self):
24
- """Load a chunk from the file, encode it, and handle end-of-file conditions."""
25
- text = self.file.read()
26
- if not text:
27
- self.end_of_file = True
28
- logging.info("End of file reached.")
29
- else:
30
- try:
31
- # Encode the text using the tokenizer
32
- encoded = self.tokenizer.encode(text)
33
- if len(encoded) > 0:
34
- self.data = torch.tensor(encoded, dtype=torch.long)
35
- logging.info(f"Loaded new data chunk of size: {len(self.data)} tokens.")
36
- # save the encoded data to a file
37
- torch.save(self.data, "encoded_data.pth")
38
- except Exception as e:
39
- logging.error(f"Error encoding text chunk: {e}")
40
- self.end_of_file = True
41
-
42
- def num_batches(self):
43
- """Calculate the total number of batches in the current chunk."""
44
- if self.data is not None:
45
- return (len(self.data) - 1) // self.block_size # Total batches in the current chunk
46
- return 0
47
-
48
- def get_batch(self):
49
- """Retrieve a batch of data from the current chunk or load a new chunk if needed."""
50
- if self.end_of_file:
51
- return None, None # Return None when no data is left
52
-
53
- # Generate a batch of data
54
- ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
55
- x = torch.stack([self.data[i:i+self.block_size] for i in ix])
56
- y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
57
- return x, y
58
-
59
- def reset(self):
60
- """Reset the file and flags for a new epoch."""
61
- self.file.seek(0)
62
- self.end_of_file = False
63
- logging.info("Resetting file for a new epoch.")
64
- self.load_chunk()
65
-
66
- def close(self):
67
- """Clean up file resources when done."""
68
- self.file.close()
69
- logging.info("File closed.")
70
-
71
- def __iter__(self):
72
- """Make the data loader iterable so it can be used in a loop."""
73
- while not self.end_of_file:
74
- x, y = self.get_batch()
75
- if x is None or y is None:
76
- break # Stop iteration if there's no more data
77
-
78
- yield x, y # Yield a batch of data
79
-
80
- # Once iteration is done, close the file
81
- self.close()
82
-
83
- #before parallelizing
84
- # Set up logging
85
- # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
86
-
87
- # class TextDataLoader:
88
- # def __init__(self, file_path, batch_size, block_size, tokenizer, device='cpu', chunk_size=10**4):
89
- # self.file_path = file_path
90
- # self.batch_size = batch_size
91
- # self.block_size = block_size
92
- # self.tokenizer = tokenizer
93
- # self.device = device
94
- # self.chunk_size = chunk_size
95
- # self.file = open(self.file_path, 'r', encoding='utf-8')
96
- # self.data = None
97
- # self.end_of_file = False
98
-
99
- # # Load the initial chunk of data
100
- # self.load_chunk()
101
-
102
- # def load_chunk(self):
103
- # """Load a chunk from the file, encode it, and handle end-of-file conditions."""
104
- # text = self.file.read()
105
- # if not text:
106
- # self.end_of_file = True
107
- # logging.info("End of file reached.")
108
- # else:
109
- # try:
110
- # # Encode the text using the tokenizer
111
- # encoded = self.tokenizer.encode(text)
112
- # if len(encoded) > 0:
113
- # self.data = torch.tensor(encoded, dtype=torch.long).to(self.device)
114
- # logging.info(f"Loaded new data chunk of size: {len(self.data)} tokens.")
115
- # except Exception as e:
116
- # logging.error(f"Error encoding text chunk: {e}")
117
- # self.end_of_file = True
118
-
119
- # def num_batches(self):
120
- # """Calculate the total number of batches in the current chunk."""
121
- # if self.data is not None:
122
- # return (len(self.data) - 1) // self.block_size # Total batches in the current chunk
123
- # return 0
124
-
125
- # def get_batch(self):
126
- # """Retrieve a batch of data from the current chunk or load a new chunk if needed."""
127
- # if self.end_of_file:
128
- # return None, None # Return None when no data is left
129
-
130
- # # Generate a batch of data
131
- # ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
132
- # x = torch.stack([self.data[i:i+self.block_size] for i in ix])
133
- # y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
134
- # return x, y
135
-
136
- # def reset(self):
137
- # """Reset the file and flags for a new epoch."""
138
- # self.file.seek(0)
139
- # self.end_of_file = False
140
- # logging.info("Resetting file for a new epoch.")
141
- # self.load_chunk()
142
-
143
- # def close(self):
144
- # """Clean up file resources when done."""
145
- # self.file.close()
146
- # logging.info("File closed.")
147
-
148
- # def __iter__(self):
149
- # """Make the data loader iterable so it can be used in a loop."""
150
- # while not self.end_of_file:
151
- # x, y = self.get_batch()
152
- # if x is None or y is None:
153
- # break # Stop iteration if there's no more data
154
-
155
- # yield x, y # Yield a batch of data
156
-
157
- # # Once iteration is done, close the file
158
- # self.close()
159
-
160
-
161
- # # Set up logging
162
- # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
163
-
164
- # class TextDataLoader:
165
- # def __init__(self, file_path, batch_size, block_size, tokenizer, device='cpu', chunk_size=10**4):
166
- # self.file_path = file_path
167
- # self.batch_size = batch_size
168
- # self.block_size = block_size
169
- # self.tokenizer = tokenizer
170
- # self.device = device
171
- # self.chunk_size = chunk_size
172
- # self.file = open(self.file_path, 'r', encoding='utf-8')
173
- # self.data = None
174
- # self.end_of_file = False
175
-
176
- # # Print a preview of the file
177
- # # self.print_file_preview()
178
-
179
- # # Initial chunk loading
180
- # self.load_chunk()
181
-
182
- # def print_file_preview(self):
183
- # """Prints the first few lines of the text file for preview"""
184
- # self.file.seek(0) # Go to the beginning of the file
185
- # lines = [self.file.readline() for _ in range(5)]
186
- # preview_text = ''.join(lines)
187
- # print("File preview:\n", preview_text)
188
- # self.file.seek(0) # Reset to the start of the file for chunk reading
189
-
190
- # def load_chunk(self):
191
- # """Load a chunk from the file, encode it, and handle end-of-file conditions."""
192
- # text = self.file.read()
193
- # if not text:
194
- # self.end_of_file = True
195
- # logging.info("End of file reached.")
196
- # else:
197
- # try:
198
- # # Log the first 100 characters of the text chunk to verify Urdu content
199
- # # logging.info(f"First 100 characters of the chunk: {text[:100]}")
200
- # # print("This is the chunk:", text)
201
-
202
- # # Encode the text using the tokenizer
203
- # # print("Tokenizer:", self.tokenizer)
204
- # encoded = self.tokenizer.encode(text)
205
- # print(len(encoded))
206
- # print("encoded data: ")
207
-
208
- # # Log the encoded output length to confirm successful encoding
209
- # logging.info(f"Encoded data length: {len(encoded)} tokens")
210
-
211
- # # if len(encoded) < self.block_size:
212
- # # # Only stop if there's absolutely no usable data left
213
- # # self.end_of_file = len(encoded) == 0
214
- # # if self.end_of_file:
215
- # # logging.warning("Insufficient data in chunk; stopping further loading.")
216
- # # else:
217
- # # logging.warning("Data chunk smaller than block size loaded; may limit training batch size.")
218
-
219
- # if len(encoded) > 0:
220
- # self.data = torch.tensor(encoded, dtype=torch.long).to(self.device)
221
- # logging.info(f"Loaded new data chunk of size: {len(self.data)} tokens.")
222
- # except Exception as e:
223
- # logging.error(f"Error encoding text chunk: {e}")
224
- # self.end_of_file = True
225
-
226
- # def get_batch(self):
227
- # """Retrieve a batch of data from the current chunk or load a new chunk if needed."""
228
- # # if self.end_of_file:
229
- # # return None, None # Return None when no data is left
230
-
231
- # # if self.data is None or len(self.data) <= self.block_size:
232
- # # self.load_chunk()
233
- # # if self.end_of_file or self.data is None or len(self.data) < self.block_size:
234
- # # return None, None # Stop if there’s insufficient data
235
-
236
- # # Generate a batch of data
237
- # ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
238
- # x = torch.stack([self.data[i:i+self.block_size] for i in ix])
239
- # y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
240
- # return x, y
241
-
242
- # def reset(self):
243
- # """Reset the file and flags for a new epoch."""
244
- # self.file.seek(0)
245
- # self.end_of_file = False
246
- # logging.info("Resetting file for a new epoch.")
247
- # self.load_chunk()
248
-
249
- # def close(self):
250
- # """Clean up file resources when done."""
251
- # self.file.close()
252
- # logging.info("File closed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Naive_gpt/model.py DELETED
@@ -1,258 +0,0 @@
1
- # Model/model.py
2
- import torch
3
- import torch.nn as nn
4
- from .config import *
5
- import inspect
6
-
7
-
8
- class CausalSelfAttention(nn.Module):
9
-
10
- def __init__(self):
11
- super().__init__()
12
- assert N_EMBD % N_HEAD == 0
13
- # key, query, value projections for all heads, but in a batch
14
- self.c_attn = nn.Linear(N_EMBD, 3 * N_EMBD)
15
- # output projection
16
- self.c_proj = nn.Linear(N_EMBD, N_EMBD)
17
- self.c_proj.NANOGPT_SCALE_INIT = 1
18
- # regularization
19
- self.n_head = N_HEAD
20
- self.n_embd = N_EMBD
21
-
22
- def forward(self, x):
23
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
24
- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
25
- # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
26
- # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
27
- qkv = self.c_attn(x)
28
- q, k, v = qkv.split(self.n_embd, dim=2)
29
- k = k.view(B, T, self.n_head, C //
30
- self.n_head).transpose(1, 2) # (B, nh, T, hs)
31
- q = q.view(B, T, self.n_head, C //
32
- self.n_head).transpose(1, 2) # (B, nh, T, hs)
33
- v = v.view(B, T, self.n_head, C //
34
- self.n_head).transpose(1, 2) # (B, nh, T, hs)
35
- y = nn.functional.scaled_dot_product_attention(
36
- q, k, v, is_causal=True) # flash attention
37
- # re-assemble all head outputs side by side
38
- y = y.transpose(1, 2).contiguous().view(B, T, C)
39
- # output projection
40
- y = self.c_proj(y)
41
- return y
42
-
43
- # class Head(nn.Module): #this is sebastian's causal attention
44
- # """ one head of self-attention """
45
-
46
- # def __init__(self, head_size):
47
- # super().__init__()
48
- # self.key = nn.Linear(N_EMBD, head_size, bias=False)
49
- # self.query = nn.Linear(N_EMBD, head_size, bias=False)
50
- # self.value = nn.Linear(N_EMBD, head_size, bias=False)
51
- # self.register_buffer('tril', torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))
52
- # self.dropout = nn.Dropout(DROPOUT)
53
-
54
- # def forward(self, x):
55
- # B, T, C = x.shape
56
- # k = self.key(x) # (B, T, head_size)
57
- # q = self.query(x) # (B, T, head_size)
58
- # wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
59
- # wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
60
- # wei = nn.functional.softmax(wei, dim=-1)
61
- # wei = self.dropout(wei)
62
- # v = self.value(x)
63
- # out = wei @ v
64
- # return out
65
-
66
- # class MultiHeadAttention(nn.Module):
67
- # """ multiple heads of self-attention in parallel """
68
-
69
- # def __init__(self, num_heads, head_size):
70
- # super().__init__()
71
- # self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
72
- # self.proj = nn.Linear(head_size * num_heads, N_EMBD)
73
- # self.dropout = nn.Dropout(DROPOUT)
74
-
75
- # def forward(self, x):
76
- # out = torch.cat([h(x) for h in self.heads], dim=-1)
77
- # out = self.dropout(self.proj(out))
78
- # return out
79
-
80
- class FeedFoward(nn.Module): #yeh MLP hai karpathy wala -> Feed forward hai sebastian wala
81
- def __init__(self):
82
- super().__init__()
83
- self.c_fc = nn.Linear(N_EMBD, 4 * N_EMBD)
84
- self.gelu = nn.GELU(approximate='tanh')
85
- self.c_proj = nn.Linear(4 * N_EMBD, N_EMBD)
86
- self.c_proj.NANOGPT_SCALE_INIT = 1
87
-
88
- def forward(self, x):
89
- x = self.c_fc(x)
90
- x = self.gelu(x)
91
- x = self.c_proj(x)
92
- return x
93
- """ a simple linear layer followed by a non-linearity """
94
-
95
- # def __init__(self, n_embd):
96
- # super().__init__()
97
- # self.net = nn.Sequential(
98
- # nn.Linear(N_EMBD, 4 * N_EMBD),
99
- # nn.ReLU(),
100
- # nn.Linear(4 * N_EMBD, N_EMBD),
101
- # nn.Dropout(DROPOUT),
102
- # )
103
-
104
- # def forward(self, x):
105
- # return self.net(x)
106
-
107
- class Block(nn.Module):
108
- """ Transformer block: communication followed by computation """
109
-
110
- def __init__(self, n_embd, n_head):
111
- super().__init__()
112
- head_size = N_EMBD // n_head
113
- self.sa = CausalSelfAttention()
114
- self.ffwd = FeedFoward()
115
- self.ln1 = nn.LayerNorm(N_EMBD)
116
- self.ln2 = nn.LayerNorm(N_EMBD)
117
-
118
- def forward(self, x):
119
- x = x + self.sa(self.ln1(x))
120
- x = x + self.ffwd(self.ln2(x))
121
- return x
122
-
123
- class GPTLanguageModel(nn.Module):
124
-
125
- def __init__(self, vocab_size, config):
126
- super().__init__()
127
- print("This is vocab size:", vocab_size)
128
- self.token_embedding_table = nn.Embedding(vocab_size, config.n_embd)
129
- self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
130
- self.blocks = nn.Sequential(
131
- *[Block(config.n_embd, n_head=config.n_head) for _ in range(config.n_layer)]
132
- )
133
- self.ln_f = nn.LayerNorm(config.n_embd)
134
- self.lm_head = nn.Linear(config.n_embd, vocab_size)
135
-
136
- self.token_embedding_table.weight = self.lm_head.weight
137
-
138
- self.apply(self._init_weights)
139
- self.config = {"BLOCK_SIZE": config.block_size, "N_EMBD": config.n_embd, "N_HEAD":config.n_head, "N_LAYER": config.n_layer}
140
-
141
-
142
- def _init_weights(self, module):
143
- if isinstance(module, nn.Linear):
144
- std = 0.02
145
- if hasattr(module, 'NANOGPT_SCALE_INIT'):
146
- std *= (2 * N_LAYER) ** -0.5
147
- torch.nn.init.normal_(module.weight, mean=0.0, std=std)
148
- if module.bias is not None:
149
- torch.nn.init.zeros_(module.bias)
150
- elif isinstance(module, nn.Embedding):
151
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
152
-
153
-
154
- # def _init_weights(self, module):
155
- # if isinstance(module, nn.Linear):
156
- # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
157
- # if module.bias is not None:
158
- # torch.nn.init.zeros_(module.bias)
159
- # elif isinstance(module, nn.Embedding):
160
- # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
161
-
162
- def forward(self, idx, targets=None):
163
- B, T = idx.shape
164
- assert T <= BLOCK_SIZE, f"Cannot forward sequence of length {T}, block size is only {BLOCK_SIZE}"
165
-
166
-
167
- tok_emb = self.token_embedding_table(idx)
168
- pos_emb = self.position_embedding_table(torch.arange(0, T, dtype=torch.long, device=idx.device))
169
- x = tok_emb + pos_emb
170
- x = self.blocks(x)
171
- x = self.ln_f(x)
172
- logits = self.lm_head(x)
173
-
174
- if targets is None:
175
- loss = None
176
- else:
177
- # B, T, C = logits.shape
178
- # logits = logits.view(B*T, C)
179
- # targets = targets.view(B*T)
180
- # loss = nn.functional.cross_entropy(logits, targets)
181
- loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
182
-
183
- return logits, loss
184
-
185
- def generate(self, idx, max_new_tokens, temperature=1.0):
186
- """
187
- Generate tokens using the language model.
188
- Args:
189
- idx: Input token indices
190
- max_new_tokens: Number of tokens to generate
191
- temperature: Controls randomness in generation
192
- - temperature > 1.0 increases randomness
193
- - temperature < 1.0 decreases randomness
194
- - temperature = 0 makes it deterministic (always picks highest probability)
195
- """
196
- for _ in range(max_new_tokens):
197
- # Truncate the sequence to the last BLOCK_SIZE tokens
198
- idx_cond = idx[:, -BLOCK_SIZE:]
199
- # Get logits from the model
200
- logits, _ = self(idx_cond)
201
- # Focus only on the last time step
202
- logits = logits[:, -1, :]
203
-
204
- if temperature == 0.0:
205
- # For temperature = 0, simply take the argmax
206
- idx_next = torch.argmax(logits, dim=-1, keepdim=True)
207
- else:
208
- # Apply temperature scaling
209
- logits = logits / temperature
210
- # Convert to probabilities
211
- probs = torch.softmax(logits, dim=-1)
212
- # Sample from the distribution
213
- idx_next = torch.multinomial(probs, num_samples=1)
214
-
215
- # Append the new token to the sequence
216
- idx = torch.cat((idx, idx_next), dim=1)
217
- return idx
218
-
219
- # def save(self, path=MODEL_PATH):
220
- # torch.save(self.state_dict(), path)
221
-
222
- # def load(self, path=MODEL_PATH):
223
- # self.load_state_dict(torch.load(path,map_location=torch.device('cpu')))
224
-
225
- def save(self, path=MODEL_PATH):
226
- torch.save(self.state_dict(), path)
227
-
228
- # def load(self, path=MODEL_PATH):
229
- # self.load_state_dict(torch.load(path))
230
-
231
- def load(self, path=MODEL_PATH):
232
- # Load the state dict
233
- state_dict = torch.load(path)["model"]
234
-
235
- # Rename the keys to match the expected ones (remove "orig_mod." prefix)
236
- new_state_dict = {}
237
- for key, value in state_dict.items():
238
- new_key = key.replace('_orig_mod.', '') # Remove 'orig_mod.' prefix
239
- new_state_dict[new_key] = value
240
-
241
- # Load the renamed state dict into the model
242
- self.load_state_dict(new_state_dict)
243
-
244
-
245
- def configure_optimizers(self, weight_decay=0.1, learning_rate=LEARNING_RATE, device=DEVICE):
246
- param_dict = {pn: p for pn, p in self.named_parameters()}
247
- param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
248
-
249
- decay_parameters = [p for n, p in param_dict.items() if p.dim() >= 2]
250
- nodecay_parameters = [p for n, p in param_dict.items() if p.dim() < 2]
251
- optim_groups = [
252
- {"params": decay_parameters, "weight_decay": weight_decay},
253
- {"params": nodecay_parameters, "weight_decay": 0.0},
254
- ]
255
- fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
256
- use_fused = fused_available and device == "cuda"
257
- optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused = use_fused)
258
- return optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Naive_gpt/train.py DELETED
@@ -1,303 +0,0 @@
1
- import torch
2
- from tqdm import tqdm
3
- from .config import *
4
- from .data_loader import TextDataLoader
5
- from .model import GPTLanguageModel
6
- import math
7
-
8
- max_lr = 6e-4
9
- min_lr = max_lr * 0.1
10
- warmup_steps = 715
11
- max_steps = 19073
12
-
13
- def get_lr(it):
14
- if it < warmup_steps:
15
- return max_lr * (it+1) / warmup_steps
16
-
17
- if it > max_steps:
18
- return min_lr
19
-
20
- decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
21
- assert 0 <= decay_ratio <= 1
22
- coeff = 0.5 * (1.0 +math.cos(math.pi * decay_ratio))
23
- return min_lr + coeff * (max_lr - min_lr)
24
-
25
- total_batch_size = 524288
26
- assert total_batch_size % (BATCH_SIZE * BLOCK_SIZE) == 0, "make sure total_batch_size is divisible by BATCH_SIZE * BLOCK_SIZE"
27
- grad_accumulation_steps = total_batch_size // (BATCH_SIZE * BLOCK_SIZE)
28
- print(f"grad_accumulation_steps: {grad_accumulation_steps}")
29
- print(f"total_batch_size: {total_batch_size}")
30
-
31
- import sys
32
- import os
33
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
34
-
35
- from DataLoader import create_dataloader
36
-
37
-
38
-
39
- def train(folder_path, tokenizer, model=None, optimizer=None, vocab_size=10000, platform='none', checkpoint=None, is_tokenized_data = False):
40
-
41
- torch.set_float32_matmul_precision('high') #hammad added this line (need to check if it is necessary)
42
- if model is None:
43
- model = GPTLanguageModel(vocab_size=vocab_size)
44
- print("Model Initialised")
45
- if checkpoint != None:
46
- print("loading checkpoint........")
47
- model.load(checkpoint)
48
- print("Model loaded from checkpoint", checkpoint)
49
-
50
- if platform == 'kaggle':
51
- model = torch.nn.DataParallel(model, device_ids=[0, 1])
52
- model = model.to(DEVICE)
53
- optimizer = model.module.configure_optimizers(weight_decay=0.1, learning_rate=LEARNING_RATE, device=DEVICE) #hammad added this line
54
- else:
55
- model = model.to(DEVICE)
56
- model = torch.compile(model) #hammad added this line
57
- optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=LEARNING_RATE, device=DEVICE)
58
- # optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas = (0.9, 0.95), eps = 1e-8)
59
-
60
-
61
- # # Initialize the data loader
62
- # loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer)
63
-
64
-
65
- # Set up a tqdm progress bar for the epoch
66
- for epoch in range(MAX_ITERS):
67
- print(f"Epoch {epoch}")
68
- epoch_loss = None # Track loss for the epoch
69
-
70
-
71
- for i in range(len(os.listdir(folder_path))):
72
- file_path = os.path.join(folder_path, os.listdir(folder_path)[i])
73
- print(f"loading file: {file_path}")
74
- loader = create_dataloader(tokenizer, file_path, BATCH_SIZE, BLOCK_SIZE, BLOCK_SIZE, tokenized_data = is_tokenized_data, filename = os.listdir(folder_path)[i]) #hammad added this line
75
-
76
-
77
- # Create a progress bar for batch processing
78
- batch_progress_bar = tqdm(loader, desc=f"Epoch {epoch+1} Batch Progress", unit="batch", ncols=100)
79
- count = 0
80
- loss_accum = 0
81
- for xb, yb in batch_progress_bar:
82
- if xb is None:
83
- break # No more batches, stop the epoch
84
- optimizer.zero_grad()
85
-
86
- # Forward pass and loss computation
87
- xb = xb.to(DEVICE)
88
- yb = yb.to(DEVICE)
89
- #with torch.autocast(DEVICE, dtype=torch.bfloat16): #hammad added this line
90
- logits, loss = model(xb, yb)
91
- loss = loss / grad_accumulation_steps
92
- if platform == 'kaggle':
93
- loss_accum += loss.mean().detach()
94
- loss.mean().backward()
95
- else:
96
- loss_accum += loss.detach()
97
- loss.backward() # Backpropagate the loss
98
- # for micro_batch in range(grad_accumulation_steps):
99
- if count % grad_accumulation_steps == 0:
100
- print("one batch completed at (xb,yb):", count)
101
- loss_accum = 0
102
- norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) #hammad added this line
103
- lr = get_lr(count) #need to check if this is correct
104
- for param_group in optimizer.param_groups:
105
- param_group['lr'] = lr
106
- optimizer.step() # Update model parameters
107
- torch.cuda.synchronize() #wait for the computation to finish before moving to the next iteration
108
-
109
- # Update epoch_loss to the most recent loss value
110
- if platform == 'kaggle':
111
- epoch_loss = loss.mean().item()
112
- else:
113
- epoch_loss = loss.item()
114
-
115
- # Update tqdm with the latest loss value
116
- batch_progress_bar.set_postfix(loss=epoch_loss)
117
-
118
- count+=1
119
- if count%5000 == 0:
120
- if platform == 'kaggle':
121
- torch.save(model.module.state_dict(), f"model_weights_checkpoint_{count}.pth")
122
- else:
123
- torch.save(model.state_dict(), f"model_weights_checkpoint_{count}.pth")
124
- print(f"Model weights saved at checkpoint {count}")
125
-
126
- # Save model weights after each chunk or epoch
127
- if platform == 'kaggle':
128
- torch.save(model.module.state_dict(),
129
- f"model_weights_epoch_{epoch}_{file_path[-6:-4]}.pth")
130
- else:
131
- torch.save(model.state_dict(),
132
- f"model_weights_epoch_{epoch}_{file_path[-6:-4]}.pth")
133
- print(f"Model weights saved at epoch {epoch}")
134
-
135
- # Print the loss at the end of the epoch
136
- if epoch_loss is not None:
137
- print(f"Epoch {epoch}, Loss: {epoch_loss}")
138
- else:
139
- print(f"Epoch {epoch}, No data available for loss calculation.")
140
-
141
- # Reset the loader for a new epoch
142
- # loader.reset()
143
-
144
- # loader.close() # Ensure the file is properly closed at the end
145
- torch.cuda.empty_cache()
146
-
147
- return model, optimizer
148
-
149
-
150
- #before parallelizing the model
151
- # def train(file_path, tokenizer, model=None, optimizer=None, vocab_size=10000, platform='none'):
152
- # if model is None:
153
- # model = GPTLanguageModel(vocab_size=vocab_size)
154
- # if platform == 'kaggle':
155
- # model = torch.nn.DataParallel(model, device_ids=[0, 1]).to(DEVICE)
156
- # else:
157
- # model = model.to(DEVICE)
158
- # optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
159
-
160
- # # Initialize the data loader
161
- # loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer, DEVICE)
162
-
163
- # # Set up a tqdm progress bar for the epoch
164
- # for epoch in range(MAX_ITERS):
165
- # print(f"Epoch {epoch}")
166
- # epoch_loss = None # Track loss for the epoch
167
-
168
- # # Create a progress bar for batch processing
169
- # batch_progress_bar = tqdm(loader, total=loader.num_batches(), desc=f"Epoch {epoch+1} Batch Progress", unit="batch", ncols=100)
170
-
171
- # for xb, yb in batch_progress_bar:
172
- # if xb is None:
173
- # break # No more batches, stop the epoch
174
-
175
- # # Forward pass and loss computation
176
- # logits, loss = model(xb, yb)
177
- # optimizer.zero_grad()
178
- # loss.backward() # Backpropagate the loss
179
- # optimizer.step() # Update model parameters
180
-
181
- # # Update epoch_loss to the most recent loss value
182
- # epoch_loss = loss.item()
183
-
184
- # # Update tqdm with the latest loss value
185
- # batch_progress_bar.set_postfix(loss=epoch_loss)
186
-
187
- # # Save model weights after each chunk or epoch
188
- # model.save(f"model_weights_epoch_{epoch}.pth")
189
- # print(f"Model weights saved at epoch {epoch}")
190
-
191
- # # Print the loss at the end of the epoch
192
- # if epoch_loss is not None:
193
- # print(f"Epoch {epoch}, Loss: {epoch_loss}")
194
- # else:
195
- # print(f"Epoch {epoch}, No data available for loss calculation.")
196
-
197
- # # Reset the loader for a new epoch
198
- # loader.reset()
199
-
200
- # loader.close() # Ensure the file is properly closed at the end
201
-
202
- # return model, optimizer
203
-
204
- # def train(file_path, tokenizer, model = None, optimizer = None, vocab_size=10000, platform='none'):
205
- # if model is None:
206
- # model = GPTLanguageModel(vocab_size=vocab_size)
207
- # if platform == 'kaggle':
208
- # model = torch.nn.DataParallel(model, device_ids=[0, 1]).to(DEVICE)
209
- # else:
210
- # model = model.to(DEVICE)
211
- # optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
212
- # loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer, DEVICE)
213
-
214
-
215
- # for epoch in range(MAX_ITERS): # Iterate over the file chunks
216
- # print(f"Epoch {epoch}")
217
- # epoch_loss = None # Track loss for the epoch
218
- # while not loader.end_of_file:
219
- # xb, yb = loader.get_batch()
220
- # if xb is None:
221
- # break # No more batches, stop the epoch
222
-
223
- # # Forward pass and loss computation
224
- # # print("This is xb", xb)
225
- # # print("This is yb", yb)
226
- # logits, loss = model(xb, yb)
227
- # optimizer.zero_grad()
228
- # loss.backward() #2 gpus pe masla kr rraha (krna for n gpus hai)
229
- # optimizer.step()
230
-
231
- # # Update epoch_loss to the most recent loss value
232
- # epoch_loss = loss.item()
233
-
234
- # # Save model weights after each chunk or epoch
235
- # model.save(f"model_weights_epoch_{epoch}.pth")
236
- # print(f"Model weights saved at epoch {epoch}")
237
-
238
- # # Print the loss only if it was computed
239
- # if epoch_loss is not None:
240
- # print(f"Epoch {epoch}, Loss: {epoch_loss}")
241
- # else:
242
- # print(f"Epoch {epoch}, No data available for loss calculation.")
243
-
244
- # # Reset the loader for a new epoch
245
- # loader.reset()
246
-
247
- # loader.close() # Ensure file is properly closed at the end
248
-
249
- # return model, optimizer
250
-
251
-
252
- # def train(file_path, tokenizer, model=None, optimizer=None, vocab_size=10000):
253
- # # Check if multiple GPUs are available
254
- # device = DEVICE
255
- # if model is None:
256
- # if torch.cuda.is_available() and torch.cuda.device_count() > 1:
257
- # print(f"Training on {torch.cuda.device_count()} GPUs")
258
- # model = GPTLanguageModel(vocab_size=vocab_size).to(device)
259
- # model = torch.nn.DataParallel(model, device_ids=[0, 1]) # Wrap the model for multi-GPU training
260
- # else:
261
- # print("Training on a single GPU or CPU.")
262
-
263
- # model = GPTLanguageModel(vocab_size=vocab_size).to(device)
264
-
265
- # if optimizer is None:
266
- # optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
267
-
268
- # loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer, device)
269
-
270
- # for epoch in range(MAX_ITERS): # Iterate over the file chunks
271
- # print(f"Epoch {epoch}")
272
- # epoch_loss = None # Track loss for the epoch
273
-
274
- # xb, yb = loader.get_batch()
275
- # if xb is None:
276
- # break # No more batches, stop the epoch
277
-
278
- # # Forward pass and loss computation
279
- # logits, loss = model(xb, yb)
280
- # optimizer.zero_grad()
281
- # loss.backward()
282
- # optimizer.step()
283
-
284
- # # Update epoch_loss to the most recent loss value
285
- # epoch_loss = loss.item()
286
-
287
- # # Save model weights after each chunk or epoch
288
- # model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model # Get the underlying model if using DataParallel
289
- # model_to_save.save(f"model_weights_epoch_{epoch}.pth")
290
- # print(f"Model weights saved at epoch {epoch}")
291
-
292
- # # Print the loss only if it was computed
293
- # if epoch_loss is not None:
294
- # print(f"Epoch {epoch}, Loss: {epoch_loss}")
295
- # else:
296
- # print(f"Epoch {epoch}, No data available for loss calculation.")
297
-
298
- # # Reset the loader for a new epoch
299
- # loader.reset()
300
-
301
- # loader.close() # Ensure file is properly closed at the end
302
-
303
- # return model, optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Naive_gpt/train_autoshardload_ddp.py DELETED
@@ -1,307 +0,0 @@
1
- from DataLoader import create_dataloader, create_dataloader_ddp, GPTDatasetDDP
2
- import os
3
- import sys
4
- import torch
5
- from tqdm import tqdm
6
- from .config import *
7
- from .data_loader import TextDataLoader
8
- from .model import GPTLanguageModel
9
- import math
10
- from torch.distributed import init_process_group, destroy_process_group
11
- from torch.nn.parallel import DistributedDataParallel as DDP
12
- import torch.distributed as dist
13
- import time
14
- import gc
15
-
16
- max_lr = 6e-4
17
- min_lr = max_lr * 0.1
18
-
19
- ratio = 715/19073
20
-
21
- max_steps = 9124 # tokencount / (batchsize*blocksize) = 4Btokens / (524288) i.e 512 batchisze * 1024
22
- # warmup_steps = 72 # 19073 / 715 is the ratio
23
-
24
- # max_steps = train_loader.calculate_steps()
25
- warmup_steps = int(ratio*max_steps)
26
-
27
- def get_lr(it):
28
- if it < warmup_steps:
29
- return max_lr * (it+1) / warmup_steps
30
-
31
- if it > max_steps:
32
- return min_lr
33
-
34
- decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
35
- assert 0 <= decay_ratio <= 1
36
- coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
37
- return min_lr + coeff * (max_lr - min_lr)
38
-
39
-
40
- log_dir = "log"
41
- os.makedirs(log_dir, exist_ok=True)
42
- log_file = os.path.join(log_dir, f"log.txt")
43
- with open(log_file, "w") as f: # open for writing to clear the file
44
- pass
45
-
46
-
47
- # set up DDP (distributed data parallel).
48
- # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
49
- ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
50
- if ddp:
51
- # use of DDP atm demands CUDA, we set the device appropriately according to rank
52
- assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
53
- init_process_group(backend='nccl')
54
- ddp_rank = int(os.environ['RANK'])
55
- ddp_local_rank = int(os.environ['LOCAL_RANK'])
56
- ddp_world_size = int(os.environ['WORLD_SIZE'])
57
- device = f'cuda:{ddp_local_rank}'
58
- torch.cuda.set_device(device)
59
- # this process will do logging, checkpointing etc.
60
- master_process = ddp_rank == 0
61
- else:
62
- # vanilla, non-DDP run
63
- ddp_rank = 0
64
- ddp_local_rank = 0
65
- ddp_world_size = 1
66
- master_process = True
67
- # attempt to autodetect device
68
- device = "cpu"
69
- if torch.cuda.is_available():
70
- device = "cuda"
71
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
72
- device = "mps"
73
- print(f"using device: {device}")
74
-
75
- torch.manual_seed(1337)
76
- if torch.cuda.is_available():
77
- torch.cuda.manual_seed(1337)
78
-
79
- print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}")
80
- print(f"master_process: {master_process}, device: {device}")
81
- # sys.exit(0)
82
-
83
-
84
- total_batch_size = 524288 # 524288 / 1024 = 512 is the batch size hamari terms mein
85
- assert total_batch_size % (
86
- BATCH_SIZE * BLOCK_SIZE * ddp_world_size) == 0, "make sure total_batch_size is divisible by BATCH_SIZE * BLOCK_SIZE"
87
- grad_accumulation_steps = total_batch_size // (
88
- BATCH_SIZE * BLOCK_SIZE * ddp_world_size)
89
- if master_process:
90
- print(f"grad_accumulation_steps: {grad_accumulation_steps}")
91
- print(f"total_batch_size: {total_batch_size}")
92
-
93
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
94
-
95
-
96
- def train(folder_path, tokenizer, model=None, optimizer=None, vocab_size=20000, platform='none', checkpoint=None, is_tokenized_data = False):
97
-
98
- # hammad added this line (need to check if it is necessary)
99
- torch.set_float32_matmul_precision('high')
100
- if model is None:
101
- model = GPTLanguageModel(vocab_size=vocab_size)
102
- print("Model Initialised")
103
- if checkpoint != None:
104
- print("loading checkpoint........")
105
- model.load(checkpoint)
106
- print("Model loaded from checkpoint", checkpoint)
107
-
108
- # if platform == 'kaggle':
109
- # model = torch.nn.DataParallel(model, device_ids=[0, 1])
110
- # model = model.to(DEVICE)
111
- # optimizer = model.module.configure_optimizers(
112
- # weight_decay=0.1, learning_rate=LEARNING_RATE, device=DEVICE) # hammad added this line
113
- # else:
114
- model = model.to(DEVICE)
115
- model = torch.compile(model) # hammad added this line
116
- if ddp:
117
- model = DDP(model, device_ids=[ddp_local_rank])
118
- raw_model = model.module if ddp else model
119
- # optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas = (0.9, 0.95), eps = 1e-8)
120
- optimizer = raw_model.configure_optimizers(
121
- weight_decay=0.1, learning_rate=LEARNING_RATE, device=DEVICE)
122
-
123
- # # Initialize the data loader
124
- # loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer)
125
- train_loader = GPTDatasetDDP(tokenizer, BATCH_SIZE, BLOCK_SIZE, BLOCK_SIZE, folder_path, ddp_rank, ddp_world_size, "train")
126
- val_loader = GPTDatasetDDP(tokenizer, BATCH_SIZE, BLOCK_SIZE, BLOCK_SIZE, folder_path, ddp_rank, ddp_world_size, "test")
127
-
128
-
129
- # Set up a tqdm progress bar for the epoch
130
- for epoch in range(MAX_ITERS):
131
- if master_process:
132
- print(f"Epoch {epoch}")
133
-
134
-
135
- epoch_loss = None # Track loss for the epoch
136
- train_loader.set_epoch(epoch)
137
- val_loader.set_epoch(epoch)
138
- train_loader.reset()
139
-
140
- # for i in range(len(os.listdir(folder_path))):
141
- # file_path = os.path.join(folder_path, os.listdir(folder_path)[i])
142
- # print(f"loading file: {file_path}")
143
- # loader = create_dataloader_ddp(tokenizer, file_path, BATCH_SIZE, BLOCK_SIZE, BLOCK_SIZE,
144
- # tokenized_data=is_tokenized_data, process_rank=ddp_rank, num_process=ddp_world_size, filename=os.listdir(folder_path)[i]) # hammad added this line
145
-
146
- # Create a progress bar for batch processing
147
- # batch_progress_bar = tqdm(
148
- # loader, desc=f"Epoch {epoch+1} Batch Progress", unit="batch", ncols=100)
149
- count = 0
150
- halt = False
151
- # for xb, yb in batch_progress_bar:
152
- while count < max_steps-8:
153
- t0 = time.time()
154
- loss_accum = 0
155
- model.train()
156
- optimizer.zero_grad()
157
-
158
-
159
- # Forward pass and loss computation
160
- for micro_step in range(grad_accumulation_steps):
161
- xb, yb = train_loader.next_batch()
162
- if xb is None:
163
- halt = True
164
- if master_process:
165
- checkpoint = {
166
- 'model': raw_model.state_dict(),
167
- 'optimizer' : optimizer.state_dict(),
168
- 'config': raw_model.config,
169
- 'step': count,
170
- 'val_loss': val_loss_accum.item() }
171
- # you might also want to add optimizer.state_dict() and
172
- # rng seeds etc., if you wanted to more exactly resume training
173
- torch.save(checkpoint, f"model_weights_checkpoint_{count}.pth")
174
- print(f"Model weights saved at checkpoint {count}")
175
- print(f"Epoch {epoch} completed")
176
- torch.cuda.synchronize()
177
- break # No more batches, stop the epoch
178
- xb = xb.to(DEVICE)
179
- yb = yb.to(DEVICE)
180
- if ddp:
181
- model.require_backward_grad_sync = (micro_step == grad_accumulation_steps - 1)
182
- with torch.autocast(DEVICE, dtype=torch.float32):
183
- logits, loss = model(xb, yb)
184
- loss = loss / grad_accumulation_steps
185
- loss_accum += loss.detach()
186
- loss.backward()
187
- # if ddp:
188
- # dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
189
- # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
190
- if halt == True:
191
- break
192
- # with torch.autocast(DEVICE, dtype=torch.bfloat16): #hammad added this line
193
- # logits, loss = model(xb, yb)
194
- # loss = loss / grad_accumulation_steps
195
- # # if platform == 'kaggle':
196
- # # loss_accum += loss.mean().detach()
197
- # # loss.mean().backward()
198
- # # else:
199
- # loss_accum += loss.detach()
200
- # if ddp:
201
- # model.require_backward_grad_sync = (
202
- # (count + 1) % grad_accumulation_steps == 0)
203
- # loss.backward() # Backpropagate the loss
204
-
205
- # for micro_batch in range(grad_accumulation_steps):
206
- if ddp:
207
- dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
208
- norm = torch.nn.utils.clip_grad_norm_(
209
- model.parameters(), 1.0) # hammad added this line
210
- lr = get_lr(count) # need to check if this is correct
211
- for param_group in optimizer.param_groups:
212
- param_group['lr'] = lr
213
- optimizer.step() # Update model parameters
214
- # wait for the computation to finish before moving to the next iteration
215
- torch.cuda.synchronize()
216
- t1 = time.time()
217
- dt = t1 - t0 # time difference in seconds
218
- tokens_processed = train_loader.batch_size * \
219
- train_loader.block_size * grad_accumulation_steps * ddp_world_size
220
- tokens_per_sec = tokens_processed / dt
221
- if master_process:
222
- print(f"epoch {epoch:5d} step {count:5d} | loss_accum: {loss_accum.item():.6f} | loss_unacum: {loss} |lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
223
- with open(log_file, "a") as f:
224
- f.write(f"{epoch} {count} train {loss_accum.item():.6f}\n")
225
-
226
- # Update epoch_loss to the most recent loss value
227
- if platform == 'kaggle':
228
- epoch_loss = loss.mean().item()
229
- else:
230
- epoch_loss = loss.item()
231
-
232
- # Update tqdm with the latest loss value
233
- # batch_progress_bar.set_postfix(loss=epoch_loss)
234
-
235
- # if count % 5000 == 0:
236
- # torch.save(raw_model.state_dict(),
237
- # f"model_weights_checkpoint_{count}.pth")
238
- #
239
-
240
- if count % 50 == 0:
241
- model.eval()
242
- val_loader.reset()
243
-
244
- with torch.no_grad():
245
- val_loss_accum = 0.0
246
- val_loss_steps = 20
247
- for _ in range(val_loss_steps):
248
- xb, yb = val_loader.next_batch()
249
- xb = xb.to(DEVICE)
250
- yb = yb.to(DEVICE)
251
- with torch.autocast(DEVICE, dtype=torch.float32):
252
- logits, loss = model(xb, yb)
253
- loss = loss / val_loss_steps
254
- val_loss_accum += loss.detach()
255
- if ddp:
256
- dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
257
- if master_process:
258
- print(f"validation loss: {val_loss_accum.item():.4f}")
259
- with open(log_file, "a") as f:
260
- f.write(f"{count} val {val_loss_accum.item():.4f}\n")
261
- if count > 0 and (count % 100 == 0):
262
- # optionally write model checkpoints
263
- checkpoint_path = os.path.join(log_dir, f"model_{count:05d}.pt")
264
- checkpoint = {
265
- 'model': raw_model.state_dict(),
266
- 'optimizer' : optimizer.state_dict(),
267
- 'config': raw_model.config,
268
- 'step': count,
269
- 'val_loss': val_loss_accum.item()
270
- }
271
- # you might also want to add optimizer.state_dict() and
272
- # rng seeds etc., if you wanted to more exactly resume training
273
- torch.save(
274
- checkpoint, f"model_weights_checkpoint_{epoch}_{count}.pth")
275
- print(f"Model weights saved at checkpoint {count}")
276
- count += 1
277
-
278
- # Save model weights after each chunk or epoch
279
-
280
- # Print the loss at the end of the epoch
281
- if master_process:
282
- if epoch_loss is not None:
283
- print(f"Epoch {epoch}, Loss: {epoch_loss}")
284
- else:
285
- print(f"Epoch {epoch}, No data available for loss calculation.")
286
-
287
-
288
- train_loader.close() # Ensure the file is properly closed at the end
289
- val_loader.close()
290
- if ddp:
291
- destroy_process_group()
292
- print("Process group destroyed")
293
-
294
- # Delete variables, model, and optimizer
295
- del model
296
- del optimizer
297
- torch.cuda.empty_cache() # Clears cached memory but not allocated memory
298
-
299
- # Force garbage collection
300
- gc.collect() # Ensures Python garbage collector releases unreferenced memory
301
- torch.cuda.empty_cache() # Clears freed memory from GPU cache
302
-
303
- torch.cuda.ipc_collect() # Helps in multi-GPU setups
304
- return (0,0)
305
-
306
-
307
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Naive_gpt/train_ddp.py DELETED
@@ -1,206 +0,0 @@
1
- from DataLoader import create_dataloader, create_dataloader_ddp
2
- import os
3
- import sys
4
- import torch
5
- from tqdm import tqdm
6
- from .config import *
7
- from .data_loader import TextDataLoader
8
- from .model import GPTLanguageModel
9
- import math
10
- from torch.distributed import init_process_group, destroy_process_group
11
- from torch.nn.parallel import DistributedDataParallel as DDP
12
- import torch.distributed as dist
13
-
14
- max_lr = 6e-4
15
- min_lr = max_lr * 0.1
16
- warmup_steps = 10
17
- max_steps = 50
18
-
19
-
20
- def get_lr(it):
21
- if it < warmup_steps:
22
- return max_lr * (it+1) / warmup_steps
23
-
24
- if it > max_steps:
25
- return min_lr
26
-
27
- decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
28
- assert 0 <= decay_ratio <= 1
29
- coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
30
- return min_lr + coeff * (max_lr - min_lr)
31
-
32
-
33
- # set up DDP (distributed data parallel).
34
- # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
35
- ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
36
- if ddp:
37
- # use of DDP atm demands CUDA, we set the device appropriately according to rank
38
- assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
39
- init_process_group(backend='nccl')
40
- ddp_rank = int(os.environ['RANK'])
41
- ddp_local_rank = int(os.environ['LOCAL_RANK'])
42
- ddp_world_size = int(os.environ['WORLD_SIZE'])
43
- device = f'cuda:{ddp_local_rank}'
44
- torch.cuda.set_device(device)
45
- # this process will do logging, checkpointing etc.
46
- master_process = ddp_rank == 0
47
- else:
48
- # vanilla, non-DDP run
49
- ddp_rank = 0
50
- ddp_local_rank = 0
51
- ddp_world_size = 1
52
- master_process = True
53
- # attempt to autodetect device
54
- device = "cpu"
55
- if torch.cuda.is_available():
56
- device = "cuda"
57
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
58
- device = "mps"
59
- print(f"using device: {device}")
60
-
61
- torch.manual_seed(1337)
62
- if torch.cuda.is_available():
63
- torch.cuda.manual_seed(1337)
64
-
65
- print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}")
66
- print(f"master_process: {master_process}, device: {device}")
67
- import sys; sys.exit(0)
68
-
69
-
70
- total_batch_size = 524288 #524288 / 1024 = 512 is the batch size hamari terms mein
71
- assert total_batch_size % (
72
- BATCH_SIZE * BLOCK_SIZE * ddp_world_size) == 0, "make sure total_batch_size is divisible by BATCH_SIZE * BLOCK_SIZE"
73
- grad_accumulation_steps = total_batch_size // (BATCH_SIZE * BLOCK_SIZE * ddp_world_size)
74
- if master_process:
75
- print(f"grad_accumulation_steps: {grad_accumulation_steps}")
76
- print(f"total_batch_size: {total_batch_size}")
77
-
78
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
79
-
80
-
81
- def train(folder_path, tokenizer, model=None, optimizer=None, vocab_size=10000, platform='none', checkpoint=None, is_tokenized_data=False):
82
-
83
- # hammad added this line (need to check if it is necessary)
84
- torch.set_float32_matmul_precision('high')
85
- if model is None:
86
- model = GPTLanguageModel(vocab_size=vocab_size)
87
- print("Model Initialised")
88
- if checkpoint != None:
89
- print("loading checkpoint........")
90
- model.load(checkpoint)
91
- print("Model loaded from checkpoint", checkpoint)
92
-
93
- # if platform == 'kaggle':
94
- # model = torch.nn.DataParallel(model, device_ids=[0, 1])
95
- # model = model.to(DEVICE)
96
- # optimizer = model.module.configure_optimizers(
97
- # weight_decay=0.1, learning_rate=LEARNING_RATE, device=DEVICE) # hammad added this line
98
- # else:
99
- model = model.to(DEVICE)
100
- model = torch.compile(model) # hammad added this line
101
- if ddp:
102
- model = DDP(model, device_ids=[ddp_local_rank])
103
- raw_model = model.module if ddp else model
104
- # optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas = (0.9, 0.95), eps = 1e-8)
105
- optimizer = raw_model.configure_optimizers(
106
- weight_decay=0.1, learning_rate=LEARNING_RATE, device=DEVICE)
107
-
108
- # # Initialize the data loader
109
- # loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer)
110
-
111
- # Set up a tqdm progress bar for the epoch
112
- for epoch in range(MAX_ITERS):
113
- print(f"Epoch {epoch}")
114
- epoch_loss = None # Track loss for the epoch
115
-
116
- # for i in range(len(os.listdir(folder_path))):
117
- # file_path = os.path.join(folder_path, os.listdir(folder_path)[i])
118
- # print(f"loading file: {file_path}")
119
- loader = create_dataloader_ddp(tokenizer, file_path, BATCH_SIZE, BLOCK_SIZE, BLOCK_SIZE,
120
- tokenized_data=is_tokenized_data, process_rank = ddp_rank, num_process = ddp_world_size, filename=os.listdir(folder_path)[i]) # hammad added this line
121
-
122
- # Create a progress bar for batch processing
123
- batch_progress_bar = tqdm(
124
- loader, desc=f"Epoch {epoch+1} Batch Progress", unit="batch", ncols=100)
125
- count = 0
126
- loss_accum = 0
127
- for xb, yb in batch_progress_bar:
128
- if xb is None:
129
- break # No more batches, stop the epoch
130
- optimizer.zero_grad()
131
-
132
- # Forward pass and loss computation
133
- xb = xb.to(DEVICE)
134
- yb = yb.to(DEVICE)
135
- # with torch.autocast(DEVICE, dtype=torch.bfloat16): #hammad added this line
136
- logits, loss = model(xb, yb)
137
- loss = loss / grad_accumulation_steps
138
- # if platform == 'kaggle':
139
- # loss_accum += loss.mean().detach()
140
- # loss.mean().backward()
141
- # else:
142
- loss_accum += loss.detach()
143
- if ddp:
144
- model.require_backward_grad_sync = ((count + 1) % grad_accumulation_steps == 0)
145
- loss.backward() # Backpropagate the loss
146
-
147
- # for micro_batch in range(grad_accumulation_steps):
148
- if count % grad_accumulation_steps == 0:
149
- print("one batch completed at (xb,yb):", count)
150
- if ddp:
151
- dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
152
- loss_accum = 0
153
- norm = torch.nn.utils.clip_grad_norm_(
154
- model.parameters(), 1.0) # hammad added this line
155
- lr = get_lr(count) # need to check if this is correct
156
- for param_group in optimizer.param_groups:
157
- param_group['lr'] = lr
158
- optimizer.step() # Update model parameters
159
- # wait for the computation to finish before moving to the next iteration
160
- torch.cuda.synchronize()
161
-
162
- # Update epoch_loss to the most recent loss value
163
- if platform == 'kaggle':
164
- epoch_loss = loss.mean().item()
165
- else:
166
- epoch_loss = loss.item()
167
-
168
- # Update tqdm with the latest loss value
169
- batch_progress_bar.set_postfix(loss=epoch_loss)
170
-
171
- count += 1
172
- if count % 5000 == 0:
173
- if platform == 'kaggle':
174
- torch.save(model.module.state_dict(),
175
- f"model_weights_checkpoint_{count}.pth")
176
- else:
177
- torch.save(model.state_dict(),
178
- f"model_weights_checkpoint_{count}.pth")
179
- print(f"Model weights saved at checkpoint {count}")
180
-
181
- # Save model weights after each chunk or epoch
182
- if platform == 'kaggle':
183
- torch.save(model.module.state_dict(),
184
- f"model_weights_epoch_{epoch}_{file_path[-6:-4]}.pth")
185
- else:
186
- torch.save(model.state_dict(),
187
- f"model_weights_epoch_{epoch}_{file_path[-6:-4]}.pth")
188
- print(f"Model weights saved at epoch {epoch}")
189
-
190
- # Print the loss at the end of the epoch
191
- if epoch_loss is not None:
192
- print(f"Epoch {epoch}, Loss: {epoch_loss}")
193
- else:
194
- print(
195
- f"Epoch {epoch}, No data available for loss calculation.")
196
-
197
- # Reset the loader for a new epoch
198
- # loader.reset()
199
-
200
- # loader.close() # Ensure the file is properly closed at the end
201
-
202
- return model, optimizer
203
-
204
- if ddp:
205
- destroy_process_group()
206
- print("Process group destroyed")