| import os |
| import math |
| import numpy as np |
| import time |
| from dataclasses import dataclass |
| import tiktoken |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| |
|
|
| from model import GPT |
| from dataloader import DataLoaderLite |
| from hellaswag_eval import render_example, iterate_examples, get_most_likely_row |
|
|
| torch.set_float32_matmul_precision('high') |
|
|
| |
| use_torch_compile = False |
|
|
|
|
| class Trainer: |
| def __init__( |
| self, |
| model, |
| optimizer, |
| train_loader, |
| val_loader, |
| token_encoder, |
| eval_freq, |
| grad_accum_steps, |
| ddp, |
| ddp_rank, |
| ddp_world_size, |
| device, |
| logpath |
| ): |
| self.ddp = ddp |
| self.ddp_rank = ddp_rank |
| self.master_process = ddp_rank == 0 |
| self.ddp_world_size = ddp_world_size |
|
|
| self.model = model |
| self.optimizer = optimizer |
| self.train_loader = train_loader |
| self.val_loader = val_loader |
| self.token_encoder = token_encoder |
|
|
| self.eval_freq = eval_freq |
| self.grad_accum_steps = grad_accum_steps |
| self.device = device |
| self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' |
| self.logpath = logpath |
|
|
|
|
| def train( |
| self, |
| max_steps, |
| warmup_steps, |
| max_lr, |
| min_lr |
| ): |
| for step in range(max_steps): |
| t0 = time.time() |
| self.is_last_step = (step == max_steps - 1) |
|
|
| |
| if step % self.eval_freq == 0 or self.is_last_step: |
| self.evaluate_validation(step) |
|
|
| |
| if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile): |
| self.evaluate_helloswag(step) |
|
|
| |
| if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile): |
| self.generate_sequences(num_seq=5, max_tokens=32) |
|
|
| |
| self.model.train() |
| self.optimizer.zero_grad() |
| batch_loss = 0.0 |
| |
| for mini_step in range(self.grad_accum_steps): |
| inp, tar = self.train_loader.next_batch() |
| inp, tar = inp.to(self.device), tar.to(self.device) |
| |
| |
| |
| with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
| logits, loss = self.model(inp, tar) |
|
|
| |
| |
| |
| loss /= self.grad_accum_steps |
| batch_loss += loss.detach() |
|
|
| if self.ddp: |
| |
| |
| self.model.require_backward_grad_sync = (mini_step == self.grad_accum_steps - 1) |
|
|
| |
| |
| |
| loss.backward() |
|
|
| if self.ddp: |
| |
| |
| |
| dist.all_reduce(batch_loss, op=dist.ReduceOp.AVG) |
|
|
| |
| norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
|
|
| |
| lr = self.estimate_lr(step, warmup_steps, max_steps, max_lr, min_lr) |
| |
| for param_group in self.optimizer.param_groups: |
| param_group['lr'] = lr |
| |
| self.optimizer.step() |
| if self.device_type == 'cuda': |
| torch.cuda.synchronize() |
| |
| dt = (time.time() - t0) * 1000.0 |
| tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * self.ddp_world_size |
| tokens_per_sec = tokens_processed / dt |
|
|
| if self.master_process: |
| print(f'step {step:4d} | loss: {batch_loss.item():.6f} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}') |
| with open(self.logpath, 'a') as f: |
| f.write(f'{step} train {batch_loss.item():.6f}\n') |
|
|
|
|
| def evaluate_validation(self, step): |
| self.model.eval() |
| self.val_loader.reset() |
| |
| with torch.no_grad(): |
| val_loss_accum = 0.0 |
| val_steps = 20 |
| for _ in range(val_steps): |
| inp, tar = self.val_loader.next_batch() |
| inp, tar = inp.to(self.device), tar.to(self.device) |
| with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
| logits, loss = self.model(inp, tar) |
| loss /= val_steps |
| val_loss_accum += loss.detach() |
|
|
| if self.ddp: |
| dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG) |
| if self.master_process: |
| print(f'Val loss: {val_loss_accum.item():.4f}') |
| with open(self.logpath, 'a') as f: |
| f.write(f'{step} val {val_loss_accum.item():.4f}\n') |
|
|
| if step > 0 and (step % 10000 == 0 or self.is_last_step): |
| raw_model = self.model.module if self.ddp else self.model |
| logdir = os.path.dirname(self.logpath) |
| ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt') |
| checkpoint = { |
| 'model': raw_model.state_dict(), |
| 'config': raw_model.config, |
| 'step': step, |
| 'val_loss': val_loss_accum.item() |
| } |
| torch.save(checkpoint, ckpt_path) |
|
|
|
|
| def evaluate_helloswag(self, step): |
| """ |
| Construct a batch of 4 sequences and perform token completion using |
| our model. |
| """ |
| n_total = 0 |
| n_correct_norm = 0 |
| for i, example in enumerate(iterate_examples('val')): |
| |
| if i % self.ddp_world_size != self.ddp_rank: |
| continue |
| |
| _, tokens, mask, label = render_example(example) |
| tokens, mask = tokens.to(self.device), mask.to(self.device) |
| with torch.no_grad(): |
| with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
| logits, loss = self.model(tokens) |
| pred_norm = get_most_likely_row(tokens, mask, logits) |
| n_total += 1 |
| n_correct_norm += int(pred_norm == label) |
| |
| if self.ddp: |
| n_total = torch.tensor(n_total, device=self.device, dtype=torch.long) |
| n_correct_norm = torch.tensor(n_correct_norm, device=self.device, dtype=torch.long) |
| dist.all_reduce(n_total, op=dist.ReduceOp.SUM) |
| dist.all_reduce(n_correct_norm, op=dist.ReduceOp.SUM) |
| n_total = n_total.item() |
| n_correct_norm = n_correct_norm.item() |
| acc_norm = n_correct_norm / n_total |
| if self.master_process: |
| print(f'HelloSwag accuracy: {n_correct_norm}/{n_total}={acc_norm:.4f}') |
| with open(self.logpath, 'a') as f: |
| f.write(f'{step} hellaswag {acc_norm:.4f}\n') |
|
|
|
|
| def generate_sequences(self, num_seq=4, max_tokens=32): |
| self.model.eval() |
| tokens = self.token_encoder.encode("Hello, I am a language model") |
| tokens = torch.tensor(tokens, dtype=torch.long) |
| tokens = tokens.unsqueeze(0).repeat(num_seq, 1) |
| gen_tokens = tokens.to(self.device) |
| |
| sample_rng = torch.Generator(device=self.device) |
| |
| sample_rng.manual_seed(42 + self.ddp_rank) |
| |
| while gen_tokens.shape[-1] <= max_tokens: |
| with torch.no_grad(): |
| with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
| logits, loss = self.model(gen_tokens) |
| logits = logits[:, -1, :] |
| probs = F.softmax(logits, dim=-1) |
| |
| topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) |
| |
| ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) |
| next_tok = torch.gather(topk_indices, -1, ix) |
| gen_tokens = torch.cat([gen_tokens, next_tok], dim=1) |
| |
| for i in range(num_seq): |
| tokens = gen_tokens[i, :max_tokens].tolist() |
| gen_text = self.token_encoder.decode(tokens) |
| print(f"> rank {self.ddp_rank} sample {i}: {gen_text}") |
|
|
|
|
| def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr): |
| """ |
| Learning rate scheduler: Cosine-decay learning schedule with warmup |
| """ |
| |
| if step < warmup_steps: |
| return max_lr * (step+1) / warmup_steps |
| |
| if step > max_steps: |
| return min_lr |
| |
| decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) |
| assert 0 <= decay_ratio <= 1 |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| return min_lr + coeff * (max_lr - min_lr) |
|
|
|
|
| @dataclass |
| class GPTConfig: |
| context_length: int = 1024 |
| vocab_size: int = 50257 |
| num_layers: int = 12 |
| embd_size: int = 768 |
| num_heads: int = 12 |
|
|
|
|
| def get_args(): |
| import argparse |
| parser = argparse.ArgumentParser(description="Hyperparameter Configuration") |
| parser.add_argument("--total_batch_size", type=int, default=524288, help="number of tokens processed for each weight update") |
| parser.add_argument("--mini_batch_size", type=int, default=32, help="setting of mini_batch_size is just a performance optimization. bigger gpu, bigger mini_batch_size") |
| parser.add_argument("--context_length", type=int, default=1024) |
| parser.add_argument("--num_layers", type=int, default=12) |
| parser.add_argument("--embd_size", type=int, default=768) |
| parser.add_argument("--num_heads", type=int, default=12) |
| parser.add_argument("--max_lr", type=float, default=1e-3) |
| parser.add_argument("--min_lr", type=float, default=1e-3 * 0.1) |
| parser.add_argument("--warmup_steps", type=int, default=715) |
| parser.add_argument("--weight_decay", type=float, default=0.1) |
| parser.add_argument("--num_epochs", type=int, default=5) |
| parser.add_argument("--steps_per_epoch", type=int, default=19073) |
| parser.add_argument("--eval_freq", type=int, default=250) |
| |
| parser.add_argument("--seed", type=int, default=1337, help="Random seed for reproducibility") |
| parser.add_argument("--logdir", type=str, default="./logs/") |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = get_args() |
|
|
| |
| print("Hyperparameter Configuration:") |
| for key, value in vars(args).items(): |
| print(f"{key}: {value}") |
|
|
| |
| os.makedirs(args.logdir, exist_ok=True) |
| logpath = os.path.join(args.logdir, 'log.txt') |
| with open(logpath, 'w') as f: |
| pass |
|
|
| |
| |
| |
| |
| ddp = int(os.environ.get('RANK', -1)) != -1 |
| if ddp: |
| |
| assert torch.cuda.is_available(), f'use of DDP requires CUDA' |
| dist.init_process_group(backend='nccl') |
| ddp_rank = int(os.environ['RANK']) |
| ddp_local_rank = int(os.environ['LOCAL_RANK']) |
| ddp_world_size = int(os.environ['WORLD_SIZE']) |
| device = f'cuda:{ddp_local_rank}' |
| torch.cuda.set_device(device) |
| |
| master_process = ddp_rank == 0 |
| else: |
| |
| ddp_rank = 0 |
| ddp_local_rank = 0 |
| ddp_world_size = 1 |
| master_process = True |
| device = 'cpu' |
| if torch.cuda.is_available(): |
| device = 'cuda' |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| device = 'mps' |
| print(f'using device: {device}') |
|
|
| device_type = 'cuda' if device.startswith('cuda') else 'cpu' |
|
|
| |
| np.random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(args.seed) |
| torch.cuda.manual_seed_all(args.seed) |
|
|
| assert args.total_batch_size % (args.mini_batch_size * args.context_length * ddp_world_size) == 0, f'ensure total_batch_size divisible by B*T*ddp_world_size' |
| grad_accum_steps = args.total_batch_size // (args.mini_batch_size * args.context_length * ddp_world_size) |
| if master_process: |
| print(f'desired batch size (number of tokens): {args.total_batch_size}') |
| print(f'gradient accumulation steps: {grad_accum_steps}') |
| print(f'GPU: {ddp_rank}, {ddp_local_rank}') |
|
|
| train_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='train') |
| val_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='val') |
|
|
| |
| |
| gpt_config = GPTConfig(vocab_size=50304, |
| context_length=args.context_length, |
| num_layers=args.num_layers, |
| num_heads=args.num_heads, |
| embd_size=args.embd_size |
| ) |
| model = GPT(config=gpt_config) |
| |
| model.to(device) |
| if use_torch_compile: |
| |
| |
| model = torch.compile(model) |
|
|
| if ddp: |
| |
| |
| |
| model = DDP(model, device_ids=[ddp_local_rank]) |
|
|
| raw_model = model.module if ddp else model |
| optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, lr=args.max_lr, device_type=device_type, master_process=master_process) |
| token_encoder = tiktoken.get_encoding('gpt2') |
|
|
| start_time = time.time() |
| |
| trainer = Trainer(model, optimizer, train_loader, val_loader, token_encoder, args.eval_freq, grad_accum_steps, |
| ddp, ddp_rank, ddp_world_size, device, logpath) |
|
|
| max_steps = args.steps_per_epoch * args.num_epochs |
| trainer.train(max_steps, args.warmup_steps, args.max_lr, args.min_lr) |
|
|
| dt = (time.time() - start_time) / (60*60) |
| print(f"Total training time: {dt:.4f}hr") |
|
|
| if ddp: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|