""" train.py — SLLM Training Loop Supports: --max_steps N Run for exactly N steps then save checkpoint and exit. Omit to train indefinitely (until Ctrl+C or data exhausted). --resume Resume from the latest checkpoint in --run_dir. --config 100M|150M Choose model config (default: 100M). --synthetic Use synthetic data (for testing without real shards). Features: - bf16 mixed precision (autocast) + GradScaler for stable training - Gradient accumulation: --grad_accum N steps per optimizer update - Gradient checkpointing: --grad_checkpoint to save VRAM - Cosine LR schedule with linear warmup - Checkpoint save every --save_every steps (and on clean exit/Ctrl+C) - Metric logging to /train_log.jsonl (one JSON line per log step) - Real-time terminal progress with tqdm Recommended for RTX 3050 4GB: python train.py --config 100M --batch_size 4 --grad_accum 8 \\ --grad_checkpoint --max_steps 1000 Run for N steps, stop, then resume: python train.py --max_steps 500 --run_dir runs/my_run python train.py --max_steps 500 --run_dir runs/my_run --resume """ import os import sys import json import math import time import signal import argparse import torch import torch.nn.functional as F from torch.amp import autocast, GradScaler from tqdm import tqdm sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model.config import SLLM_100M, SLLM_150M, ModelConfig from model.model import SLLM from data.dataloader import build_dataloader # ------------------------------------------------------------------ # # ARG PARSING # ------------------------------------------------------------------ # def parse_args(): p = argparse.ArgumentParser(description="SLLM Training Loop") # Run management p.add_argument("--run_dir", type=str, default="runs/run_001", help="Directory for checkpoints and logs") p.add_argument("--run_name", type=str, default=None, help="Override run name (defaults to run_dir basename)") p.add_argument("--resume", action="store_true", help="Resume from latest checkpoint in run_dir") p.add_argument("--max_steps", type=int, default=None, help="Absolute step target — stop when step reaches this number.") p.add_argument("--extra_steps", type=int, default=None, help="Run N MORE steps from current checkpoint (relative). Converted to --max_steps internally.") # Model p.add_argument("--config", type=str, default="100M", choices=["100M", "150M"]) # Data p.add_argument("--data_dir", type=str, default="tokenizer/data") p.add_argument("--synthetic", action="store_true", help="Use synthetic random data (for testing)") p.add_argument("--num_workers",type=int, default=2) # Training p.add_argument("--batch_size", type=int, default=4, help="Per-device batch size") p.add_argument("--grad_accum", type=int, default=8, help="Gradient accumulation steps") p.add_argument("--max_lr", type=float, default=3e-4) p.add_argument("--min_lr", type=float, default=3e-5) p.add_argument("--warmup_steps", type=int, default=100) p.add_argument("--weight_decay", type=float, default=0.1) p.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping norm (0 = disabled)") # Memory p.add_argument("--grad_checkpoint", action="store_true", help="Enable gradient checkpointing (saves VRAM, slower)") p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) # Logging / Saving p.add_argument("--log_every", type=int, default=10, help="Log metrics every N optimizer steps") p.add_argument("--save_every", type=int, default=500, help="Save checkpoint every N optimizer steps") p.add_argument("--val_every", type=int, default=250, help="Run validation every N optimizer steps") p.add_argument("--val_steps", type=int, default=20, help="Number of val batches to average") return p.parse_args() # ------------------------------------------------------------------ # # LEARNING RATE SCHEDULE # ------------------------------------------------------------------ # def get_lr(step: int, warmup_steps: int, total_steps: int, max_lr: float, min_lr: float) -> float: """ Linear warmup then cosine decay. If total_steps is None (training indefinitely), uses a fixed 10k step decay window. """ # Linear warmup if step < warmup_steps: return max_lr * (step + 1) / warmup_steps # After decay: hold at min_lr decay_steps = total_steps if total_steps else 10_000 if step >= decay_steps: return min_lr # Cosine decay progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps) coeff = 0.5 * (1.0 + math.cos(math.pi * progress)) return min_lr + coeff * (max_lr - min_lr) # ------------------------------------------------------------------ # # OPTIMIZER (AdamW with selective weight decay) # ------------------------------------------------------------------ # def build_optimizer(model: SLLM, lr: float, weight_decay: float) -> torch.optim.AdamW: """ AdamW with weight decay applied only to 2D params (Linear weights). Excludes: embeddings, norms (RMSNorm weight vectors), biases. This is the standard approach from GPT-2/NanoGPT. """ decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # 2D tensors (weight matrices) get weight decay if param.dim() >= 2: decay_params.append(param) else: # 1D: norm weights, biases, embeddings no_decay_params.append(param) optim_groups = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ] n_decay = sum(p.numel() for p in decay_params) n_no_decay = sum(p.numel() for p in no_decay_params) print(f" Optimizer: {n_decay/1e6:.1f}M decay params | {n_no_decay/1e6:.1f}M no-decay params") return torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=True) # ------------------------------------------------------------------ # # CHECKPOINT SAVE / LOAD # ------------------------------------------------------------------ # def save_checkpoint(path: str, model: SLLM, optimizer, step: int, args, loss: float): os.makedirs(os.path.dirname(path), exist_ok=True) torch.save({ "step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, "config_name": args.config, }, path) print(f"\n [CKPT] Saved checkpoint: {path} (step={step}, loss={loss:.4f})") def load_checkpoint(run_dir: str, model: SLLM, optimizer, device): """Loads the latest checkpoint from run_dir. Returns step number.""" ckpts = sorted([ f for f in os.listdir(run_dir) if f.startswith("ckpt_") and f.endswith(".pt") ]) if not ckpts: raise FileNotFoundError(f"No checkpoints found in {run_dir}") path = os.path.join(run_dir, ckpts[-1]) ckpt = torch.load(path, map_location=device, weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) step = ckpt["step"] loss = ckpt.get("loss", float("nan")) print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})") return step # ------------------------------------------------------------------ # # VALIDATION # ------------------------------------------------------------------ # @torch.no_grad() def estimate_val_loss(model, val_loader, val_steps: int, device, dtype_ctx) -> float: model.eval() losses = [] for i, (x, y) in enumerate(val_loader): if i >= val_steps: break x, y = x.to(device), y.to(device) with dtype_ctx: _, loss = model(x, y) losses.append(loss.item()) model.train() return sum(losses) / len(losses) if losses else float("nan") # ------------------------------------------------------------------ # # METRIC LOGGING # ------------------------------------------------------------------ # class MetricLogger: """Appends one JSON line per step to train_log.jsonl.""" def __init__(self, log_path: str): self.log_path = log_path os.makedirs(os.path.dirname(log_path), exist_ok=True) # Don't clear existing log when resuming — append print(f" [LOG] Logging to: {log_path}") def log(self, **kwargs): with open(self.log_path, "a") as f: f.write(json.dumps(kwargs) + "\n") # ------------------------------------------------------------------ # # MAIN TRAINING LOOP # ------------------------------------------------------------------ # def train(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\nDevice : {device}") if device.type == "cuda": print(f"GPU : {torch.cuda.get_device_name(0)}") print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") # ---- dtype context --------------------------------------------- # if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported(): dtype_torch = torch.bfloat16 dtype_name = "bf16" elif args.dtype == "fp16" and device.type == "cuda": dtype_torch = torch.float16 dtype_name = "fp16" else: dtype_torch = torch.float32 dtype_name = "fp32" print(f"dtype : {dtype_name}") use_amp = dtype_torch in (torch.float16, torch.bfloat16) dtype_ctx = autocast(device_type=device.type, dtype=dtype_torch) if use_amp else torch.no_grad().__class__() scaler = GradScaler(enabled=(dtype_torch == torch.float16)) # bf16 doesn't need scaler # ---- Auto-detect config on resume ------------------------------ # if args.resume: try: ckpts = sorted([ f for f in os.listdir(args.run_dir) if f.startswith("ckpt_") and f.endswith(".pt") ]) if ckpts: ckpt_path = os.path.join(args.run_dir, ckpts[-1]) _tmp_ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) if "config_name" in _tmp_ckpt and _tmp_ckpt["config_name"] != args.config: print(f" [CKPT] Auto-switching config from '{args.config}' to '{_tmp_ckpt['config_name']}' to match checkpoint.") args.config = _tmp_ckpt["config_name"] del _tmp_ckpt except Exception: pass # ---- Model ----------------------------------------------------- # cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M} cfg = cfg_map[args.config] model = SLLM(cfg).to(device) if args.grad_checkpoint: model.enable_gradient_checkpointing() print(" Gradient checkpointing: ON") print(f"\nModel : SLLM-{args.config} ({model.count_params()/1e6:.1f}M params)") print(f"Config : {cfg}") # ---- Optimizer ------------------------------------------------- # optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay) # ---- Data ------------------------------------------------------ # train_loader = build_dataloader( data_dir = args.data_dir, split = "train", context_length = cfg.context_length, batch_size = args.batch_size, num_workers = args.num_workers, use_synthetic = args.synthetic, vocab_size = cfg.vocab_size, ) val_loader = build_dataloader( data_dir = args.data_dir, split = "val", context_length = cfg.context_length, batch_size = args.batch_size, num_workers = 0, use_synthetic = args.synthetic, vocab_size = cfg.vocab_size, ) # ---- Run directory --------------------------------------------- # os.makedirs(args.run_dir, exist_ok=True) log_path = os.path.join(args.run_dir, "train_log.jsonl") logger = MetricLogger(log_path) # ---- Resume ---------------------------------------------------- # start_step = 0 if args.resume: try: start_step = load_checkpoint(args.run_dir, model, optimizer, device) except FileNotFoundError as e: print(f" [WARN] {e} — starting from scratch.") # ---- Effective batch size info --------------------------------- # eff_batch = args.batch_size * args.grad_accum tokens_per_step = eff_batch * cfg.context_length print(f"\nTraining:") # ---- Resolve extra_steps -> max_steps -------------------------- # if args.extra_steps is not None: if args.max_steps is not None: print(" [WARN] Both --extra_steps and --max_steps given. --extra_steps takes priority.") args.max_steps = start_step + args.extra_steps print(f" [INFO] --extra_steps {args.extra_steps} → running until step {args.max_steps}") print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} -> effective={eff_batch})") print(f" tokens/step : {tokens_per_step:,}") print(f" max_steps : {args.max_steps or 'unlimited'} (absolute step target)") print(f" start_step : {start_step}") print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else 'unlimited'}") print(f" save_every : {args.save_every}") print(f" log_every : {args.log_every}") # ---- Early exit if already past max_steps ---------------------- # if args.max_steps is not None and start_step >= args.max_steps: print(f"\n [WARN] start_step ({start_step}) >= max_steps ({args.max_steps}).") print(f" Nothing to train. Use --extra_steps N to run N more steps.") print(f"\nExample: python train.py --resume --run_dir {args.run_dir} --extra_steps 5000") return # ---- Graceful Ctrl+C handler ----------------------------------- # stop_flag = {"stop": False} def _signal_handler(sig, frame): print("\n [SIGNAL] Ctrl+C received — will save checkpoint and exit after current step.") stop_flag["stop"] = True signal.signal(signal.SIGINT, _signal_handler) # ---- Training loop --------------------------------------------- # model.train() step = start_step micro_step = 0 # within grad_accum window running_loss = 0.0 # accumulated for logging t_start = time.time() t_step_start = time.time() data_iter = iter(train_loader) print(f"\n{'='*60}") print(f" TRAINING STARTED (step {step} -> {args.max_steps or '∞'})") print(f"{'='*60}\n") pbar = tqdm( initial=step, total=args.max_steps, desc="Training", unit="step", dynamic_ncols=True, ) while True: # ---- Stop conditions --------------------------------------- # if stop_flag["stop"]: break if args.max_steps is not None and step >= args.max_steps: print(f"\n [DONE] Reached max_steps={args.max_steps}") break optimizer.zero_grad(set_to_none=True) accum_loss = 0.0 # ---- Gradient accumulation micro-steps --------------------- # for micro in range(args.grad_accum): # Get next batch try: x, y = next(data_iter) except StopIteration: data_iter = iter(train_loader) x, y = next(data_iter) x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) # Forward + loss (inside AMP context) with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp): logits, loss = model(x, y) # Scale loss by grad_accum so gradients average correctly loss = loss / args.grad_accum # Backward scaler.scale(loss).backward() accum_loss += loss.item() # ---- Gradient clipping ------------------------------------- # if args.grad_clip > 0: scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) else: grad_norm = float("nan") # ---- LR update --------------------------------------------- # lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr) for pg in optimizer.param_groups: pg["lr"] = lr # ---- Optimizer step ---------------------------------------- # scaler.step(optimizer) scaler.update() step += 1 running_loss = accum_loss # loss for this step # ---- Tokens per second ------------------------------------- # t_now = time.time() elapsed = t_now - t_step_start t_step_start = t_now tok_per_sec = tokens_per_step / max(elapsed, 1e-6) # ---- Progress bar update ----------------------------------- # pbar.update(1) pbar.set_postfix({ "loss": f"{running_loss:.4f}", "lr": f"{lr:.2e}", "tok/s": f"{tok_per_sec:.0f}", }) # ---- Logging ----------------------------------------------- # if step % args.log_every == 0: log_entry = { "step": step, "loss": round(running_loss, 6), "lr": lr, "grad_norm": round(float(grad_norm), 4) if not math.isnan(float(grad_norm)) else None, "tok_per_sec": round(tok_per_sec, 1), "elapsed_s": round(t_now - t_start, 1), } if device.type == "cuda": log_entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3) logger.log(**log_entry) # ---- Validation -------------------------------------------- # if step % args.val_every == 0: val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp)) tqdm.write(f" [STEP {step:6d}] train_loss={running_loss:.4f} val_loss={val_loss:.4f} lr={lr:.2e}") logger.log(step=step, val_loss=round(val_loss, 6)) # ---- Checkpoint -------------------------------------------- # if step % args.save_every == 0: ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt") save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss) # ---- Final checkpoint on exit (only if we actually ran steps) -- # pbar.close() steps_done = step - start_step if steps_done > 0: ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt") save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss) else: print("\n [SKIP] No steps were taken — skipping final checkpoint save.") total_time = time.time() - t_start print(f"\n{'='*60}") print(f" TRAINING COMPLETE") print(f"{'='*60}") print(f" Steps completed : {step - start_step}") print(f" Final loss : {running_loss:.4f}") print(f" Total time : {total_time/60:.1f} min") print(f" Run dir : {args.run_dir}") print(f"\nTo resume: python train.py --resume --run_dir {args.run_dir} --max_steps ") print(f"To plot : python plot_training.py --run_dir {args.run_dir}") if __name__ == "__main__": train()