""" LUNA 100M — Config-Driven Dynamic Training Script ================================================== Reads train_config.yaml for all hyperparameters. auto_config: true -> hardware probed; batch/lr/workers set automatically auto_config: false -> every value in config used exactly as-is Usage: python train.py # uses train_config.yaml defaults python train.py --config train_config.yaml # explicit config path python train.py --data_path /mnt/data/litdata_final # override data path only python train.py --max_tokens 10000000 # short smoke-test run """ import os import gc import sys import math import time import json import argparse import yaml import psutil import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import autocast, GradScaler from pathlib import Path # Reduce CUDA memory fragmentation os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") # ─── Model ──────────────────────────────────────────────────────────────────── class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len=1024): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(max_seq_len).float() freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos()) self.register_buffer("sin_cached", emb.sin()) def forward(self, seq_len): return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def apply_rotary(x, cos, sin): c = cos.unsqueeze(0).unsqueeze(0) s = sin.unsqueeze(0).unsqueeze(0) return x * c + rotate_half(x) * s class CausalSelfAttention(nn.Module): def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25): super().__init__() self.n_head = n_head self.head_dim = n_embd // n_head self.rot_dim = int(self.head_dim * rotary_pct) self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True) self.c_proj = nn.Linear(n_embd, n_embd, bias=True) self.rotary = RotaryEmbedding(self.rot_dim, block_size) def forward(self, x): B, T, C = x.size() qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) cos, sin = self.rotary(T) q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1) k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C)) class MLP(nn.Module): def __init__(self, n_embd): super().__init__() self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True) self.gelu = nn.GELU() self.proj = nn.Linear(4 * n_embd, n_embd, bias=True) def forward(self, x): return self.proj(self.gelu(self.fc(x))) class Block(nn.Module): def __init__(self, n_embd, n_head, block_size): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_embd, n_head, block_size) self.ln2 = nn.LayerNorm(n_embd) self.mlp = MLP(n_embd) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class LUNAModel(nn.Module): def __init__(self, vocab_size, block_size, n_layer, n_embd, n_head): super().__init__() self.wte = nn.Embedding(vocab_size, n_embd) self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) self.lm_head.weight = self.wte.weight # tie self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Embedding)): m.weight.data.normal_(mean=0.0, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: m.bias.data.zero_() def forward(self, idx, targets=None, return_logits=True): x = self.wte(idx) for block in self.blocks: x = block(x) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if not return_logits: logits = None return logits, loss @property def num_params(self): return sum(p.numel() for p in self.parameters()) - self.wte.weight.numel() # ─── Dataset ────────────────────────────────────────────────────────────────── class LitDataDataset(torch.utils.data.Dataset): def __init__(self, data_path: str, block_size: int = 1024): import struct, numpy as np self.block_size = block_size self.data_path = Path(data_path) with open(self.data_path / "index.json") as f: idx = json.load(f) self.chunks_meta = idx["chunks"] self._cum_blocks = [] total = 0 for c in self.chunks_meta: n = c["dim"] // (block_size + 1) total += n self._cum_blocks.append(total) self.total_blocks = total self._chunk_cache = {} def _load_chunk(self, chunk_idx: int): if chunk_idx in self._chunk_cache: return self._chunk_cache[chunk_idx] import struct, numpy as np meta = self.chunks_meta[chunk_idx] with open(self.data_path / meta["filename"], "rb") as f: raw = f.read() num_items = struct.unpack_from("= 4: del self._chunk_cache[next(iter(self._chunk_cache))] self._chunk_cache[chunk_idx] = tokens return tokens def __len__(self): return self.total_blocks def __getitem__(self, idx): chunk_idx = 0 for i, cum in enumerate(self._cum_blocks): if idx < cum: chunk_idx = i break prev = self._cum_blocks[chunk_idx - 1] if chunk_idx > 0 else 0 tokens = self._load_chunk(chunk_idx) s = (idx - prev) * (self.block_size + 1) e = s + self.block_size + 1 chunk = tokens[s:e] if len(chunk) < self.block_size + 1: pad = torch.zeros(self.block_size + 1, dtype=torch.int32) pad[:len(chunk)] = chunk chunk = pad chunk = chunk.long() return chunk[:self.block_size], chunk[1:self.block_size + 1] # ─── Hardware Detection ──────────────────────────────────────────────────────── def probe_hardware(): info = { "cpu_cores": os.cpu_count() or 4, "ram_gb": psutil.virtual_memory().total / 1024**3, } if torch.cuda.is_available(): props = torch.cuda.get_device_properties(0) info.update({ "device": "cuda", "gpu_name": props.name, "vram_gb": props.total_memory / 1024**3, "sm_major": props.major, }) if props.major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True info["precision"] = "bf16" info["dtype"] = torch.bfloat16 else: info["precision"] = "fp16" info["dtype"] = torch.float16 else: info.update({ "device": "cpu", "gpu_name": "CPU", "vram_gb": 0, "sm_major": 0, "precision": "fp32", "dtype": torch.float32, }) return info def probe_max_batch(model, device, dtype, seq_len, vocab_size, max_search=4096, grad_accum_sim=4): """Binary search for max micro_batch. Simulates grad_accum forward+backward passes to account for real training memory patterns. Safety: x0.70.""" tmp_opt = torch.optim.AdamW(model.parameters(), lr=1e-4) lo, hi, best = 1, max_search, 1 while lo <= hi: mid = (lo + hi) // 2 try: torch.cuda.empty_cache(); gc.collect() tmp_opt.zero_grad(set_to_none=True) # Simulate grad_accum micro-batches (real training pattern) for _ in range(grad_accum_sim): x = torch.randint(0, vocab_size, (mid, seq_len), device=device) t = torch.randint(0, vocab_size, (mid, seq_len), device=device) with autocast(device_type="cuda", dtype=dtype): _, loss = model(x, t, return_logits=False) loss = loss / grad_accum_sim loss.backward() del x, t, loss tmp_opt.step() tmp_opt.zero_grad(set_to_none=True) best = mid; lo = mid + 1 torch.cuda.empty_cache() except torch.cuda.OutOfMemoryError: try: del x, t, loss except: pass torch.cuda.empty_cache() tmp_opt.zero_grad(set_to_none=True) hi = mid - 1 except RuntimeError as e: if "out of memory" in str(e).lower(): try: del x, t, loss except: pass torch.cuda.empty_cache() tmp_opt.zero_grad(set_to_none=True) hi = mid - 1 else: raise del tmp_opt; torch.cuda.empty_cache(); gc.collect() safe = max(1, int(best * 0.70)) print(f" Probe found max_batch={best}, using {safe} (70% safety, tested with {grad_accum_sim} accum steps)") return safe # ─── LR Schedule ────────────────────────────────────────────────────────────── def cosine_lr(step, warmup, total, lr_max, lr_min): if step < warmup: return lr_max * (step + 1) / warmup p = (step - warmup) / max(1, total - warmup) return lr_min + 0.5 * (1 + math.cos(math.pi * p)) * (lr_max - lr_min) # ─── Config Loading ─────────────────────────────────────────────────────────── def load_config(config_path: str) -> dict: """Load YAML config and return flat namespace dict.""" with open(config_path, encoding="utf-8") as f: raw = yaml.safe_load(f) cfg = { # top-level "auto_config": raw.get("auto_config", True), "data_path": raw.get("data_path", "Base/data/litdata_pretrain_final"), "out_dir": raw.get("out_dir", "out/pretrain/luna-100m"), "tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"), # model "vocab_size": raw["model"]["vocab_size"], "seq_len": raw["model"]["seq_len"], "n_layer": raw["model"]["n_layer"], "n_embd": raw["model"]["n_embd"], "n_head": raw["model"]["n_head"], # train "max_tokens": raw["train"]["max_tokens"], "lr_warmup_steps":raw["train"]["lr_warmup_steps"], "save_interval": raw["train"]["save_interval"], "log_interval": raw["train"]["log_interval"], "max_norm": raw["train"]["max_norm"], # optimizer "lr": raw["optimizer"]["lr"], "min_lr": raw["optimizer"]["min_lr"], "weight_decay": raw["optimizer"]["weight_decay"], "betas": tuple(raw["optimizer"]["betas"]), "eps": raw["optimizer"]["eps"], # batch "global_batch": raw["batch"]["global_batch"], "micro_batch": raw["batch"]["micro_batch"], "grad_accum": raw["batch"]["grad_accum"], # dataloader "num_workers": raw["dataloader"]["num_workers"], "pin_memory": raw["dataloader"]["pin_memory"], # hardware "precision": raw["hardware"]["precision"], "compile": raw["hardware"]["compile"], } return cfg def apply_cli_overrides(cfg: dict, cli_args: argparse.Namespace) -> dict: """CLI args override config values (only if explicitly provided).""" for key, val in vars(cli_args).items(): if key == "config": continue if val is not None: # argparse default=None means "not provided" cfg[key] = val return cfg def resolve_auto(cfg: dict, hw: dict) -> dict: """ When auto_config=True: override batch, workers, lr-warmup, pin_memory, precision from real hardware. Never touches model arch or max_tokens. Returns updated cfg plus injected hw info. """ if not cfg["auto_config"]: print(" [CONFIG] auto_config=false -- using manual values as-is") cfg.update({"_hw": hw}) return cfg print(" [CONFIG] auto_config=true -- tuning settings to this hardware") # Precision cfg["precision"] = hw["precision"] cfg["_dtype"] = hw["dtype"] # Workers auto_workers = hw["cpu_cores"] // 2 # Cap by RAM: each worker caches up to 4 chunks × ~67MB max_by_ram = max(0, int(hw["ram_gb"] * 0.25 * 1024 / 268)) cfg["num_workers"] = min(auto_workers, max_by_ram, hw["cpu_cores"]) if cfg["num_workers"] == -1: cfg["num_workers"] = 0 # Pin memory cfg["pin_memory"] = hw["ram_gb"] > 16 and hw["device"] == "cuda" # LR warmup: 5% of total steps (will be computed again in train()) cfg["_auto_warmup"] = True # flag: recompute once total_steps is known # LR scaling: sqrt(global_batch / 120) relative to base lr base_global = 120 cfg["lr"] = cfg["lr"] * math.sqrt(cfg["global_batch"] / base_global) cfg["min_lr"] = cfg["min_lr"] * math.sqrt(cfg["global_batch"] / base_global) cfg["_hw"] = hw return cfg # ─── Training ───────────────────────────────────────────────────────────────── SEP = "=" * 72 def train(cfg: dict): hw = cfg["_hw"] device = torch.device(hw["device"]) # Clean GPU before anything — kill leftover allocations from prior runs if device.type == "cuda": torch.cuda.empty_cache() gc.collect() free_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1024**3 print(f" GPU free before model load: {free_gb:.1f} GB") # Pick precision dtype if cfg["auto_config"]: dtype = hw.get("dtype", torch.float32) else: dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}.get(cfg["precision"], torch.float32) print(SEP) print(" LUNA 100M - Training") print(SEP) mode = "AUTO" if cfg["auto_config"] else "MANUAL" print(f" Config mode : {mode}") print(f" GPU : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)") print(f" RAM : {hw['ram_gb']:.1f} GB CPU: {hw['cpu_cores']} cores") print(f" Precision : {cfg['precision']} dtype={dtype}") print(f" Workers : {cfg['num_workers']} pin_memory={cfg['pin_memory']}") # ── Model ───────────────────────────────────────────────────────────────── print(f"\n Building LUNA-100M...") model = LUNAModel( vocab_size=cfg["vocab_size"], block_size=cfg["seq_len"], n_layer=cfg["n_layer"], n_embd=cfg["n_embd"], n_head=cfg["n_head"], ).to(device) compiled_model = False # torch.compile disabled: causes CUDA graph / OOM issues with tied # embeddings at this model size. Raw PyTorch + SDPA is already fast. print(" torch.compile: disabled (not needed for 100M params)") print(f" Parameters: {model.num_params:,} (unique)") # ── Batch sizing ────────────────────────────────────────────────────────── if cfg["auto_config"] and device.type == "cuda": print(f"\n Probing max micro_batch_size (VRAM search)...") # Probe using the actual model — no second copy wasting VRAM max_mbs = probe_max_batch( model, device, dtype, cfg["seq_len"], cfg["vocab_size"] ) # Re-init model weights after probe (probe dirties optimizer state) model.apply(model._init_weights) torch.cuda.empty_cache(); gc.collect() # grad_accum to hit global_batch grad_accum = max(1, math.ceil(cfg["global_batch"] / max_mbs)) effective_batch = max_mbs * grad_accum print(f" AUTO -> micro_batch={max_mbs}, grad_accum={grad_accum}, " f"effective_batch={effective_batch}") else: max_mbs = cfg["micro_batch"] grad_accum = cfg["grad_accum"] effective_batch = max_mbs * grad_accum print(f"\n MANUAL -> micro_batch={max_mbs}, grad_accum={grad_accum}, " f"effective_batch={effective_batch}") tokens_per_step = effective_batch * cfg["seq_len"] print(f" Tokens/step : {tokens_per_step:,}") # ── Dataset ─────────────────────────────────────────────────────────────── print(f"\n Dataset: {cfg['data_path']}") dataset = LitDataDataset(cfg["data_path"], block_size=cfg["seq_len"]) print(f" Blocks : {len(dataset):,} ({len(dataset) * cfg['seq_len']:,} tokens)") loader = torch.utils.data.DataLoader( dataset, batch_size=max_mbs, shuffle=True, num_workers=cfg["num_workers"], pin_memory=cfg["pin_memory"], drop_last=True, prefetch_factor=4 if cfg["num_workers"] > 0 else None, persistent_workers=cfg["num_workers"] > 0, ) # ── Optimiser ───────────────────────────────────────────────────────────── fused_ok = device.type == "cuda" and hasattr(torch.optim, "AdamW") try: optimizer = torch.optim.AdamW( model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"], betas=cfg["betas"], eps=cfg["eps"], fused=True, ) except TypeError: optimizer = torch.optim.AdamW( model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"], betas=cfg["betas"], eps=cfg["eps"], ) use_scaler = dtype == torch.float16 scaler = GradScaler(enabled=use_scaler) # ── Schedule ────────────────────────────────────────────────────────────── total_steps = max(1, cfg["max_tokens"] // tokens_per_step) if cfg["auto_config"] and cfg.get("_auto_warmup"): warmup_steps = max(50, min(500, total_steps // 20)) else: warmup_steps = min(cfg["lr_warmup_steps"], total_steps) out_dir = Path(cfg["out_dir"]) out_dir.mkdir(parents=True, exist_ok=True) print(f"\n max_tokens : {cfg['max_tokens']:,}") print(f" total_steps : {total_steps:,}") print(f" warmup_steps : {warmup_steps}") print(f" lr : {cfg['lr']:.2e} -> {cfg['min_lr']:.2e}") print(f" save every : {cfg['save_interval']} steps") print(f" out_dir : {out_dir}") print(SEP) # ── Resume ──────────────────────────────────────────────────────────────── start_step = 0 ckpt_path = out_dir / "latest.pt" if ckpt_path.exists(): print(f"\n Resuming from {ckpt_path}...") ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) start_step = ckpt["step"] print(f" Resumed at step {start_step}") # ── Loop ────────────────────────────────────────────────────────────────── model.train() data_iter = iter(loader) def get_batch(): nonlocal data_iter try: return next(data_iter) except StopIteration: data_iter = iter(loader) return next(data_iter) run_t0 = time.perf_counter() tokens_seen = start_step * tokens_per_step step = start_step print(f"\n Starting training (step {start_step} -> {total_steps})...") while step < total_steps: t0 = time.perf_counter() lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"]) for pg in optimizer.param_groups: pg["lr"] = lr_now optimizer.zero_grad(set_to_none=True) total_loss = 0.0 for _ in range(grad_accum): x, t = get_batch() x = x.to(device, non_blocking=True) t = t.to(device, non_blocking=True) with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")): _, loss = model(x, t, return_logits=False) loss = loss / grad_accum scaler.scale(loss).backward() total_loss += loss.item() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["max_norm"]) scaler.step(optimizer) scaler.update() if device.type == "cuda": torch.cuda.synchronize() dt = time.perf_counter() - t0 step += 1 tokens_seen += tokens_per_step if step % cfg["log_interval"] == 0 or step <= 2: tps = tokens_per_step / dt steps_left = total_steps - step eta_h = steps_left * dt / 3600 vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0 print(f" step {step:6d}/{total_steps} | loss {total_loss:.4f} | " f"lr {lr_now:.2e} | {tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h") if step % cfg["save_interval"] == 0 or step == total_steps: raw = model._orig_mod if hasattr(model, "_orig_mod") else model step_dir = out_dir / f"step-{step:08d}" step_dir.mkdir(parents=True, exist_ok=True) torch.save(raw.state_dict(), step_dir / "lit_model.pth") torch.save({"step": step, "model": raw.state_dict(), "optimizer": optimizer.state_dict(), "tokens_seen": tokens_seen}, out_dir / "latest.pt") print(f" Saved -> {step_dir}") # ── Final ───────────────────────────────────────────────────────────────── final_dir = out_dir / "final" final_dir.mkdir(parents=True, exist_ok=True) raw = model._orig_mod if hasattr(model, "_orig_mod") else model torch.save(raw.state_dict(), final_dir / "lit_model.pth") import shutil tok_src = Path(cfg["tokenizer_dir"]) if tok_src.exists(): shutil.copytree(tok_src, final_dir / "tokenizer", dirs_exist_ok=True) total_h = (time.perf_counter() - run_t0) / 3600 print(SEP) print(f" Done! {total_h:.2f} h -> {final_dir}") print(SEP) # ─── Entry point ────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser(description="LUNA 100M Trainer") p.add_argument("--config", type=str, default="train_config.yaml", help="Path to train_config.yaml") # CLI overrides (all optional - omit to use config value) p.add_argument("--data_path", type=str, default=None) p.add_argument("--out_dir", type=str, default=None) p.add_argument("--max_tokens", type=int, default=None) p.add_argument("--micro_batch", type=int, default=None) p.add_argument("--global_batch",type=int, default=None) p.add_argument("--lr", type=float, default=None) p.add_argument("--num_workers", type=int, default=None) p.add_argument("--save_interval",type=int, default=None) p.add_argument("--log_interval",type=int, default=None) p.add_argument("--auto_config", type=lambda x: x.lower() in ("1","true","yes"), default=None, help="Override auto_config (true/false)") return p.parse_args() if __name__ == "__main__": args = parse_args() cfg = load_config(args.config) cfg = apply_cli_overrides(cfg, args) hw = probe_hardware() cfg = resolve_auto(cfg, hw) train(cfg)