File size: 12,530 Bytes
83aefdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
import torch
from tqdm import tqdm
from .config import *
from .data_loader import TextDataLoader
from .model import GPTLanguageModel
import math
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 715
max_steps = 19073
def get_lr(it):
if it < warmup_steps:
return max_lr * (it+1) / warmup_steps
if it > max_steps:
return min_lr
decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 +math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
total_batch_size = 524288
assert total_batch_size % (BATCH_SIZE * BLOCK_SIZE) == 0, "make sure total_batch_size is divisible by BATCH_SIZE * BLOCK_SIZE"
grad_accumulation_steps = total_batch_size // (BATCH_SIZE * BLOCK_SIZE)
print(f"grad_accumulation_steps: {grad_accumulation_steps}")
print(f"total_batch_size: {total_batch_size}")
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
from DataLoader import create_dataloader
def train(folder_path, tokenizer, model=None, optimizer=None, vocab_size=10000, platform='none', checkpoint=None, is_tokenized_data = False):
torch.set_float32_matmul_precision('high') #hammad added this line (need to check if it is necessary)
if model is None:
model = GPTLanguageModel(vocab_size=vocab_size)
print("Model Initialised")
if checkpoint != None:
print("loading checkpoint........")
model.load(checkpoint)
print("Model loaded from checkpoint", checkpoint)
if platform == 'kaggle':
model = torch.nn.DataParallel(model, device_ids=[0, 1])
model = model.to(DEVICE)
optimizer = model.module.configure_optimizers(weight_decay=0.1, learning_rate=LEARNING_RATE, device=DEVICE) #hammad added this line
else:
model = model.to(DEVICE)
model = torch.compile(model) #hammad added this line
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=LEARNING_RATE, device=DEVICE)
# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas = (0.9, 0.95), eps = 1e-8)
# # Initialize the data loader
# loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer)
# Set up a tqdm progress bar for the epoch
for epoch in range(MAX_ITERS):
print(f"Epoch {epoch}")
epoch_loss = None # Track loss for the epoch
for i in range(len(os.listdir(folder_path))):
file_path = os.path.join(folder_path, os.listdir(folder_path)[i])
print(f"loading file: {file_path}")
loader = create_dataloader(tokenizer, file_path, BATCH_SIZE, BLOCK_SIZE, BLOCK_SIZE, tokenized_data = is_tokenized_data, filename = os.listdir(folder_path)[i]) #hammad added this line
# Create a progress bar for batch processing
batch_progress_bar = tqdm(loader, desc=f"Epoch {epoch+1} Batch Progress", unit="batch", ncols=100)
count = 0
loss_accum = 0
for xb, yb in batch_progress_bar:
if xb is None:
break # No more batches, stop the epoch
optimizer.zero_grad()
# Forward pass and loss computation
xb = xb.to(DEVICE)
yb = yb.to(DEVICE)
#with torch.autocast(DEVICE, dtype=torch.bfloat16): #hammad added this line
logits, loss = model(xb, yb)
loss = loss / grad_accumulation_steps
if platform == 'kaggle':
loss_accum += loss.mean().detach()
loss.mean().backward()
else:
loss_accum += loss.detach()
loss.backward() # Backpropagate the loss
# for micro_batch in range(grad_accumulation_steps):
if count % grad_accumulation_steps == 0:
print("one batch completed at (xb,yb):", count)
loss_accum = 0
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) #hammad added this line
lr = get_lr(count) #need to check if this is correct
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step() # Update model parameters
torch.cuda.synchronize() #wait for the computation to finish before moving to the next iteration
# Update epoch_loss to the most recent loss value
if platform == 'kaggle':
epoch_loss = loss.mean().item()
else:
epoch_loss = loss.item()
# Update tqdm with the latest loss value
batch_progress_bar.set_postfix(loss=epoch_loss)
count+=1
if count%5000 == 0:
if platform == 'kaggle':
torch.save(model.module.state_dict(), f"model_weights_checkpoint_{count}.pth")
else:
torch.save(model.state_dict(), f"model_weights_checkpoint_{count}.pth")
print(f"Model weights saved at checkpoint {count}")
# Save model weights after each chunk or epoch
if platform == 'kaggle':
torch.save(model.module.state_dict(),
f"model_weights_epoch_{epoch}_{file_path[-6:-4]}.pth")
else:
torch.save(model.state_dict(),
f"model_weights_epoch_{epoch}_{file_path[-6:-4]}.pth")
print(f"Model weights saved at epoch {epoch}")
# Print the loss at the end of the epoch
if epoch_loss is not None:
print(f"Epoch {epoch}, Loss: {epoch_loss}")
else:
print(f"Epoch {epoch}, No data available for loss calculation.")
# Reset the loader for a new epoch
# loader.reset()
# loader.close() # Ensure the file is properly closed at the end
torch.cuda.empty_cache()
return model, optimizer
#before parallelizing the model
# def train(file_path, tokenizer, model=None, optimizer=None, vocab_size=10000, platform='none'):
# if model is None:
# model = GPTLanguageModel(vocab_size=vocab_size)
# if platform == 'kaggle':
# model = torch.nn.DataParallel(model, device_ids=[0, 1]).to(DEVICE)
# else:
# model = model.to(DEVICE)
# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# # Initialize the data loader
# loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer, DEVICE)
# # Set up a tqdm progress bar for the epoch
# for epoch in range(MAX_ITERS):
# print(f"Epoch {epoch}")
# epoch_loss = None # Track loss for the epoch
# # Create a progress bar for batch processing
# batch_progress_bar = tqdm(loader, total=loader.num_batches(), desc=f"Epoch {epoch+1} Batch Progress", unit="batch", ncols=100)
# for xb, yb in batch_progress_bar:
# if xb is None:
# break # No more batches, stop the epoch
# # Forward pass and loss computation
# logits, loss = model(xb, yb)
# optimizer.zero_grad()
# loss.backward() # Backpropagate the loss
# optimizer.step() # Update model parameters
# # Update epoch_loss to the most recent loss value
# epoch_loss = loss.item()
# # Update tqdm with the latest loss value
# batch_progress_bar.set_postfix(loss=epoch_loss)
# # Save model weights after each chunk or epoch
# model.save(f"model_weights_epoch_{epoch}.pth")
# print(f"Model weights saved at epoch {epoch}")
# # Print the loss at the end of the epoch
# if epoch_loss is not None:
# print(f"Epoch {epoch}, Loss: {epoch_loss}")
# else:
# print(f"Epoch {epoch}, No data available for loss calculation.")
# # Reset the loader for a new epoch
# loader.reset()
# loader.close() # Ensure the file is properly closed at the end
# return model, optimizer
# def train(file_path, tokenizer, model = None, optimizer = None, vocab_size=10000, platform='none'):
# if model is None:
# model = GPTLanguageModel(vocab_size=vocab_size)
# if platform == 'kaggle':
# model = torch.nn.DataParallel(model, device_ids=[0, 1]).to(DEVICE)
# else:
# model = model.to(DEVICE)
# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer, DEVICE)
# for epoch in range(MAX_ITERS): # Iterate over the file chunks
# print(f"Epoch {epoch}")
# epoch_loss = None # Track loss for the epoch
# while not loader.end_of_file:
# xb, yb = loader.get_batch()
# if xb is None:
# break # No more batches, stop the epoch
# # Forward pass and loss computation
# # print("This is xb", xb)
# # print("This is yb", yb)
# logits, loss = model(xb, yb)
# optimizer.zero_grad()
# loss.backward() #2 gpus pe masla kr rraha (krna for n gpus hai)
# optimizer.step()
# # Update epoch_loss to the most recent loss value
# epoch_loss = loss.item()
# # Save model weights after each chunk or epoch
# model.save(f"model_weights_epoch_{epoch}.pth")
# print(f"Model weights saved at epoch {epoch}")
# # Print the loss only if it was computed
# if epoch_loss is not None:
# print(f"Epoch {epoch}, Loss: {epoch_loss}")
# else:
# print(f"Epoch {epoch}, No data available for loss calculation.")
# # Reset the loader for a new epoch
# loader.reset()
# loader.close() # Ensure file is properly closed at the end
# return model, optimizer
# def train(file_path, tokenizer, model=None, optimizer=None, vocab_size=10000):
# # Check if multiple GPUs are available
# device = DEVICE
# if model is None:
# if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# print(f"Training on {torch.cuda.device_count()} GPUs")
# model = GPTLanguageModel(vocab_size=vocab_size).to(device)
# model = torch.nn.DataParallel(model, device_ids=[0, 1]) # Wrap the model for multi-GPU training
# else:
# print("Training on a single GPU or CPU.")
# model = GPTLanguageModel(vocab_size=vocab_size).to(device)
# if optimizer is None:
# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# loader = TextDataLoader(file_path, BATCH_SIZE, BLOCK_SIZE, tokenizer, device)
# for epoch in range(MAX_ITERS): # Iterate over the file chunks
# print(f"Epoch {epoch}")
# epoch_loss = None # Track loss for the epoch
# xb, yb = loader.get_batch()
# if xb is None:
# break # No more batches, stop the epoch
# # Forward pass and loss computation
# logits, loss = model(xb, yb)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# # Update epoch_loss to the most recent loss value
# epoch_loss = loss.item()
# # Save model weights after each chunk or epoch
# model_to_save = model.module if isinstance(model, torch.nn.DataParallel) else model # Get the underlying model if using DataParallel
# model_to_save.save(f"model_weights_epoch_{epoch}.pth")
# print(f"Model weights saved at epoch {epoch}")
# # Print the loss only if it was computed
# if epoch_loss is not None:
# print(f"Epoch {epoch}, Loss: {epoch_loss}")
# else:
# print(f"Epoch {epoch}, No data available for loss calculation.")
# # Reset the loader for a new epoch
# loader.reset()
# loader.close() # Ensure file is properly closed at the end
# return model, optimizer
|