sllm / train.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
train.py — SLLM Training Loop
Supports:
--max_steps N Run for exactly N steps then save checkpoint and exit.
Omit to train indefinitely (until Ctrl+C or data exhausted).
--resume Resume from the latest checkpoint in --run_dir.
--config 100M|150M Choose model config (default: 100M).
--synthetic Use synthetic data (for testing without real shards).
Features:
- bf16 mixed precision (autocast) + GradScaler for stable training
- Gradient accumulation: --grad_accum N steps per optimizer update
- Gradient checkpointing: --grad_checkpoint to save VRAM
- Cosine LR schedule with linear warmup
- Checkpoint save every --save_every steps (and on clean exit/Ctrl+C)
- Metric logging to <run_dir>/train_log.jsonl (one JSON line per log step)
- Real-time terminal progress with tqdm
Recommended for RTX 3050 4GB:
python train.py --config 100M --batch_size 4 --grad_accum 8 \\
--grad_checkpoint --max_steps 1000
Run for N steps, stop, then resume:
python train.py --max_steps 500 --run_dir runs/my_run
python train.py --max_steps 500 --run_dir runs/my_run --resume
"""
import os
import sys
import json
import math
import time
import signal
import argparse
import torch
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import SLLM_100M, SLLM_150M, ModelConfig
from model.model import SLLM
from data.dataloader import build_dataloader
# ------------------------------------------------------------------ #
# ARG PARSING
# ------------------------------------------------------------------ #
def parse_args():
p = argparse.ArgumentParser(description="SLLM Training Loop")
# Run management
p.add_argument("--run_dir", type=str, default="runs/run_001", help="Directory for checkpoints and logs")
p.add_argument("--run_name", type=str, default=None, help="Override run name (defaults to run_dir basename)")
p.add_argument("--resume", action="store_true", help="Resume from latest checkpoint in run_dir")
p.add_argument("--max_steps", type=int, default=None, help="Absolute step target — stop when step reaches this number.")
p.add_argument("--extra_steps", type=int, default=None, help="Run N MORE steps from current checkpoint (relative). Converted to --max_steps internally.")
# Model
p.add_argument("--config", type=str, default="100M", choices=["100M", "150M"])
# Data
p.add_argument("--data_dir", type=str, default="tokenizer/data")
p.add_argument("--synthetic", action="store_true", help="Use synthetic random data (for testing)")
p.add_argument("--num_workers",type=int, default=2)
# Training
p.add_argument("--batch_size", type=int, default=4, help="Per-device batch size")
p.add_argument("--grad_accum", type=int, default=8, help="Gradient accumulation steps")
p.add_argument("--max_lr", type=float, default=3e-4)
p.add_argument("--min_lr", type=float, default=3e-5)
p.add_argument("--warmup_steps", type=int, default=100)
p.add_argument("--weight_decay", type=float, default=0.1)
p.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping norm (0 = disabled)")
# Memory
p.add_argument("--grad_checkpoint", action="store_true", help="Enable gradient checkpointing (saves VRAM, slower)")
p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
# Logging / Saving
p.add_argument("--log_every", type=int, default=10, help="Log metrics every N optimizer steps")
p.add_argument("--save_every", type=int, default=500, help="Save checkpoint every N optimizer steps")
p.add_argument("--val_every", type=int, default=250, help="Run validation every N optimizer steps")
p.add_argument("--val_steps", type=int, default=20, help="Number of val batches to average")
return p.parse_args()
# ------------------------------------------------------------------ #
# LEARNING RATE SCHEDULE
# ------------------------------------------------------------------ #
def get_lr(step: int, warmup_steps: int, total_steps: int, max_lr: float, min_lr: float) -> float:
"""
Linear warmup then cosine decay.
If total_steps is None (training indefinitely), uses a fixed 10k step decay window.
"""
# Linear warmup
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
# After decay: hold at min_lr
decay_steps = total_steps if total_steps else 10_000
if step >= decay_steps:
return min_lr
# Cosine decay
progress = (step - warmup_steps) / max(1, decay_steps - warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr + coeff * (max_lr - min_lr)
# ------------------------------------------------------------------ #
# OPTIMIZER (AdamW with selective weight decay)
# ------------------------------------------------------------------ #
def build_optimizer(model: SLLM, lr: float, weight_decay: float) -> torch.optim.AdamW:
"""
AdamW with weight decay applied only to 2D params (Linear weights).
Excludes: embeddings, norms (RMSNorm weight vectors), biases.
This is the standard approach from GPT-2/NanoGPT.
"""
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# 2D tensors (weight matrices) get weight decay
if param.dim() >= 2:
decay_params.append(param)
else:
# 1D: norm weights, biases, embeddings
no_decay_params.append(param)
optim_groups = [
{"params": decay_params, "weight_decay": weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
n_decay = sum(p.numel() for p in decay_params)
n_no_decay = sum(p.numel() for p in no_decay_params)
print(f" Optimizer: {n_decay/1e6:.1f}M decay params | {n_no_decay/1e6:.1f}M no-decay params")
return torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=True)
# ------------------------------------------------------------------ #
# CHECKPOINT SAVE / LOAD
# ------------------------------------------------------------------ #
def save_checkpoint(path: str, model: SLLM, optimizer, step: int, args, loss: float):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
"step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
"config_name": args.config,
}, path)
print(f"\n [CKPT] Saved checkpoint: {path} (step={step}, loss={loss:.4f})")
def load_checkpoint(run_dir: str, model: SLLM, optimizer, device):
"""Loads the latest checkpoint from run_dir. Returns step number."""
ckpts = sorted([
f for f in os.listdir(run_dir)
if f.startswith("ckpt_") and f.endswith(".pt")
])
if not ckpts:
raise FileNotFoundError(f"No checkpoints found in {run_dir}")
path = os.path.join(run_dir, ckpts[-1])
ckpt = torch.load(path, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
step = ckpt["step"]
loss = ckpt.get("loss", float("nan"))
print(f" [CKPT] Resumed from: {path} (step={step}, loss={loss:.4f})")
return step
# ------------------------------------------------------------------ #
# VALIDATION
# ------------------------------------------------------------------ #
@torch.no_grad()
def estimate_val_loss(model, val_loader, val_steps: int, device, dtype_ctx) -> float:
model.eval()
losses = []
for i, (x, y) in enumerate(val_loader):
if i >= val_steps:
break
x, y = x.to(device), y.to(device)
with dtype_ctx:
_, loss = model(x, y)
losses.append(loss.item())
model.train()
return sum(losses) / len(losses) if losses else float("nan")
# ------------------------------------------------------------------ #
# METRIC LOGGING
# ------------------------------------------------------------------ #
class MetricLogger:
"""Appends one JSON line per step to train_log.jsonl."""
def __init__(self, log_path: str):
self.log_path = log_path
os.makedirs(os.path.dirname(log_path), exist_ok=True)
# Don't clear existing log when resuming — append
print(f" [LOG] Logging to: {log_path}")
def log(self, **kwargs):
with open(self.log_path, "a") as f:
f.write(json.dumps(kwargs) + "\n")
# ------------------------------------------------------------------ #
# MAIN TRAINING LOOP
# ------------------------------------------------------------------ #
def train():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice : {device}")
if device.type == "cuda":
print(f"GPU : {torch.cuda.get_device_name(0)}")
print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# ---- dtype context --------------------------------------------- #
if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
dtype_torch = torch.bfloat16
dtype_name = "bf16"
elif args.dtype == "fp16" and device.type == "cuda":
dtype_torch = torch.float16
dtype_name = "fp16"
else:
dtype_torch = torch.float32
dtype_name = "fp32"
print(f"dtype : {dtype_name}")
use_amp = dtype_torch in (torch.float16, torch.bfloat16)
dtype_ctx = autocast(device_type=device.type, dtype=dtype_torch) if use_amp else torch.no_grad().__class__()
scaler = GradScaler(enabled=(dtype_torch == torch.float16)) # bf16 doesn't need scaler
# ---- Auto-detect config on resume ------------------------------ #
if args.resume:
try:
ckpts = sorted([
f for f in os.listdir(args.run_dir)
if f.startswith("ckpt_") and f.endswith(".pt")
])
if ckpts:
ckpt_path = os.path.join(args.run_dir, ckpts[-1])
_tmp_ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
if "config_name" in _tmp_ckpt and _tmp_ckpt["config_name"] != args.config:
print(f" [CKPT] Auto-switching config from '{args.config}' to '{_tmp_ckpt['config_name']}' to match checkpoint.")
args.config = _tmp_ckpt["config_name"]
del _tmp_ckpt
except Exception:
pass
# ---- Model ----------------------------------------------------- #
cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M}
cfg = cfg_map[args.config]
model = SLLM(cfg).to(device)
if args.grad_checkpoint:
model.enable_gradient_checkpointing()
print(" Gradient checkpointing: ON")
print(f"\nModel : SLLM-{args.config} ({model.count_params()/1e6:.1f}M params)")
print(f"Config : {cfg}")
# ---- Optimizer ------------------------------------------------- #
optimizer = build_optimizer(model, lr=args.max_lr, weight_decay=args.weight_decay)
# ---- Data ------------------------------------------------------ #
train_loader = build_dataloader(
data_dir = args.data_dir,
split = "train",
context_length = cfg.context_length,
batch_size = args.batch_size,
num_workers = args.num_workers,
use_synthetic = args.synthetic,
vocab_size = cfg.vocab_size,
)
val_loader = build_dataloader(
data_dir = args.data_dir,
split = "val",
context_length = cfg.context_length,
batch_size = args.batch_size,
num_workers = 0,
use_synthetic = args.synthetic,
vocab_size = cfg.vocab_size,
)
# ---- Run directory --------------------------------------------- #
os.makedirs(args.run_dir, exist_ok=True)
log_path = os.path.join(args.run_dir, "train_log.jsonl")
logger = MetricLogger(log_path)
# ---- Resume ---------------------------------------------------- #
start_step = 0
if args.resume:
try:
start_step = load_checkpoint(args.run_dir, model, optimizer, device)
except FileNotFoundError as e:
print(f" [WARN] {e} — starting from scratch.")
# ---- Effective batch size info --------------------------------- #
eff_batch = args.batch_size * args.grad_accum
tokens_per_step = eff_batch * cfg.context_length
print(f"\nTraining:")
# ---- Resolve extra_steps -> max_steps -------------------------- #
if args.extra_steps is not None:
if args.max_steps is not None:
print(" [WARN] Both --extra_steps and --max_steps given. --extra_steps takes priority.")
args.max_steps = start_step + args.extra_steps
print(f" [INFO] --extra_steps {args.extra_steps} → running until step {args.max_steps}")
print(f" batch_size : {args.batch_size} (grad_accum={args.grad_accum} -> effective={eff_batch})")
print(f" tokens/step : {tokens_per_step:,}")
print(f" max_steps : {args.max_steps or 'unlimited'} (absolute step target)")
print(f" start_step : {start_step}")
print(f" steps to run : {(args.max_steps - start_step) if args.max_steps else 'unlimited'}")
print(f" save_every : {args.save_every}")
print(f" log_every : {args.log_every}")
# ---- Early exit if already past max_steps ---------------------- #
if args.max_steps is not None and start_step >= args.max_steps:
print(f"\n [WARN] start_step ({start_step}) >= max_steps ({args.max_steps}).")
print(f" Nothing to train. Use --extra_steps N to run N more steps.")
print(f"\nExample: python train.py --resume --run_dir {args.run_dir} --extra_steps 5000")
return
# ---- Graceful Ctrl+C handler ----------------------------------- #
stop_flag = {"stop": False}
def _signal_handler(sig, frame):
print("\n [SIGNAL] Ctrl+C received — will save checkpoint and exit after current step.")
stop_flag["stop"] = True
signal.signal(signal.SIGINT, _signal_handler)
# ---- Training loop --------------------------------------------- #
model.train()
step = start_step
micro_step = 0 # within grad_accum window
running_loss = 0.0 # accumulated for logging
t_start = time.time()
t_step_start = time.time()
data_iter = iter(train_loader)
print(f"\n{'='*60}")
print(f" TRAINING STARTED (step {step} -> {args.max_steps or '∞'})")
print(f"{'='*60}\n")
pbar = tqdm(
initial=step,
total=args.max_steps,
desc="Training",
unit="step",
dynamic_ncols=True,
)
while True:
# ---- Stop conditions --------------------------------------- #
if stop_flag["stop"]:
break
if args.max_steps is not None and step >= args.max_steps:
print(f"\n [DONE] Reached max_steps={args.max_steps}")
break
optimizer.zero_grad(set_to_none=True)
accum_loss = 0.0
# ---- Gradient accumulation micro-steps --------------------- #
for micro in range(args.grad_accum):
# Get next batch
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
x, y = next(data_iter)
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
# Forward + loss (inside AMP context)
with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
logits, loss = model(x, y)
# Scale loss by grad_accum so gradients average correctly
loss = loss / args.grad_accum
# Backward
scaler.scale(loss).backward()
accum_loss += loss.item()
# ---- Gradient clipping ------------------------------------- #
if args.grad_clip > 0:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
else:
grad_norm = float("nan")
# ---- LR update --------------------------------------------- #
lr = get_lr(step, args.warmup_steps, args.max_steps, args.max_lr, args.min_lr)
for pg in optimizer.param_groups:
pg["lr"] = lr
# ---- Optimizer step ---------------------------------------- #
scaler.step(optimizer)
scaler.update()
step += 1
running_loss = accum_loss # loss for this step
# ---- Tokens per second ------------------------------------- #
t_now = time.time()
elapsed = t_now - t_step_start
t_step_start = t_now
tok_per_sec = tokens_per_step / max(elapsed, 1e-6)
# ---- Progress bar update ----------------------------------- #
pbar.update(1)
pbar.set_postfix({
"loss": f"{running_loss:.4f}",
"lr": f"{lr:.2e}",
"tok/s": f"{tok_per_sec:.0f}",
})
# ---- Logging ----------------------------------------------- #
if step % args.log_every == 0:
log_entry = {
"step": step,
"loss": round(running_loss, 6),
"lr": lr,
"grad_norm": round(float(grad_norm), 4) if not math.isnan(float(grad_norm)) else None,
"tok_per_sec": round(tok_per_sec, 1),
"elapsed_s": round(t_now - t_start, 1),
}
if device.type == "cuda":
log_entry["vram_gb"] = round(torch.cuda.memory_allocated() / 1e9, 3)
logger.log(**log_entry)
# ---- Validation -------------------------------------------- #
if step % args.val_every == 0:
val_loss = estimate_val_loss(model, val_loader, args.val_steps, device, autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp))
tqdm.write(f" [STEP {step:6d}] train_loss={running_loss:.4f} val_loss={val_loss:.4f} lr={lr:.2e}")
logger.log(step=step, val_loss=round(val_loss, 6))
# ---- Checkpoint -------------------------------------------- #
if step % args.save_every == 0:
ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt")
save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss)
# ---- Final checkpoint on exit (only if we actually ran steps) -- #
pbar.close()
steps_done = step - start_step
if steps_done > 0:
ckpt_path = os.path.join(args.run_dir, f"ckpt_{step:07d}.pt")
save_checkpoint(ckpt_path, model, optimizer, step, args, running_loss)
else:
print("\n [SKIP] No steps were taken — skipping final checkpoint save.")
total_time = time.time() - t_start
print(f"\n{'='*60}")
print(f" TRAINING COMPLETE")
print(f"{'='*60}")
print(f" Steps completed : {step - start_step}")
print(f" Final loss : {running_loss:.4f}")
print(f" Total time : {total_time/60:.1f} min")
print(f" Run dir : {args.run_dir}")
print(f"\nTo resume: python train.py --resume --run_dir {args.run_dir} --max_steps <N>")
print(f"To plot : python plot_training.py --run_dir {args.run_dir}")
if __name__ == "__main__":
train()