File size: 6,482 Bytes
ccfb646 498886e ccfb646 498886e ccfb646 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os
import torch
import time
import json
import math
from .model import GPT,GPTConfig
def train_memgpt(config_path,dataloader_class=None):
with open(config_path,'r') as f:
cfg = json.load(f)
model_cfg_params = cfg['model']
train_cfg_params = cfg['training']
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
assert torch.cuda.is_available()
init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)
master_process = ddp_rank == 0
else:
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
if master_process:
print(f"Using device: {device}")
device_type = "cuda" if device.startswith("cuda") else "cpu"
torch.manual_seed(1337)
if torch.cuda.is_available():
torch.cuda.manual_seed(1337)
total_batch_size = train_cfg_params['total_batch_size']
B = train_cfg_params['B']
T = train_cfg_params['T']
max_steps = train_cfg_params['max_steps']
log_dir = train_cfg_params['log_dir']
max_lr = train_cfg_params['max_lr']
min_lr = train_cfg_params['min_lr']
warmup_steps = train_cfg_params['warmup_steps']
weight_decay = train_cfg_params['weight_decay']
base_learning_rate = train_cfg_params['learning_rate']
assert total_batch_size % (B * T * ddp_world_size) == 0
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
if master_process:
print(f"Total desired batch size: {total_batch_size}")
print(f"Calculated gradient accumulation steps: {grad_accum_steps}")
train_loader = dataloader_class(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train",master_process=master_process)
val_loader = dataloader_class(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val",master_process=master_process)
torch.set_float32_matmul_precision('high')
# Create Model
model = GPT(GPTConfig(**model_cfg_params))
model.to(device)
use_compile = True
if use_compile:
model = torch.compile(model)
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model
def get_lr(it):
if it < warmup_steps:
return max_lr * (it + 1) / warmup_steps
if it > max_steps:
return min_lr
decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
optimizer = raw_model.configure_optimizers(weight_decay=weight_decay, learning_rate=base_learning_rate, device_type=device_type, master_process=master_process)
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, "log.txt")
with open(log_file, "w") as f:
pass
for step in range(max_steps):
t0 = time.time()
last_step = (step == max_steps - 1)
if step % 350 == 0 or last_step:
model.eval()
val_loader.reset()
with torch.no_grad():
val_loss_accum = 0.0
val_loss_steps = 20
for _ in range(val_loss_steps):
x, y = val_loader.next_batch()
x, y = x.to(device), y.to(device)
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / val_loss_steps
val_loss_accum += loss.detach()
if ddp:
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
if master_process:
print(f"Validation loss: {val_loss_accum.item():.4f}")
with open(log_file, "a") as f:
f.write(f"{step} val {val_loss_accum.item():.4f}\n")
checkpoint_name = f"model_final.pt" if last_step else f"model_{step:05d}.pt"
checkpoint_path = os.path.join(log_dir, checkpoint_name)
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': step,
'val_loss': val_loss_accum.item(),
'config': raw_model.config
}
torch.save(checkpoint, checkpoint_path)
model.train()
optimizer.zero_grad()
loss_accum = 0.0
for micro_step in range(grad_accum_steps):
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
if ddp:
model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / grad_accum_steps
loss_accum += loss.detach()
loss.backward()
if ddp:
dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
lr = get_lr(step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
if device_type == 'cuda':
torch.cuda.synchronize()
t1 = time.time()
dt = (t1 - t0) * 1000
tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
tokens_per_sec = tokens_processed / dt
if master_process:
print(f"Step:{step:5d} | Loss: {loss_accum.item():.6f} | lr: {lr:.4e} | Norm:{norm:.4f} | dt: {dt:.2f}ms | Tok/sec: {tokens_per_sec:.2f}")
with open(log_file, 'a') as f:
f.write(f"{step} train {loss_accum.item():.6f}\n")
if ddp:
destroy_process_group() |