|
|
""" |
|
|
Train model. Run as: |
|
|
|
|
|
python base_train.py |
|
|
|
|
|
or distributed as: |
|
|
|
|
|
torchrun --nproc_per_node=8 base_train.py |
|
|
|
|
|
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: |
|
|
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20 |
|
|
""" |
|
|
|
|
|
import os |
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
import time |
|
|
from contextlib import nullcontext |
|
|
|
|
|
import wandb |
|
|
import torch |
|
|
|
|
|
from nanochat.gpt import GPT, GPTConfig |
|
|
from nanochat.dataloader import tokenizing_distributed_data_loader |
|
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type |
|
|
from nanochat.tokenizer import get_tokenizer, get_token_bytes |
|
|
from nanochat.checkpoint_manager import save_checkpoint |
|
|
from nanochat.loss_eval import evaluate_bpb |
|
|
from nanochat.engine import Engine |
|
|
from scripts.base_eval import evaluate_model |
|
|
print_banner() |
|
|
|
|
|
|
|
|
|
|
|
run = "dummy" |
|
|
|
|
|
device_type = "" |
|
|
|
|
|
depth = 20 |
|
|
max_seq_len = 2048 |
|
|
|
|
|
num_iterations = -1 |
|
|
target_flops = -1.0 |
|
|
target_param_data_ratio = 20 |
|
|
|
|
|
device_batch_size = 32 |
|
|
total_batch_size = 524288 |
|
|
embedding_lr = 0.2 |
|
|
unembedding_lr = 0.004 |
|
|
weight_decay = 0.0 |
|
|
matrix_lr = 0.02 |
|
|
grad_clip = 1.0 |
|
|
warmup_ratio = 0.0 |
|
|
warmdown_ratio = 0.2 |
|
|
final_lr_frac = 0.0 |
|
|
|
|
|
eval_every = 250 |
|
|
eval_tokens = 20*524288 |
|
|
core_metric_every = 2000 |
|
|
core_metric_max_per_task = 500 |
|
|
sample_every = 2000 |
|
|
|
|
|
model_tag = "" |
|
|
|
|
|
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] |
|
|
exec(open(os.path.join('nanochat', 'configurator.py')).read()) |
|
|
user_config = {k: globals()[k] for k in config_keys} |
|
|
|
|
|
|
|
|
|
|
|
device_type = autodetect_device_type() if device_type == "" else device_type |
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) |
|
|
master_process = ddp_rank == 0 |
|
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() |
|
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None |
|
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 |
|
|
|
|
|
|
|
|
use_dummy_wandb = run == "dummy" or not master_process |
|
|
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config) |
|
|
|
|
|
|
|
|
tokenizer = get_tokenizer() |
|
|
token_bytes = get_token_bytes(device=device) |
|
|
vocab_size = tokenizer.get_vocab_size() |
|
|
print0(f"Vocab size: {vocab_size:,}") |
|
|
|
|
|
|
|
|
num_layers = depth |
|
|
model_dim = depth * 64 |
|
|
num_heads = max(1, (model_dim + 127) // 128) |
|
|
num_kv_heads = num_heads |
|
|
print0(f"num_layers: {num_layers}") |
|
|
print0(f"model_dim: {model_dim}") |
|
|
print0(f"num_heads: {num_heads}") |
|
|
print0(f"num_kv_heads: {num_kv_heads}") |
|
|
|
|
|
|
|
|
|
|
|
tokens_per_fwdbwd = device_batch_size * max_seq_len |
|
|
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size |
|
|
assert total_batch_size % world_tokens_per_fwdbwd == 0 |
|
|
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd |
|
|
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") |
|
|
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") |
|
|
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") |
|
|
|
|
|
|
|
|
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) |
|
|
with torch.device("meta"): |
|
|
model_config = GPTConfig(**model_config_kwargs) |
|
|
model = GPT(model_config) |
|
|
model.to_empty(device=device) |
|
|
model.init_weights() |
|
|
orig_model = model |
|
|
model = torch.compile(model, dynamic=False) |
|
|
num_params = sum(p.numel() for p in model.parameters()) |
|
|
print0(f"Number of parameters: {num_params:,}") |
|
|
num_flops_per_token = model.estimate_flops() |
|
|
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") |
|
|
|
|
|
|
|
|
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0 |
|
|
if num_iterations > 0: |
|
|
print0(f"Using user-provided number of iterations: {num_iterations:,}") |
|
|
elif target_flops > 0: |
|
|
|
|
|
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size)) |
|
|
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") |
|
|
elif target_param_data_ratio > 0: |
|
|
|
|
|
target_tokens = target_param_data_ratio * num_params |
|
|
num_iterations = target_tokens // total_batch_size |
|
|
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") |
|
|
else: |
|
|
raise ValueError("No training horizon specified") |
|
|
total_tokens = total_batch_size * num_iterations |
|
|
print0(f"Total number of training tokens: {total_tokens:,}") |
|
|
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") |
|
|
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") |
|
|
|
|
|
|
|
|
|
|
|
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) |
|
|
adamw_optimizer, muon_optimizer = optimizers |
|
|
|
|
|
|
|
|
base_dir = get_base_dir() |
|
|
tokens_dir = os.path.join(base_dir, "tokenized_data") |
|
|
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device) |
|
|
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) |
|
|
x, y = next(train_loader) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lr_multiplier(it): |
|
|
warmup_iters = round(warmup_ratio * num_iterations) |
|
|
warmdown_iters = round(warmdown_ratio * num_iterations) |
|
|
if it < warmup_iters: |
|
|
return (it + 1) / warmup_iters |
|
|
elif it <= num_iterations - warmdown_iters: |
|
|
return 1.0 |
|
|
else: |
|
|
progress = (num_iterations - it) / warmdown_iters |
|
|
return progress * 1.0 + (1 - progress) * final_lr_frac |
|
|
|
|
|
|
|
|
def get_muon_momentum(it): |
|
|
frac = min(it / 300, 1) |
|
|
momentum = (1 - frac) * 0.85 + frac * 0.95 |
|
|
return momentum |
|
|
|
|
|
|
|
|
|
|
|
min_val_bpb = float("inf") |
|
|
smooth_train_loss = 0 |
|
|
ema_beta = 0.9 |
|
|
total_training_time = 0 |
|
|
|
|
|
for step in range(num_iterations + 1): |
|
|
last_step = step == num_iterations |
|
|
flops_so_far = num_flops_per_token * total_batch_size * step |
|
|
|
|
|
|
|
|
if last_step or step % eval_every == 0: |
|
|
model.eval() |
|
|
val_loader = build_val_loader() |
|
|
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size) |
|
|
with autocast_ctx: |
|
|
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) |
|
|
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}") |
|
|
if val_bpb < min_val_bpb: |
|
|
min_val_bpb = val_bpb |
|
|
wandb_run.log({ |
|
|
"step": step, |
|
|
"total_training_flops": flops_so_far, |
|
|
"total_training_time": total_training_time, |
|
|
"val/bpb": val_bpb, |
|
|
}) |
|
|
model.train() |
|
|
|
|
|
|
|
|
|
|
|
results = {} |
|
|
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)): |
|
|
model.eval() |
|
|
with autocast_ctx: |
|
|
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task) |
|
|
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") |
|
|
wandb_run.log({ |
|
|
"step": step, |
|
|
"total_training_flops": flops_so_far, |
|
|
"core_metric": results["core_metric"], |
|
|
"centered_results": results["centered_results"], |
|
|
}) |
|
|
model.train() |
|
|
|
|
|
|
|
|
|
|
|
if master_process and (last_step or (step > 0 and step % sample_every == 0)): |
|
|
model.eval() |
|
|
prompts = [ |
|
|
"The capital of France is", |
|
|
"The chemical symbol of gold is", |
|
|
"If yesterday was Friday, then tomorrow will be", |
|
|
"The opposite of hot is", |
|
|
"The planets of the solar system are:", |
|
|
"My favorite color is", |
|
|
"If 5*x + 3 = 13, then x is", |
|
|
] |
|
|
engine = Engine(orig_model, tokenizer) |
|
|
for prompt in prompts: |
|
|
tokens = tokenizer(prompt, prepend="<|bos|>") |
|
|
with autocast_ctx: |
|
|
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) |
|
|
print0(tokenizer.decode(sample[0])) |
|
|
model.train() |
|
|
|
|
|
|
|
|
if master_process and last_step: |
|
|
output_dirname = model_tag if model_tag else f"d{depth}" |
|
|
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) |
|
|
save_checkpoint( |
|
|
checkpoint_dir, |
|
|
step, |
|
|
orig_model.state_dict(), |
|
|
[opt.state_dict() for opt in optimizers], |
|
|
{ |
|
|
"step": step, |
|
|
"val_bpb": val_bpb, |
|
|
"model_config": model_config_kwargs, |
|
|
"user_config": user_config, |
|
|
"device_batch_size": device_batch_size, |
|
|
"max_seq_len": max_seq_len, |
|
|
} |
|
|
) |
|
|
|
|
|
if last_step: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
synchronize() |
|
|
t0 = time.time() |
|
|
for micro_step in range(grad_accum_steps): |
|
|
with autocast_ctx: |
|
|
loss = model(x, y) |
|
|
train_loss = loss.detach() |
|
|
loss = loss / grad_accum_steps |
|
|
loss.backward() |
|
|
x, y = next(train_loader) |
|
|
|
|
|
grad_clip_enabled = grad_clip > 0.0 |
|
|
if grad_clip_enabled: |
|
|
grad_norm_tensor = torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip) |
|
|
grad_norm = grad_norm_tensor.item() |
|
|
|
|
|
lrm = get_lr_multiplier(step) |
|
|
for opt in optimizers: |
|
|
for group in opt.param_groups: |
|
|
group["lr"] = group["initial_lr"] * lrm |
|
|
muon_momentum = get_muon_momentum(step) |
|
|
for group in muon_optimizer.param_groups: |
|
|
group["momentum"] = muon_momentum |
|
|
for opt in optimizers: |
|
|
opt.step() |
|
|
model.zero_grad(set_to_none=True) |
|
|
synchronize() |
|
|
t1 = time.time() |
|
|
dt = t1 - t0 |
|
|
|
|
|
|
|
|
|
|
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() |
|
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) |
|
|
pct_done = 100 * step / num_iterations |
|
|
tok_per_sec = int(total_batch_size / dt) |
|
|
flops_per_sec = num_flops_per_token * total_batch_size / dt |
|
|
promised_flops_per_sec_h100 = 989e12 * ddp_world_size |
|
|
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 |
|
|
if step > 10: |
|
|
total_training_time += dt |
|
|
print_grad_norm = f" grad norm: {grad_norm:.4f} |" if grad_clip_enabled else "" |
|
|
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} |{print_grad_norm} lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") |
|
|
if step % 100 == 0: |
|
|
log_data = { |
|
|
"step": step, |
|
|
"total_training_flops": flops_so_far, |
|
|
"total_training_time": total_training_time, |
|
|
"train/loss": debiased_smooth_loss, |
|
|
"train/lrm": lrm, |
|
|
"train/dt": dt, |
|
|
"train/tok_per_sec": tok_per_sec, |
|
|
"train/mfu": mfu, |
|
|
} |
|
|
if grad_clip_enabled: |
|
|
log_data["train/grad_norm"] = grad_norm |
|
|
wandb_run.log(log_data) |
|
|
|
|
|
|
|
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") |
|
|
print0(f"Total training time: {total_training_time/60:.2f}m") |
|
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}") |
|
|
|
|
|
|
|
|
from nanochat.report import get_report |
|
|
get_report().log(section="Base model training", data=[ |
|
|
user_config, |
|
|
{ |
|
|
"Number of parameters": num_params, |
|
|
"Number of FLOPs per token": f"{num_flops_per_token:e}", |
|
|
"Calculated number of iterations": num_iterations, |
|
|
"Number of training tokens": total_tokens, |
|
|
"Tokens : Params ratio": total_batch_size * num_iterations / num_params, |
|
|
"DDP world size": ddp_world_size, |
|
|
"warmup_ratio": warmup_ratio, |
|
|
"warmdown_ratio": warmdown_ratio, |
|
|
"final_lr_frac": final_lr_frac, |
|
|
}, |
|
|
{ |
|
|
"Minimum validation bpb": min_val_bpb, |
|
|
"Final validation bpb": val_bpb, |
|
|
"CORE metric estimate": results.get("core_metric", None), |
|
|
"MFU %": f"{mfu:.2f}%", |
|
|
"Total training flops": f"{flops_so_far:e}", |
|
|
"Total training time": f"{total_training_time/60:.2f}m", |
|
|
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", |
|
|
} |
|
|
]) |
|
|
|
|
|
|
|
|
wandb_run.finish() |
|
|
compute_cleanup() |
|
|
|