GPT-124m / model_core /training.py
abhinavv3's picture
minor changes
498886e
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os
import torch
import time
import json
import math
from .model import GPT,GPTConfig
def train_memgpt(config_path,dataloader_class=None):
with open(config_path,'r') as f:
cfg = json.load(f)
model_cfg_params = cfg['model']
train_cfg_params = cfg['training']
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
assert torch.cuda.is_available()
init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)
master_process = ddp_rank == 0
else:
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
if master_process:
print(f"Using device: {device}")
device_type = "cuda" if device.startswith("cuda") else "cpu"
torch.manual_seed(1337)
if torch.cuda.is_available():
torch.cuda.manual_seed(1337)
total_batch_size = train_cfg_params['total_batch_size']
B = train_cfg_params['B']
T = train_cfg_params['T']
max_steps = train_cfg_params['max_steps']
log_dir = train_cfg_params['log_dir']
max_lr = train_cfg_params['max_lr']
min_lr = train_cfg_params['min_lr']
warmup_steps = train_cfg_params['warmup_steps']
weight_decay = train_cfg_params['weight_decay']
base_learning_rate = train_cfg_params['learning_rate']
assert total_batch_size % (B * T * ddp_world_size) == 0
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
if master_process:
print(f"Total desired batch size: {total_batch_size}")
print(f"Calculated gradient accumulation steps: {grad_accum_steps}")
train_loader = dataloader_class(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train",master_process=master_process)
val_loader = dataloader_class(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val",master_process=master_process)
torch.set_float32_matmul_precision('high')
# Create Model
model = GPT(GPTConfig(**model_cfg_params))
model.to(device)
use_compile = True
if use_compile:
model = torch.compile(model)
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model
def get_lr(it):
if it < warmup_steps:
return max_lr * (it + 1) / warmup_steps
if it > max_steps:
return min_lr
decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
optimizer = raw_model.configure_optimizers(weight_decay=weight_decay, learning_rate=base_learning_rate, device_type=device_type, master_process=master_process)
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, "log.txt")
with open(log_file, "w") as f:
pass
for step in range(max_steps):
t0 = time.time()
last_step = (step == max_steps - 1)
if step % 350 == 0 or last_step:
model.eval()
val_loader.reset()
with torch.no_grad():
val_loss_accum = 0.0
val_loss_steps = 20
for _ in range(val_loss_steps):
x, y = val_loader.next_batch()
x, y = x.to(device), y.to(device)
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / val_loss_steps
val_loss_accum += loss.detach()
if ddp:
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
if master_process:
print(f"Validation loss: {val_loss_accum.item():.4f}")
with open(log_file, "a") as f:
f.write(f"{step} val {val_loss_accum.item():.4f}\n")
checkpoint_name = f"model_final.pt" if last_step else f"model_{step:05d}.pt"
checkpoint_path = os.path.join(log_dir, checkpoint_name)
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': step,
'val_loss': val_loss_accum.item(),
'config': raw_model.config
}
torch.save(checkpoint, checkpoint_path)
model.train()
optimizer.zero_grad()
loss_accum = 0.0
for micro_step in range(grad_accum_steps):
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
if ddp:
model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / grad_accum_steps
loss_accum += loss.detach()
loss.backward()
if ddp:
dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
lr = get_lr(step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
if device_type == 'cuda':
torch.cuda.synchronize()
t1 = time.time()
dt = (t1 - t0) * 1000
tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
tokens_per_sec = tokens_processed / dt
if master_process:
print(f"Step:{step:5d} | Loss: {loss_accum.item():.6f} | lr: {lr:.4e} | Norm:{norm:.4f} | dt: {dt:.2f}ms | Tok/sec: {tokens_per_sec:.2f}")
with open(log_file, 'a') as f:
f.write(f"{step} train {loss_accum.item():.6f}\n")
if ddp:
destroy_process_group()