| import torch |
| import torch.nn.functional as F |
| import logging |
| import time |
| from dataclasses import asdict |
|
|
| import numpy as np |
| import wandb |
| from torch.cuda.amp import autocast, GradScaler |
| from tqdm import tqdm |
|
|
| from src.engine.generate import generate_greedy |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| def calc_cross_entropy_loss(logits, targets): |
| B, T, V = logits.shape |
| return F.cross_entropy(logits.view(B * T, V), targets.view(B * T)) |
|
|
|
|
| def calc_perplexity(loss: float) -> float: |
| return float(np.exp(loss)) |
|
|
|
|
| def train( |
| model, optimizer, train_loader, val_loader, device, cfg, sample_prompt, tokenizer |
| ): |
| run = wandb.init( |
| entity="pjawale-student", |
| project="gpt2-pytorch", |
| config=asdict(cfg), |
| ) |
|
|
| scaler = GradScaler() |
| history = [] |
| tokens_per_batch = cfg.batch_size * cfg.context_window_size |
| global_step = 0 |
|
|
| try: |
| for epoch in range(1, cfg.num_epochs + 1): |
| model.train() |
| train_loss = 0.0 |
| epoch_start = time.perf_counter() |
| window_start = epoch_start |
| window_batches = 0 |
|
|
| pbar = tqdm( |
| train_loader, |
| desc=f"Epoch {epoch}/{cfg.num_epochs}", |
| unit="batch", |
| leave=True, |
| ) |
|
|
| for batch_idx, (inputs, targets) in enumerate(pbar, 1): |
| inputs, targets = inputs.to(device), targets.to(device) |
| optimizer.zero_grad() |
|
|
| with autocast(dtype=torch.bfloat16): |
| loss = calc_cross_entropy_loss(model(inputs), targets) |
|
|
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| train_loss += loss.item() |
| window_batches += 1 |
| global_step += 1 |
|
|
| if batch_idx % 10 == 0 or batch_idx == len(train_loader): |
| now = time.perf_counter() |
| elapsed = now - window_start |
| tok_s = (window_batches * tokens_per_batch) / elapsed if elapsed > 0 else 0.0 |
| window_start = now |
| window_batches = 0 |
|
|
| pbar.set_postfix( |
| loss=f"{loss.item():.4f}", |
| tok_s=f"{tok_s:,.0f}", |
| ) |
|
|
| wandb.log( |
| { |
| "train/batch_loss": loss.item(), |
| "train/batch_perplexity": calc_perplexity(loss.item()), |
| "train/tok_s": tok_s, |
| "epoch": epoch, |
| }, |
| step=global_step, |
| ) |
|
|
| epoch_elapsed = time.perf_counter() - epoch_start |
| epoch_tokens = len(train_loader) * tokens_per_batch |
| epoch_tok_s = epoch_tokens / epoch_elapsed if epoch_elapsed > 0 else 0.0 |
|
|
| model.eval() |
| val_loss = 0.0 |
| with torch.no_grad(): |
| for inputs, targets in tqdm(val_loader, desc="Validating", leave=False): |
| inputs, targets = inputs.to(device), targets.to(device) |
| with autocast(dtype=torch.bfloat16): |
| val_loss += calc_cross_entropy_loss(model(inputs), targets).item() |
|
|
| avg_train = train_loss / len(train_loader) |
| avg_val = val_loss / len(val_loader) |
| avg_train_perplexity = calc_perplexity(avg_train) |
| avg_val_perplexity = calc_perplexity(avg_val) |
| sample = generate_greedy(model, tokenizer, device, sample_prompt) |
|
|
| logging.info( |
| f"Epoch {epoch:2d}/{cfg.num_epochs} | train={avg_train:.4f} | val={avg_val:.4f} | " |
| f"train_perplexity={avg_train_perplexity:.4f} | val_perplexity={avg_val_perplexity:.4f} | " |
| f"{epoch_tok_s:,.0f} tok/s (epoch avg)" |
| ) |
|
|
| wandb.log( |
| { |
| "train/loss": avg_train, |
| "val/loss": avg_val, |
| "train/perplexity": avg_train_perplexity, |
| "val/perplexity": avg_val_perplexity, |
| "train/tok_s_epoch": epoch_tok_s, |
| "sample": sample, |
| "epoch": epoch, |
| }, |
| step=global_step, |
| ) |
|
|
| history.append((avg_train, avg_val)) |
| logging.info(f" Sample: {sample}\n") |
|
|
| finally: |
| wandb.finish() |
|
|
| return np.array(history) |