RON-110M / code /train.py
endurasolution's picture
Upload Ron-110M: pretrain + summarizer + tokenizer + code
3b97420 verified
from __future__ import annotations
import argparse
import json
import math
import os
import random
import time
from pathlib import Path
from typing import Any
import numpy as np
import torch
from rich.console import Console
from searshorai.model import GPT, GPTConfig
console = Console()
PRESETS = {
"quick_test": dict(
n_layer=6,
n_head=6,
n_embd=384,
block_size=256,
batch_size=8,
grad_accum=8,
max_steps=1000,
),
"gpu_16gb": dict(
n_layer=10,
n_head=10,
n_embd=640,
block_size=512,
batch_size=4,
grad_accum=16,
max_steps=20000,
),
"rtx3090_8h": dict(
n_layer=12,
n_head=12,
n_embd=768,
block_size=512,
batch_size=8,
grad_accum=16,
max_steps=20000,
),
"rtx3090_quality": dict(
n_layer=16,
n_head=16,
n_embd=1024,
block_size=512,
batch_size=4,
grad_accum=24,
max_steps=30000,
),
"gpu_40gb_quality": dict(
n_layer=20,
n_head=16,
n_embd=1024,
block_size=768,
batch_size=4,
grad_accum=32,
max_steps=40000,
),
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train a GPT-style language model from scratch.")
parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103"))
parser.add_argument("--out_dir", type=Path, default=Path("runs/wikitext-gpt"))
parser.add_argument("--preset", choices=PRESETS.keys(), default="gpu_16gb")
parser.add_argument("--resume", type=Path, default=None)
parser.add_argument("--reset_optimizer", action="store_true")
parser.add_argument("--reset_step", action="store_true",
help="When resuming, restart step counter at 0 (useful when restarting a fresh schedule).")
parser.add_argument("--n_layer", type=int, default=None)
parser.add_argument("--n_head", type=int, default=None)
parser.add_argument("--n_embd", type=int, default=None)
parser.add_argument("--block_size", type=int, default=None)
parser.add_argument("--batch_size", type=int, default=None, help="Micro-batch size.")
parser.add_argument("--grad_accum", type=int, default=None)
parser.add_argument("--max_steps", type=int, default=None)
parser.add_argument("--learning_rate", type=float, default=2.5e-4)
parser.add_argument("--min_lr", type=float, default=2.5e-5)
parser.add_argument("--warmup_steps", type=int, default=1000)
parser.add_argument("--weight_decay", type=float, default=0.1)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--eval_interval", type=int, default=500)
parser.add_argument("--eval_iters", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=1000)
parser.add_argument("--log_interval", type=int, default=20)
parser.add_argument("--seed", type=int, default=1337)
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "float16", "bfloat16"])
parser.add_argument("--compile", action="store_true")
parser.add_argument("--gradient_checkpointing", action="store_true")
parser.add_argument(
"--no_gradient_checkpointing",
"--no-gradient-checkpointing",
action="store_true",
help="Disable checkpointing when resuming from a checkpoint that was trained with it.",
)
parser.add_argument("--eval_only", action="store_true")
parser.add_argument("--always_save_checkpoint", action="store_true")
parser.add_argument("--save_optimizer", action="store_true")
return parser.parse_args()
def apply_preset(args: argparse.Namespace) -> argparse.Namespace:
preset = PRESETS[args.preset]
for key, value in preset.items():
if getattr(args, key) is None:
setattr(args, key, value)
return args
def setup_reproducibility(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
def choose_device(args: argparse.Namespace) -> str:
if args.device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
if args.device == "cuda" and not torch.cuda.is_available():
raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is False.")
return args.device
def choose_dtype(args: argparse.Namespace, device: str) -> torch.dtype:
if device == "cpu":
return torch.float32
if args.dtype == "float32":
return torch.float32
if args.dtype == "float16":
return torch.float16
if args.dtype == "bfloat16":
if torch.cuda.is_bf16_supported():
return torch.bfloat16
console.print("[yellow]bfloat16 requested but not supported. Falling back to float16.[/yellow]")
return torch.float16
if torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
def make_autocast_context(device: str, dtype: torch.dtype):
enabled = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
return torch.amp.autocast(device_type=device, dtype=dtype, enabled=enabled)
def make_grad_scaler(device: str, dtype: torch.dtype):
enabled = device == "cuda" and dtype == torch.float16
try:
return torch.amp.GradScaler("cuda", enabled=enabled)
except TypeError:
return torch.cuda.amp.GradScaler(enabled=enabled)
def get_lr(step: int, args: argparse.Namespace) -> float:
if step < args.warmup_steps:
return args.learning_rate * step / max(1, args.warmup_steps)
if step > args.max_steps:
return args.min_lr
decay_ratio = (step - args.warmup_steps) / max(1, args.max_steps - args.warmup_steps)
decay_ratio = min(1.0, max(0.0, decay_ratio))
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return args.min_lr + coeff * (args.learning_rate - args.min_lr)
def load_json(path: Path) -> dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Missing required file: {path}")
return json.loads(path.read_text(encoding="utf-8"))
def validate_meta(meta: dict[str, Any]) -> None:
required_keys = ["vocab_size", "dtype"]
for key in required_keys:
if key not in meta:
raise KeyError(f"meta.json is missing required key: {key}")
if meta["dtype"] not in ("uint16", "uint32"):
raise ValueError(f"Unsupported meta dtype: {meta['dtype']}. Expected uint16 or uint32.")
if int(meta["vocab_size"]) <= 0:
raise ValueError("meta.json vocab_size must be greater than zero.")
if meta["dtype"] == "uint16" and int(meta["vocab_size"]) > 65535:
raise ValueError("meta dtype is uint16 but vocab_size is greater than 65535. Use uint32 data files.")
def load_memmap(path: Path, dtype: str) -> np.memmap:
if not path.exists():
raise FileNotFoundError(f"Missing required file: {path}")
np_dtype = np.uint16 if dtype == "uint16" else np.uint32
return np.memmap(path, dtype=np_dtype, mode="r")
def validate_dataset(train_data: np.memmap, val_data: np.memmap, block_size: int, vocab_size: int) -> None:
min_required = block_size + 2
if len(train_data) < min_required:
raise ValueError(
f"train.bin is too small. Need at least {min_required} tokens for block_size={block_size}, "
f"but got {len(train_data)}."
)
if len(val_data) < min_required:
raise ValueError(
f"val.bin is too small. Need at least {min_required} tokens for block_size={block_size}, "
f"but got {len(val_data)}."
)
sample_count = min(10000, len(train_data))
sample_positions = np.linspace(0, len(train_data) - 1, sample_count, dtype=np.int64)
sample = np.asarray(train_data[sample_positions], dtype=np.int64)
max_token = int(sample.max())
min_token = int(sample.min())
if min_token < 0:
raise ValueError(f"Dataset contains negative token id: {min_token}")
if max_token >= vocab_size:
raise ValueError(
f"Dataset token id {max_token} is >= vocab_size {vocab_size}. "
"This usually means tokenizer/meta/train.bin mismatch."
)
def get_batch(
data: np.memmap,
batch_size: int,
block_size: int,
device: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fast batch loader: one vectorized gather, then a single host->device transfer.
The old code did batch_size python-level numpy slices per call, which was a
major bottleneck.
"""
max_start = len(data) - block_size - 1
if max_start <= 0:
raise ValueError("Dataset is too small for the configured block_size.")
# Random start positions.
ix = np.random.randint(0, max_start, size=(batch_size,), dtype=np.int64)
# Allocate contiguous int64 arrays. memmap reads are cheap for sequential blocks.
x_np = np.empty((batch_size, block_size), dtype=np.int64)
y_np = np.empty((batch_size, block_size), dtype=np.int64)
for row, start in enumerate(ix):
x_np[row] = data[start : start + block_size]
y_np[row] = data[start + 1 : start + 1 + block_size]
x = torch.from_numpy(x_np)
y = torch.from_numpy(y_np)
if device == "cuda":
x = x.pin_memory().to(device, non_blocking=True)
y = y.pin_memory().to(device, non_blocking=True)
else:
x = x.to(device)
y = y.to(device)
return x, y
@torch.no_grad()
def estimate_loss(
model: GPT,
train_data: np.memmap,
val_data: np.memmap,
args: argparse.Namespace,
device: str,
autocast_ctx,
) -> dict[str, float]:
out: dict[str, float] = {}
model.eval()
for split, data in [("train", train_data), ("val", val_data)]:
losses = []
for _ in range(args.eval_iters):
x, y = get_batch(data, args.batch_size, args.block_size, device)
with autocast_ctx:
_, loss = model(x, y)
if torch.isfinite(loss):
losses.append(float(loss.item()))
out[split] = float(sum(losses) / max(1, len(losses)))
model.train()
return out
def unwrap_model(model: GPT) -> GPT:
if hasattr(model, "_orig_mod"):
return model._orig_mod
return model
def strip_compile_prefix(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
cleaned = {}
for key, value in state_dict.items():
if key.startswith("_orig_mod."):
key = key[len("_orig_mod.") :]
cleaned[key] = value
return cleaned
def optimizer_to_device(optimizer: torch.optim.Optimizer, device: str) -> None:
for state in optimizer.state.values():
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to(device)
def save_checkpoint(
path: Path,
model: GPT,
optimizer: torch.optim.Optimizer | None,
args: argparse.Namespace,
step: int,
best_val_loss: float,
meta: dict[str, Any],
) -> None:
raw_model = unwrap_model(model)
checkpoint: dict[str, Any] = {
"model": raw_model.state_dict(),
"args": vars(args),
"config": vars(raw_model.config),
"step": step,
"best_val_loss": best_val_loss,
"meta": meta,
}
if args.save_optimizer and optimizer is not None:
checkpoint["optimizer"] = optimizer.state_dict()
torch.save(checkpoint, path)
def write_run_config(args: argparse.Namespace, meta: dict[str, Any], device: str, dtype: torch.dtype) -> None:
config_path = args.out_dir / "run_config.json"
payload = {
"args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()},
"meta": meta,
"device": device,
"dtype": str(dtype),
"torch_version": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
}
config_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
def build_model_from_checkpoint(
ckpt_path: Path,
device: str,
args: argparse.Namespace,
) -> tuple[GPT, int, float, dict[str, Any]]:
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
config = GPTConfig(**ckpt["config"])
if hasattr(config, "gradient_checkpointing"):
if args.no_gradient_checkpointing:
config.gradient_checkpointing = False
elif args.gradient_checkpointing:
config.gradient_checkpointing = True
model = GPT(config)
state_dict = strip_compile_prefix(ckpt["model"])
model.load_state_dict(state_dict, strict=True)
start_step = int(ckpt.get("step", 0))
best_val_loss = float(ckpt.get("best_val_loss", float("inf")))
checkpoint_meta = ckpt.get("meta", {})
return model, start_step, best_val_loss, checkpoint_meta
def build_new_model(meta: dict[str, Any], args: argparse.Namespace) -> tuple[GPT, int, float]:
config = GPTConfig(
vocab_size=int(meta["vocab_size"]),
block_size=int(args.block_size),
n_layer=int(args.n_layer),
n_head=int(args.n_head),
n_embd=int(args.n_embd),
dropout=float(args.dropout),
gradient_checkpointing=bool(args.gradient_checkpointing),
)
model = GPT(config)
return model, 0, float("inf")
def print_startup_info(
model: GPT,
args: argparse.Namespace,
device: str,
dtype: torch.dtype,
train_data: np.memmap,
val_data: np.memmap,
start_step: int,
) -> None:
raw_model = unwrap_model(model)
tokens_per_step = args.batch_size * args.grad_accum * args.block_size
if hasattr(raw_model, "num_parameters"):
num_params = raw_model.num_parameters()
else:
num_params = sum(p.numel() for p in raw_model.parameters())
console.print("")
console.print("[bold green]Training configuration[/bold green]")
console.print(f"Device: {device}")
console.print(f"Dtype: {dtype}")
console.print(f"Preset: {args.preset}")
console.print(f"Parameters: {num_params / 1e6:.2f}M")
console.print(f"Layers: {args.n_layer}")
console.print(f"Heads: {args.n_head}")
console.print(f"Embedding size: {args.n_embd}")
console.print(f"Block size: {args.block_size}")
console.print(f"Batch size: {args.batch_size}")
console.print(f"Grad accumulation: {args.grad_accum}")
console.print(f"Tokens per step: {tokens_per_step:,}")
console.print(f"Train tokens: {len(train_data):,}")
console.print(f"Val tokens: {len(val_data):,}")
console.print(f"Start step: {start_step:,}")
console.print(f"Max steps: {args.max_steps:,}")
console.print(f"Learning rate: {args.learning_rate:.2e}")
console.print(f"Min LR: {args.min_lr:.2e}")
console.print(f"Warmup steps: {args.warmup_steps:,}")
console.print(f"Grad clip: {args.grad_clip}")
console.print("")
def main() -> None:
args = apply_preset(parse_args())
args.out_dir.mkdir(parents=True, exist_ok=True)
setup_reproducibility(args.seed)
device = choose_device(args)
dtype = choose_dtype(args, device)
autocast_ctx = make_autocast_context(device, dtype)
scaler = make_grad_scaler(device, dtype)
meta_path = args.data_dir / "meta.json"
meta = load_json(meta_path)
validate_meta(meta)
train_data = load_memmap(args.data_dir / "train.bin", meta["dtype"])
val_data = load_memmap(args.data_dir / "val.bin", meta["dtype"])
validate_dataset(
train_data=train_data,
val_data=val_data,
block_size=int(args.block_size),
vocab_size=int(meta["vocab_size"]),
)
if args.resume is not None:
console.print(f"[yellow]Resuming from checkpoint:[/yellow] {args.resume}")
model, start_step, best_val_loss, checkpoint_meta = build_model_from_checkpoint(args.resume, device, args)
if checkpoint_meta:
meta = checkpoint_meta
else:
model, start_step, best_val_loss = build_new_model(meta, args)
if args.reset_step:
start_step = 0
best_val_loss = float("inf")
console.print("[yellow]reset_step set: step counter restarted at 0.[/yellow]")
model.to(device)
optimizer = model.configure_optimizers(
args.weight_decay,
args.learning_rate,
(0.9, 0.95),
"cuda" if device == "cuda" else "cpu",
)
if args.resume is not None and not args.reset_optimizer:
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
if "optimizer" in ckpt:
try:
optimizer.load_state_dict(ckpt["optimizer"])
optimizer_to_device(optimizer, device)
console.print("[green]Loaded optimizer state from checkpoint.[/green]")
except Exception as exc:
console.print(f"[yellow]Could not load optimizer state. Continuing with fresh optimizer. Error: {exc}[/yellow]")
else:
console.print("[yellow]Checkpoint has no optimizer state. Continuing with fresh optimizer.[/yellow]")
elif args.resume is not None and args.reset_optimizer:
console.print("[yellow]reset_optimizer set: starting with fresh Adam moments.[/yellow]")
if args.compile:
console.print("[cyan]Compiling model...[/cyan]")
model = torch.compile(model)
write_run_config(args, meta, device, dtype)
print_startup_info(model, args, device, dtype, train_data, val_data, start_step)
if args.eval_only:
losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx)
console.print(f"eval only: train {losses['train']:.4f}, val {losses['val']:.4f}")
return
model.train()
tokens_per_step = args.batch_size * args.grad_accum * args.block_size
start_time = time.time()
last_log_time = start_time
last_log_step = start_step
for completed_step in range(start_step, args.max_steps):
step = completed_step + 1
lr = get_lr(step, args)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
optimizer.zero_grad(set_to_none=True)
loss_accum = 0.0
skipped_micro = 0
for _ in range(args.grad_accum):
x, y = get_batch(train_data, args.batch_size, args.block_size, device)
with autocast_ctx:
_, loss = model(x, y)
loss = loss / args.grad_accum
if not torch.isfinite(loss):
console.print(f"[yellow]Non-finite loss at step {step}, skipping micro-batch.[/yellow]")
skipped_micro += 1
continue
scaler.scale(loss).backward()
loss_accum += float(loss.item())
if skipped_micro == args.grad_accum:
# Whole step was bad. Skip the optimizer update.
scaler.update()
continue
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
if step % args.log_interval == 0 or step == start_step + 1:
now = time.time()
elapsed = max(now - last_log_time, 1e-9)
steps_done = max(1, step - last_log_step)
toks_per_sec = (tokens_per_step * steps_done) / elapsed
last_log_time = now
last_log_step = step
console.print(
f"step {step:7d} | "
f"loss {loss_accum:.4f} | "
f"lr {lr:.2e} | "
f"grad {float(grad_norm):.2f} | "
f"{toks_per_sec:,.0f} tok/s"
)
should_eval = step % args.eval_interval == 0 or step == args.max_steps
if should_eval:
losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx)
console.print(
f"[bold]eval step {step}:[/bold] "
f"train {losses['train']:.4f}, val {losses['val']:.4f}"
)
if losses["val"] < best_val_loss:
best_val_loss = losses["val"]
save_checkpoint(
args.out_dir / "best.pt",
model,
optimizer,
args,
step,
best_val_loss,
meta,
)
console.print(f"[green]saved best checkpoint: val {best_val_loss:.4f}[/green]")
if args.always_save_checkpoint:
save_checkpoint(
args.out_dir / f"step_{step}.pt",
model,
optimizer,
args,
step,
best_val_loss,
meta,
)
if step % args.save_interval == 0:
save_checkpoint(
args.out_dir / "latest.pt",
model,
optimizer,
args,
step,
best_val_loss,
meta,
)
console.print(f"[cyan]saved latest checkpoint at step {step}[/cyan]")
save_checkpoint(
args.out_dir / "latest.pt",
model,
optimizer,
args,
args.max_steps,
best_val_loss,
meta,
)
elapsed_hours = (time.time() - start_time) / 3600.0
console.print("")
console.print(f"[bold green]Finished in {elapsed_hours:.2f} hours.[/bold green]")
console.print(f"[bold green]Best validation loss: {best_val_loss:.4f}[/bold green]")
console.print(f"[bold green]Best checkpoint: {args.out_dir / 'best.pt'}[/bold green]")
console.print(f"[bold green]Latest checkpoint: {args.out_dir / 'latest.pt'}[/bold green]")
if __name__ == "__main__":
main()