| """ |
| train.py — SLLM Training Loop |
| |
| Supports: |
| --max_steps N Run for exactly N steps then save checkpoint and exit. |
| Omit to train indefinitely (until Ctrl+C or data exhausted). |
| --resume Resume from the latest checkpoint in --run_dir. |
| --config 100M|150M Choose model config (default: 100M). |
| --synthetic Use synthetic data (for testing without real shards). |
| |
| Features: |
| - bf16 mixed precision (autocast) + GradScaler for stable training |
| - Gradient accumulation: --grad_accum N steps per optimizer update |
| - Gradient checkpointing: --grad_checkpoint to save VRAM |
| - Cosine LR schedule with linear warmup |
| - Checkpoint save every --save_every steps (and on clean exit/Ctrl+C) |
| - Metric logging to <run_dir>/train_log.jsonl (one JSON line per log step) |
| - Real-time terminal progress with tqdm |
| |
| Recommended for RTX 3050 4GB: |
| python train.py --config 100M --batch_size 4 --grad_accum 8 \\ |
| --grad_checkpoint --max_steps 1000 |
| |
| Run for N steps, stop, then resume: |
| python train.py --max_steps 500 --run_dir runs/my_run |
| python train.py --max_steps 500 --run_dir runs/my_run --resume |
| """ |
|
|
| import os |
| import sys |
| import json |
| import math |
| import time |
| import signal |
| import argparse |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.amp import autocast, GradScaler |
| from tqdm import tqdm |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from model.config import SLLM_100M, SLLM_150M, ModelConfig |
| from model.model import SLLM |
| from data.dataloader import build_dataloader |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="SLLM Training Loop") |
|
|
| |
| p.add_argument("--run_dir", type=str, default="runs/run_001", help="Directory for checkpoints and logs") |
| p.add_argument("--run_name", type=str, default=None, help="Override run name (defaults to run_dir basename)") |
| p.add_argument("--resume", action="store_true", help="Resume from latest checkpoint in run_dir") |
| p.add_argument("--max_steps", type=int, default=None, help="Absolute step target — stop when step reaches this number.") |
| p.add_argument("--extra_steps", type=int, default=None, help="Run N MORE steps from current checkpoint (relative). Converted to --max_steps internally.") |
|
|
| |
| p.add_argument("--config", type=str, default="100M", choices=["100M", "150M"]) |
|
|
| |
| p.add_argument("--data_dir", type=str, default="tokenizer/data") |
| p.add_argument("--synthetic", action="store_true", help="Use synthetic random data (for testing)") |
| p.add_argument("--num_workers",type=int, default=2) |
|
|
| |
| p.add_argument("--batch_size", type=int, default=4, help="Per-device batch size") |
| p.add_argument("--grad_accum", type=int, default=8, help="Gradient accumulation steps") |
| p.add_argument("--max_lr", type=float, default=3e-4) |
| p.add_argument("--min_lr", type=float, default=3e-5) |
| p.add_argument("--warmup_steps", type=int, default=100) |
| p.add_argument("--weight_decay", type=float, default=0.1) |
| p.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping norm (0 = disabled)") |
|
|
| |
| p.add_argument("--grad_checkpoint", action="store_true", help="Enable gradient checkpointing (saves VRAM, slower)") |
| p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) |
|
|
| |
| p.add_argument("--log_every", type=int, default=10, help="Log metrics every N optimizer steps") |
| p.add_argument("--save_every", type=int, default=500, help="Save checkpoint every N optimizer steps") |
| p.add_argument("--val_every", type=int, default=250, help="Run validation every N optimizer steps") |
| p.add_argument("--val_steps", type=int, default=20, help="Number of val batches to average") |
|
|
| return p.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def get_lr(step: int, warmup_steps: int, total_steps: int, max_lr: float, min_lr: float) -> float: |
| """ |
| Linear warmup then cosine decay. |
| If total_steps is None (training indefinitely), uses a fixed 10k step decay window. |
| """ |
| |
| if step < warmup_steps: |
| return max_lr * (step + 1) / warmup_steps |
|
|
| |
| decay_steps = total_steps if total_steps else 10_000 |
| if step >= decay_steps: |
| return min_lr |
|
|
| |
| progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps) |
| coeff = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| return min_lr + coeff * (max_lr - min_lr) |
|
|
|
|
| |
| |
| |
|
|
| def build_optimizer(model: SLLM, lr: float, weight_decay: float) -> torch.optim.AdamW: |
| """ |
| AdamW with weight decay applied only to 2D params (Linear weights). |
| Excludes: embeddings, norms (RMSNorm weight vectors), biases. |
| |
| This is the standard approach from GPT-2/NanoGPT. |
| """ |
| decay_params = [] |
| no_decay_params = [] |
|
|
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| |
| if param.dim() >= 2: |
| decay_params.append(param) |
| else: |
| |
| no_decay_params.append(param) |
|
|
| optim_groups = [ |
| {"params": decay_params, "weight_decay": weight_decay}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ] |
|
|
| n_decay = sum(p.numel() for p in decay_params) |
| n_no_decay = sum(p.numel() for p in no_decay_params) |
| print(f" Optimizer: {n_decay/1e6:.1f}M decay params | {n_no_decay/1e6:.1f}M no-decay params") |
|
|
| return torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=True) |
|
|
|
|
| |
| |
| |
|
|
| def save_checkpoint(path: str, model: SLLM, optimizer, step: int, args, loss: float): |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| torch.save({ |
| "step": step, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "loss": loss, |
| "config_name": args.config, |
| }, path) |
| print(f"\n [CKPT] Saved checkpoint: {path} (step={step}, loss={loss:.4f})") |
|
|
|
|
| def load_checkpoint(run_dir: str, model: SLLM, optimizer, device): |
| """Loads the latest checkpoint from run_dir. Returns step number.""" |
| ckpts = sorted([ |
| f for f in os.listdir(run_dir) |
| if f.startswith("ckpt_") and f.endswith(".pt") |
| ]) |
| if not ckpts: |
| raise FileNotFoundError(f"No checkpoints found in {run_dir}") |
|
|
| path = os.path.join(run_dir, ckpts[-1]) |
| ckpt = torch.load(path, map_location=device, weights_only=False) |
|
|
| model.load_state_dict(ckpt["model_state_dict"]) |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
|
|
| step = ckpt["step"] |
| loss = ckpt.get("loss", float("nan")) |
| print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})") |
| return step |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def estimate_val_loss(model, val_loader, val_steps: int, device, dtype_ctx) -> float: |
| model.eval() |
| losses = [] |
| for i, (x, y) in enumerate(val_loader): |
| if i >= val_steps: |
| break |
| x, y = x.to(device), y.to(device) |
| with dtype_ctx: |
| _, loss = model(x, y) |
| losses.append(loss.item()) |
| model.train() |
| return sum(losses) / len(losses) if losses else float("nan") |
|
|
|
|
| |
| |
| |
|
|
| class MetricLogger: |
| """Appends one JSON line per step to train_log.jsonl.""" |
|
|
| def __init__(self, log_path: str): |
| self.log_path = log_path |
| os.makedirs(os.path.dirname(log_path), exist_ok=True) |
| |
| print(f" [LOG] Logging to: {log_path}") |
|
|
| def log(self, **kwargs): |
| with open(self.log_path, "a") as f: |
| f.write(json.dumps(kwargs) + "\n") |
|
|
|
|
| |
| |
| |
|
|
| def train(): |
| args = parse_args() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\nDevice : {device}") |
| if device.type == "cuda": |
| print(f"GPU : {torch.cuda.get_device_name(0)}") |
| print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
| |
| if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported(): |
| dtype_torch = torch.bfloat16 |
| dtype_name = "bf16" |
| elif args.dtype == "fp16" and device.type == "cuda": |
| dtype_torch = torch.float16 |
| dtype_name = "fp16" |
| else: |
| dtype_torch = torch.float32 |
| dtype_name = "fp32" |
|
|
| print(f"dtype : {dtype_name}") |
| use_amp = dtype_torch in (torch.float16, torch.bfloat16) |
| dtype_ctx = autocast(device_type=device.type, dtype=dtype_torch) if use_amp else torch.no_grad().__class__() |
| scaler = GradScaler(enabled=(dtype_torch == torch.float16)) |
|
|
| |
| if args.resume: |
| try: |
| ckpts = sorted([ |
| f for f in os.listdir(args.run_dir) |
| if f.startswith("ckpt_") and f.endswith(".pt") |
| ]) |
| if ckpts: |
| ckpt_path = os.path.join(args.run_dir, ckpts[-1]) |
| _tmp_ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| if "config_name" in _tmp_ckpt and _tmp_ckpt["config_name"] != args.config: |
| print(f" [CKPT] Auto-switching config from '{args.config}' to '{_tmp_ckpt['config_name']}' to match checkpoint.") |
| args.config = _tmp_ckpt["config_name"] |
| del _tmp_ckpt |
| except Exception: |
| pass |
|
|
| |
| cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M} |
| cfg = cfg_map[args.config] |
| model = SLLM(cfg).to(device) |
|
|
| if args.grad_checkpoint: |
| model.enable_gradient_checkpointing() |
| print(" Gradient checkpointing: ON") |
|
|
| print(f"\nModel : SLLM-{args.config} ({model.count_params()/1e6:.1f}M params)") |
| print(f"Config : {cfg}") |
|
|
| |
| optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay) |
|
|
| |
| train_loader = build_dataloader( |
| data_dir = args.data_dir, |
| split = "train", |
| context_length = cfg.context_length, |
| batch_size = args.batch_size, |
| num_workers = args.num_workers, |
| use_synthetic = args.synthetic, |
| vocab_size = cfg.vocab_size, |
| ) |
| val_loader = build_dataloader( |
| data_dir = args.data_dir, |
| split = "val", |
| context_length = cfg.context_length, |
| batch_size = args.batch_size, |
| num_workers = 0, |
| use_synthetic = args.synthetic, |
| vocab_size = cfg.vocab_size, |
| ) |
|
|
| |
| os.makedirs(args.run_dir, exist_ok=True) |
| log_path = os.path.join(args.run_dir, "train_log.jsonl") |
| logger = MetricLogger(log_path) |
|
|
| |
| start_step = 0 |
| if args.resume: |
| try: |
| start_step = load_checkpoint(args.run_dir, model, optimizer, device) |
| except FileNotFoundError as e: |
| print(f" [WARN] {e} — starting from scratch.") |
|
|
| |
| eff_batch = args.batch_size * args.grad_accum |
| tokens_per_step = eff_batch * cfg.context_length |
| print(f"\nTraining:") |
| |
| if args.extra_steps is not None: |
| if args.max_steps is not None: |
| print(" [WARN] Both --extra_steps and --max_steps given. --extra_steps takes priority.") |
| args.max_steps = start_step + args.extra_steps |
| print(f" [INFO] --extra_steps {args.extra_steps} → running until step {args.max_steps}") |
|
|
| print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} -> effective={eff_batch})") |
| print(f" tokens/step : {tokens_per_step:,}") |
| print(f" max_steps : {args.max_steps or 'unlimited'} (absolute step target)") |
| print(f" start_step : {start_step}") |
| print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else 'unlimited'}") |
| print(f" save_every : {args.save_every}") |
| print(f" log_every : {args.log_every}") |
|
|
| |
| if args.max_steps is not None and start_step >= args.max_steps: |
| print(f"\n [WARN] start_step ({start_step}) >= max_steps ({args.max_steps}).") |
| print(f" Nothing to train. Use --extra_steps N to run N more steps.") |
| print(f"\nExample: python train.py --resume --run_dir {args.run_dir} --extra_steps 5000") |
| return |
|
|
| |
| stop_flag = {"stop": False} |
| def _signal_handler(sig, frame): |
| print("\n [SIGNAL] Ctrl+C received — will save checkpoint and exit after current step.") |
| stop_flag["stop"] = True |
| signal.signal(signal.SIGINT, _signal_handler) |
|
|
| |
| model.train() |
| step = start_step |
| micro_step = 0 |
| running_loss = 0.0 |
| t_start = time.time() |
| t_step_start = time.time() |
| data_iter = iter(train_loader) |
|
|
| print(f"\n{'='*60}") |
| print(f" TRAINING STARTED (step {step} -> {args.max_steps or '∞'})") |
| print(f"{'='*60}\n") |
|
|
| pbar = tqdm( |
| initial=step, |
| total=args.max_steps, |
| desc="Training", |
| unit="step", |
| dynamic_ncols=True, |
| ) |
|
|
| while True: |
| |
| if stop_flag["stop"]: |
| break |
| if args.max_steps is not None and step >= args.max_steps: |
| print(f"\n [DONE] Reached max_steps={args.max_steps}") |
| break |
|
|
| optimizer.zero_grad(set_to_none=True) |
| accum_loss = 0.0 |
|
|
| |
| for micro in range(args.grad_accum): |
| |
| try: |
| x, y = next(data_iter) |
| except StopIteration: |
| data_iter = iter(train_loader) |
| x, y = next(data_iter) |
|
|
| x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) |
|
|
| |
| with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp): |
| logits, loss = model(x, y) |
| |
| loss = loss / args.grad_accum |
|
|
| |
| scaler.scale(loss).backward() |
| accum_loss += loss.item() |
|
|
| |
| if args.grad_clip > 0: |
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| else: |
| grad_norm = float("nan") |
|
|
| |
| lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| step += 1 |
| running_loss = accum_loss |
|
|
| |
| t_now = time.time() |
| elapsed = t_now - t_step_start |
| t_step_start = t_now |
| tok_per_sec = tokens_per_step / max(elapsed, 1e-6) |
|
|
| |
| pbar.update(1) |
| pbar.set_postfix({ |
| "loss": f"{running_loss:.4f}", |
| "lr": f"{lr:.2e}", |
| "tok/s": f"{tok_per_sec:.0f}", |
| }) |
|
|
| |
| if step % args.log_every == 0: |
| log_entry = { |
| "step": step, |
| "loss": round(running_loss, 6), |
| "lr": lr, |
| "grad_norm": round(float(grad_norm), 4) if not math.isnan(float(grad_norm)) else None, |
| "tok_per_sec": round(tok_per_sec, 1), |
| "elapsed_s": round(t_now - t_start, 1), |
| } |
| if device.type == "cuda": |
| log_entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3) |
| logger.log(**log_entry) |
|
|
| |
| if step % args.val_every == 0: |
| val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp)) |
| tqdm.write(f" [STEP {step:6d}] train_loss={running_loss:.4f} val_loss={val_loss:.4f} lr={lr:.2e}") |
| logger.log(step=step, val_loss=round(val_loss, 6)) |
|
|
| |
| if step % args.save_every == 0: |
| ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt") |
| save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss) |
|
|
| |
| pbar.close() |
| steps_done = step - start_step |
| if steps_done > 0: |
| ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt") |
| save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss) |
| else: |
| print("\n [SKIP] No steps were taken — skipping final checkpoint save.") |
|
|
| total_time = time.time() - t_start |
| print(f"\n{'='*60}") |
| print(f" TRAINING COMPLETE") |
| print(f"{'='*60}") |
| print(f" Steps completed : {step - start_step}") |
| print(f" Final loss : {running_loss:.4f}") |
| print(f" Total time : {total_time/60:.1f} min") |
| print(f" Run dir : {args.run_dir}") |
| print(f"\nTo resume: python train.py --resume --run_dir {args.run_dir} --max_steps <N>") |
| print(f"To plot : python plot_training.py --run_dir {args.run_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|