| """ |
| Train model. From root directory of the project, run as: |
| |
| python -m scripts.base_train |
| |
| or distributed as: |
| |
| torchrun --nproc_per_node=8 -m scripts.base_train |
| |
| If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example: |
| python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20 |
| """ |
|
|
| import os |
| os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" |
| import gc |
| import json |
| import time |
| import math |
| import argparse |
| from dataclasses import asdict |
| from contextlib import contextmanager |
|
|
| import wandb |
| import torch |
| import torch.distributed as dist |
|
|
| from nanochat.gpt import GPT, GPTConfig, Linear |
| from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit |
| from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized |
| from nanochat.tokenizer import get_tokenizer, get_token_bytes |
| from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint |
| from nanochat.loss_eval import evaluate_bpb |
| from nanochat.engine import Engine |
| from nanochat.flash_attention import HAS_FA3 |
| from scripts.base_eval import evaluate_core |
| print_banner() |
|
|
| |
| |
| parser = argparse.ArgumentParser(description="Pretrain base model") |
| |
| parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)") |
| |
| parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)") |
| |
| parser.add_argument("--fp8", action="store_true", help="enable FP8 training (requires H100+ GPU and torchao)") |
| parser.add_argument("--fp8-recipe", type=str, default="tensorwise", choices=["rowwise", "tensorwise"], help="FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower)") |
| |
| parser.add_argument("--depth", type=int, default=20, help="depth of the Transformer model") |
| parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = depth * aspect_ratio") |
| parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") |
| parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") |
| parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") |
| |
| parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") |
| parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") |
| parser.add_argument("--target-param-data-ratio", type=float, default=10.5, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") |
| |
| parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM.") |
| parser.add_argument("--total-batch-size", type=int, default=-1, help="total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal)") |
| parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)") |
| parser.add_argument("--unembedding-lr", type=float, default=0.008, help="learning rate for unembedding parameters (Adam)") |
| parser.add_argument("--weight-decay", type=float, default=0.28, help="cautious weight decay for the Muon optimizer (for weights)") |
| parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)") |
| parser.add_argument("--scalar-lr", type=float, default=0.5, help="learning rate for scalars (resid_lambdas, x0_lambdas)") |
| parser.add_argument("--warmup-steps", type=int, default=40, help="number of steps for LR warmup") |
| parser.add_argument("--warmdown-ratio", type=float, default=0.65, help="ratio of iterations for LR warmdown") |
| parser.add_argument("--final-lr-frac", type=float, default=0.05, help="final LR as fraction of initial LR") |
| parser.add_argument("--resume-from-step", type=int, default=-1, help="resume training from this step (-1 = disable)") |
| |
| parser.add_argument("--eval-every", type=int, default=250, help="evaluate val bpb every N steps (-1 = disable)") |
| parser.add_argument("--eval-tokens", type=int, default=80*524288, help="number of tokens to evaluate val loss on") |
| parser.add_argument("--core-metric-every", type=int, default=2000, help="evaluate CORE metric every N steps (-1 = disable)") |
| parser.add_argument("--core-metric-max-per-task", type=int, default=500, help="examples per task for CORE metric") |
| parser.add_argument("--sample-every", type=int, default=2000, help="sample from model every N steps (-1 = disable)") |
| parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") |
| |
| parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") |
| args = parser.parse_args() |
| user_config = vars(args).copy() |
| |
| |
|
|
| device_type = autodetect_device_type() if args.device_type == "" else args.device_type |
| ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) |
| master_process = ddp_rank == 0 |
| synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None |
| get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 |
| if device_type == "cuda": |
| gpu_device_name = torch.cuda.get_device_name(0) |
| gpu_peak_flops = get_peak_flops(gpu_device_name) |
| print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}") |
| else: |
| gpu_peak_flops = float('inf') |
| print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})") |
|
|
| |
| use_dummy_wandb = args.run == "dummy" or not master_process |
| wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) |
|
|
| |
| from nanochat.flash_attention import USE_FA3 |
| using_fa3 = USE_FA3 |
| if using_fa3: |
| print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") |
| else: |
| print0("!" * 80) |
| if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16: |
| print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback") |
| else: |
| print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") |
| print0("WARNING: Training will be less efficient without FA3") |
| if args.window_pattern != "L": |
| print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") |
| print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") |
| print0("!" * 80) |
|
|
| |
| |
| tokenizer = get_tokenizer() |
| token_bytes = get_token_bytes(device=device) |
| vocab_size = tokenizer.get_vocab_size() |
| print0(f"Vocab size: {vocab_size:,}") |
|
|
| |
| |
|
|
| def build_model_meta(depth): |
| """Build a model on meta device for a given depth (shapes/dtypes only, no data).""" |
| |
| |
| base_dim = depth * args.aspect_ratio |
| model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim |
| num_heads = model_dim // args.head_dim |
| config = GPTConfig( |
| sequence_len=args.max_seq_len, vocab_size=vocab_size, |
| n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, |
| window_pattern=args.window_pattern, |
| ) |
| with torch.device("meta"): |
| model_meta = GPT(config) |
| return model_meta |
|
|
| |
| model = build_model_meta(args.depth) |
| model_config = model.config |
| model_config_kwargs = asdict(model_config) |
| print0(f"Model config:\n{json.dumps(model_config_kwargs, indent=2)}") |
| model.to_empty(device=device) |
| model.init_weights() |
|
|
| |
| base_dir = get_base_dir() |
| output_dirname = args.model_tag if args.model_tag else f"d{args.depth}" |
| checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname) |
| resuming = args.resume_from_step != -1 |
| if resuming: |
| print0(f"Resuming optimization from step {args.resume_from_step}") |
| model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, args.resume_from_step, device, load_optimizer=True, rank=ddp_rank) |
| model.load_state_dict(model_data, strict=True, assign=True) |
| del model_data |
|
|
| |
| |
|
|
| |
| if args.fp8: |
| if device_type != "cuda": |
| print0("Warning: FP8 training requires CUDA, ignoring --fp8 flag") |
| else: |
| |
| from nanochat.fp8 import Float8LinearConfig, convert_to_float8_training |
| |
| import torch.nn as nn |
|
|
| |
| def fp8_module_filter(mod: nn.Module, fqn: str) -> bool: |
| if not isinstance(mod, nn.Linear): |
| return False |
| if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: |
| return False |
| if min(mod.in_features, mod.out_features) < 128: |
| return False |
| return True |
|
|
| fp8_config = Float8LinearConfig.from_recipe_name(args.fp8_recipe) |
| num_linear = sum(1 for m in model.modules() if isinstance(m, nn.Linear)) |
| convert_to_float8_training(model, config=fp8_config, module_filter_fn=fp8_module_filter) |
| num_fp8 = sum(1 for m in model.modules() if 'Float8' in type(m).__name__) |
| num_skipped = num_linear - num_fp8 |
| print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)") |
|
|
| |
| @contextmanager |
| def disable_fp8(model): |
| """Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation. |
| |
| CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead, |
| we swap out Float8Linear modules entirely and restore them after. |
| """ |
| import torch.nn as nn |
|
|
| |
| fp8_locations = [] |
| for name, module in model.named_modules(): |
| if 'Float8' in type(module).__name__: |
| if '.' in name: |
| parent_name, attr_name = name.rsplit('.', 1) |
| parent = model.get_submodule(parent_name) |
| else: |
| parent = model |
| attr_name = name |
| fp8_locations.append((parent, attr_name, module)) |
|
|
| if not fp8_locations: |
| yield |
| return |
|
|
| |
| for parent, attr_name, fp8_module in fp8_locations: |
| linear = Linear( |
| fp8_module.in_features, |
| fp8_module.out_features, |
| bias=fp8_module.bias is not None, |
| device=fp8_module.weight.device, |
| dtype=fp8_module.weight.dtype, |
| ) |
| linear.weight = fp8_module.weight |
| if fp8_module.bias is not None: |
| linear.bias = fp8_module.bias |
| setattr(parent, attr_name, linear) |
|
|
| try: |
| yield |
| finally: |
| |
| for parent, attr_name, fp8_module in fp8_locations: |
| setattr(parent, attr_name, fp8_module) |
|
|
| |
| |
|
|
| orig_model = model |
| model = torch.compile(model, dynamic=False) |
|
|
| |
| |
|
|
| |
| param_counts = model.num_scaling_params() |
| print0(f"Parameter counts:") |
| for key, value in param_counts.items(): |
| print0(f"{key:24s}: {value:,}") |
| num_params = param_counts['total'] |
| num_flops_per_token = model.estimate_flops() |
| print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") |
|
|
| |
| |
| |
| def get_scaling_params(m): |
| |
| params_counts = m.num_scaling_params() |
| scaling_params = params_counts['transformer_matrices'] + params_counts['lm_head'] |
| return scaling_params |
| num_scaling_params = get_scaling_params(model) |
| target_tokens = int(args.target_param_data_ratio * num_scaling_params) |
|
|
| |
| d12_ref = build_model_meta(12) |
| D_REF = args.target_param_data_ratio * get_scaling_params(d12_ref) |
| B_REF = 2**19 |
|
|
| |
| |
| |
| total_batch_size = args.total_batch_size |
| if total_batch_size == -1: |
| batch_size_ratio = target_tokens / D_REF |
| predicted_batch_size = B_REF * batch_size_ratio ** 0.383 |
| total_batch_size = 2 ** round(math.log2(predicted_batch_size)) |
| print0(f"Auto-computed optimal batch size: {total_batch_size:,} tokens") |
|
|
| |
| batch_lr_scale = 1.0 |
| batch_ratio = total_batch_size / B_REF |
| if batch_ratio != 1.0: |
| |
| |
| |
| batch_lr_scale = batch_ratio ** 0.5 |
| print0(f"Scaling LRs by {batch_lr_scale:.4f} for batch size {total_batch_size:,} (reference: {B_REF:,})") |
|
|
| |
| |
| |
| |
| |
| |
| weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens) |
| if weight_decay_scaled != args.weight_decay: |
| print0(f"Scaling weight decay from {args.weight_decay:.6f} to {weight_decay_scaled:.6f} for depth {args.depth}") |
|
|
| |
| |
| optimizer = model.setup_optimizer( |
| |
| unembedding_lr=args.unembedding_lr * batch_lr_scale, |
| embedding_lr=args.embedding_lr * batch_lr_scale, |
| scalar_lr=args.scalar_lr * batch_lr_scale, |
| |
| matrix_lr=args.matrix_lr * batch_lr_scale, |
| weight_decay=weight_decay_scaled, |
| ) |
|
|
| if resuming: |
| optimizer.load_state_dict(optimizer_data) |
| del optimizer_data |
|
|
| |
| |
| scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None |
| if scaler is not None: |
| print0("GradScaler enabled for fp16 training") |
|
|
| |
| |
| dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] |
| train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) |
| build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) |
| x, y, dataloader_state_dict = next(train_loader) |
|
|
| |
| |
|
|
| |
| assert args.num_iterations > 0 or args.target_param_data_ratio > 0 or args.target_flops > 0 |
| if args.num_iterations > 0: |
| |
| num_iterations = args.num_iterations |
| print0(f"Using user-provided number of iterations: {num_iterations:,}") |
| elif args.target_flops > 0: |
| |
| num_iterations = round(args.target_flops / (num_flops_per_token * total_batch_size)) |
| print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}") |
| elif args.target_param_data_ratio > 0: |
| |
| num_iterations = target_tokens // total_batch_size |
| print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}") |
| else: |
| raise ValueError("No training horizon specified") |
| total_tokens = total_batch_size * num_iterations |
| print0(f"Total number of training tokens: {total_tokens:,}") |
| print0(f"Tokens : Scaling params ratio: {total_batch_size * num_iterations / num_scaling_params:.2f}") |
| print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}") |
|
|
| |
| def get_lr_multiplier(it): |
| warmup_iters = args.warmup_steps |
| warmdown_iters = round(args.warmdown_ratio * num_iterations) |
| if it < warmup_iters: |
| return (it + 1) / warmup_iters |
| elif it <= num_iterations - warmdown_iters: |
| return 1.0 |
| else: |
| progress = (num_iterations - it) / warmdown_iters |
| return progress * 1.0 + (1 - progress) * args.final_lr_frac |
|
|
| |
| def get_muon_momentum(it): |
| warmdown_iters = round(args.warmdown_ratio * num_iterations) |
| warmdown_start = num_iterations - warmdown_iters |
| if it < 400: |
| frac = it / 400 |
| return (1 - frac) * 0.85 + frac * 0.97 |
| elif it >= warmdown_start: |
| progress = (it - warmdown_start) / warmdown_iters |
| return 0.97 * (1 - progress) + 0.90 * progress |
| else: |
| return 0.97 |
|
|
| |
| def get_weight_decay(it): |
| return weight_decay_scaled * 0.5 * (1 + math.cos(math.pi * it / num_iterations)) |
|
|
| |
| |
|
|
| |
| if not resuming: |
| step = 0 |
| val_bpb = None |
| min_val_bpb = float("inf") |
| smooth_train_loss = 0 |
| total_training_time = 0 |
| else: |
| step = meta_data["step"] |
| loop_state = meta_data["loop_state"] |
| val_bpb = meta_data["val_bpb"] |
| min_val_bpb = loop_state["min_val_bpb"] |
| smooth_train_loss = loop_state["smooth_train_loss"] |
| total_training_time = loop_state["total_training_time"] |
|
|
| |
| tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len |
| world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size |
| assert total_batch_size % world_tokens_per_fwdbwd == 0 |
| grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd |
| print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_len} = {tokens_per_fwdbwd:,}") |
| print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") |
| print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") |
|
|
| |
| while True: |
| last_step = step == num_iterations |
| flops_so_far = num_flops_per_token * total_batch_size * step |
|
|
| |
| if args.eval_every > 0 and (last_step or step % args.eval_every == 0): |
| model.eval() |
| val_loader = build_val_loader() |
| eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) |
| with disable_fp8(model): |
| val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) |
| print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") |
| if val_bpb < min_val_bpb: |
| min_val_bpb = val_bpb |
| wandb_run.log({ |
| "step": step, |
| "total_training_flops": flops_so_far, |
| "total_training_time": total_training_time, |
| "val/bpb": val_bpb, |
| }) |
| model.train() |
|
|
| |
| |
| |
| results = {} |
| if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): |
| model.eval() |
| with disable_fp8(orig_model): |
| results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) |
| print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") |
| wandb_run.log({ |
| "step": step, |
| "total_training_flops": flops_so_far, |
| "core_metric": results["core_metric"], |
| "centered_results": results["centered_results"], |
| }) |
| model.train() |
|
|
| |
| |
| if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)): |
| model.eval() |
| prompts = [ |
| "The capital of France is", |
| "The chemical symbol of gold is", |
| "If yesterday was Friday, then tomorrow will be", |
| "The opposite of hot is", |
| "The planets of the solar system are:", |
| "My favorite color is", |
| "If 5*x + 3 = 13, then x is", |
| ] |
| engine = Engine(orig_model, tokenizer) |
| for prompt in prompts: |
| tokens = tokenizer(prompt, prepend="<|bos|>") |
| with disable_fp8(orig_model): |
| sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) |
| print0(tokenizer.decode(sample[0])) |
| model.train() |
|
|
| |
| if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): |
| save_checkpoint( |
| checkpoint_dir, |
| step, |
| orig_model.state_dict(), |
| optimizer.state_dict(), |
| { |
| "step": step, |
| "val_bpb": val_bpb, |
| "model_config": model_config_kwargs, |
| "user_config": user_config, |
| "device_batch_size": args.device_batch_size, |
| "max_seq_len": args.max_seq_len, |
| "total_batch_size": total_batch_size, |
| "dataloader_state_dict": dataloader_state_dict, |
| "loop_state": { |
| "min_val_bpb": min_val_bpb, |
| "smooth_train_loss": smooth_train_loss, |
| "total_training_time": total_training_time, |
| }, |
| }, |
| rank=ddp_rank, |
| ) |
|
|
| |
| if last_step: |
| break |
|
|
| |
| |
| |
| synchronize() |
| t0 = time.time() |
| for micro_step in range(grad_accum_steps): |
| loss = model(x, y) |
| train_loss = loss.detach() |
| loss = loss / grad_accum_steps |
| if scaler is not None: |
| scaler.scale(loss).backward() |
| else: |
| loss.backward() |
| x, y, dataloader_state_dict = next(train_loader) |
| |
| lrm = get_lr_multiplier(step) |
| muon_momentum = get_muon_momentum(step) |
| muon_weight_decay = get_weight_decay(step) |
| for group in optimizer.param_groups: |
| group["lr"] = group["initial_lr"] * lrm |
| if group['kind'] == 'muon': |
| group["momentum"] = muon_momentum |
| group["weight_decay"] = muon_weight_decay |
| if scaler is not None: |
| scaler.unscale_(optimizer) |
| |
| |
| |
| if is_ddp_initialized(): |
| for v in scaler._found_inf_per_device(optimizer).values(): |
| dist.all_reduce(v, op=dist.ReduceOp.MAX) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| optimizer.step() |
| model.zero_grad(set_to_none=True) |
| train_loss_f = train_loss.item() |
| synchronize() |
| t1 = time.time() |
| dt = t1 - t0 |
| |
|
|
| |
| ema_beta = 0.9 |
| smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f |
| debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) |
| pct_done = 100 * step / num_iterations |
| tok_per_sec = int(total_batch_size / dt) |
| flops_per_sec = num_flops_per_token * total_batch_size / dt |
| mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) |
| if step > 10: |
| total_training_time += dt |
| |
| steps_done = step - 10 |
| if steps_done > 0: |
| avg_time_per_step = total_training_time / steps_done |
| remaining_steps = num_iterations - step |
| eta_seconds = remaining_steps * avg_time_per_step |
| eta_str = f" | eta: {eta_seconds/60:.1f}m" |
| else: |
| eta_str = "" |
| epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}" |
| print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") |
| if step % 100 == 0: |
| log_data = { |
| "step": step, |
| "total_training_flops": flops_so_far, |
| "total_training_time": total_training_time, |
| "train/loss": debiased_smooth_loss, |
| "train/lrm": lrm, |
| "train/dt": dt, |
| "train/tok_per_sec": tok_per_sec, |
| "train/mfu": mfu, |
| "train/epoch": epoch, |
| } |
| wandb_run.log(log_data) |
|
|
| |
| first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) |
| step += 1 |
|
|
| |
| |
| |
| if first_step_of_run: |
| gc.collect() |
| gc.freeze() |
| gc.disable() |
| elif step % 5000 == 0: |
| gc.collect() |
|
|
| |
| print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") |
| print0(f"Total training time: {total_training_time/60:.2f}m") |
| if val_bpb is not None: |
| print0(f"Minimum validation bpb: {min_val_bpb:.6f}") |
|
|
| |
| from nanochat.report import get_report |
| get_report().log(section="Base model training", data=[ |
| user_config, |
| { |
| "Number of parameters": num_params, |
| "Number of FLOPs per token": f"{num_flops_per_token:e}", |
| "Calculated number of iterations": num_iterations, |
| "Number of training tokens": total_tokens, |
| "Tokens : Scaling params ratio": total_batch_size * num_iterations / num_scaling_params, |
| "DDP world size": ddp_world_size, |
| "warmup_steps": args.warmup_steps, |
| "warmdown_ratio": args.warmdown_ratio, |
| "final_lr_frac": args.final_lr_frac, |
| }, |
| { |
| "Minimum validation bpb": min_val_bpb if val_bpb is not None else None, |
| "Final validation bpb": val_bpb, |
| "CORE metric estimate": results.get("core_metric", None), |
| "MFU %": f"{mfu:.2f}%", |
| "Total training flops": f"{flops_so_far:e}", |
| "Total training time": f"{total_training_time/60:.2f}m", |
| "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB", |
| } |
| ]) |
|
|
| |
| wandb_run.finish() |
| compute_cleanup() |
|
|