frankenstallm / source /train /pretrain.py
pathcosmos's picture
Upload folder using huggingface_hub (#14)
17758b3
"""
train/pretrain.py β€” Main pretraining entry point.
Launch single-GPU:
python train/pretrain.py --config configs/small.yaml --train_data data/train.bin
Launch multi-GPU with torchrun:
torchrun --nproc_per_node=8 train/pretrain.py --config configs/small.yaml \
--train_data data/train.bin
The script auto-detects whether it is running inside a torchrun launch by
checking for the RANK environment variable.
"""
from __future__ import annotations
import argparse
import os
import random
import signal
import sys
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
# B200 Tensor Core μ΅œλŒ€ ν™œμš©: TF32 matmul + cuDNN
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True # fixed seq_len=4096 β†’ safe to auto-tune
torch.set_float32_matmul_precision("high") # TF32 precision for fp32 matmul
# Allow imports from the project root regardless of working directory.
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
from data import PackedDataset
from model import LLM, LMConfig
from train.trainer import TrainConfig, Trainer
from train.utils import (
cleanup_ddp,
get_cosine_schedule_with_warmup,
is_main_process,
load_checkpoint,
setup_ddp,
)
# ---------------------------------------------------------------------------
# Optional TransformerEngine import (FP8 support)
# ---------------------------------------------------------------------------
try:
import transformer_engine.pytorch as te # type: ignore[import]
HAS_TE = True
except ImportError:
te = None # type: ignore[assignment]
HAS_TE = False
# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Pretrain a decoder-only LLM.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Paths
parser.add_argument(
"--config",
type=Path,
default=Path("configs/small.yaml"),
help="Path to the LMConfig YAML file.",
)
parser.add_argument(
"--train_data",
type=Path,
required=True,
help="Path to the training data .bin file (numpy uint16 memmap).",
)
parser.add_argument(
"--val_data",
type=Path,
default=None,
help="Optional path to validation data .bin file.",
)
parser.add_argument(
"--checkpoint_dir",
type=Path,
default=Path("checkpoints"),
help="Root directory for saving checkpoints.",
)
parser.add_argument(
"--resume",
type=Path,
default=None,
help="Path to a checkpoint directory to resume training from.",
)
# Training hyper-parameters
parser.add_argument(
"--max_steps",
type=int,
default=None,
help="Override the number of optimiser steps (default: TrainConfig.max_steps).",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Per-GPU micro-batch size.",
)
parser.add_argument(
"--lr",
type=float,
default=3e-4,
help="Peak learning rate.",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.1,
help="AdamW weight decay coefficient.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=2000,
help="Number of linear warmup steps.",
)
parser.add_argument(
"--grad_accum",
type=int,
default=1,
help="Gradient accumulation steps.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Base random seed (rank offset is added automatically).",
)
parser.add_argument(
"--log_file",
type=Path,
default=None,
help="Path to a text file for structured training logs (rank-0 only). "
"If omitted, logs go only to stdout.",
)
parser.add_argument(
"--use_fp8",
action="store_true",
default=False,
help="Enable TransformerEngine FP8 training (overrides config; requires B200/H100).",
)
return parser.parse_args()
# ---------------------------------------------------------------------------
# Seed helper
# ---------------------------------------------------------------------------
def set_seed(seed: int) -> None:
"""Set deterministic seeds for Python, NumPy, and PyTorch."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ---------------------------------------------------------------------------
# Optimizer parameter groups
# ---------------------------------------------------------------------------
def build_optimizer_param_groups(
model: torch.nn.Module,
weight_decay: float,
) -> list[dict]:
"""
Split parameters into two groups:
- decay group : weight tensors with ndim >= 2
- no-decay group: bias, LayerNorm/RMSNorm weights, and embedding weights
This follows standard practice (e.g. GPT-style training).
"""
decay_params: list[torch.nn.Parameter] = []
no_decay_params: list[torch.nn.Parameter] = []
# Names of module types whose parameters should never be decayed.
no_decay_module_types = (
torch.nn.Embedding,
torch.nn.LayerNorm,
)
# Also skip any parameter whose name ends with '.bias' or 'norm'.
# Mamba-2 SSM parameters that should never be decayed
no_decay_name_suffixes = ("bias", "A_log", "D", "dt_bias")
# Collect module-level exclusions.
no_decay_module_params: set[int] = set()
for module in model.modules():
if isinstance(module, no_decay_module_types):
for param in module.parameters(recurse=False):
no_decay_module_params.add(id(param))
seen: set[int] = set()
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if id(param) in seen:
continue
seen.add(id(param))
if (
id(param) in no_decay_module_params
or any(name.endswith(sfx) for sfx in no_decay_name_suffixes)
or param.ndim < 2
):
no_decay_params.append(param)
else:
decay_params.append(param)
return [
{"params": decay_params, "weight_decay": weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
args = parse_args()
# ---- Distributed setup -------------------------------------------------
is_ddp = "RANK" in os.environ
rank = 0
local_rank = 0
world_size = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if is_ddp:
rank, local_rank, world_size, device = setup_ddp()
# Per-rank seed so data shuffling differs across replicas.
set_seed(args.seed + rank)
# ---- NUMA affinity for optimal GPU↔CPU memory locality ---------------
# B200 topology: GPU 0-3 β†’ NUMA node 0 (cores 0-35)
# GPU 4-7 β†’ NUMA node 1 (cores 36-71)
# Without pinning, 5/8 ranks end up on wrong NUMA β†’ 3.2x memory latency.
try:
if local_rank < 4:
os.sched_setaffinity(0, set(range(0, 36))) # NUMA node 0
else:
os.sched_setaffinity(0, set(range(36, 72))) # NUMA node 1
if is_main_process():
print(f"NUMA affinity: rank {rank} (GPU {local_rank}) β†’ "
f"{'NUMA0 cores 0-35' if local_rank < 4 else 'NUMA1 cores 36-71'}")
except (AttributeError, OSError) as e:
if is_main_process():
print(f"[WARN] NUMA affinity failed: {e}")
# ---- Model -------------------------------------------------------------
if not args.config.exists():
raise FileNotFoundError(f"Config file not found: {args.config}")
lm_config = LMConfig.from_yaml(args.config)
# CLI --use_fp8 flag overrides whatever the config file says.
if args.use_fp8:
lm_config.use_fp8 = True
# FP8 alignment check: (batch_size Γ— seq_len) must be divisible by 8.
if lm_config.use_fp8 and (args.batch_size * lm_config.max_seq_len) % 8 != 0:
raise ValueError(
f"FP8: batch_size Γ— max_seq_len = {args.batch_size} Γ— {lm_config.max_seq_len} "
f"= {args.batch_size * lm_config.max_seq_len} must be divisible by 8."
)
# Note: fp8_model_init() is intentionally omitted β€” MXFP8Tensor weights are
# incompatible with DDP's _broadcast_coalesced during multi-GPU init.
# Weights remain in float32; TE quantizes on-the-fly inside fp8_autocast.
model = LLM(lm_config).to(device)
if is_main_process():
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
print(f"LMConfig: {lm_config}")
# ---- Wrap in DDP -------------------------------------------------------
if is_ddp:
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
gradient_as_bucket_view=True, # zero-copy gradient β†’ NCCL buffer
bucket_cap_mb=800, # larger buckets for NVLS (was 400)
find_unused_parameters=False, # fixed graph, no traversal overhead
# NOTE: static_graph=True 제거 β€” TE FP8 λ ˆμ΄μ–΄μ˜ 동적 autograd hooks와 좩돌
)
# ---- Dataset & DataLoader ----------------------------------------------
# PackedDataset: non-overlapping stride=seq_len windows.
# Avoids 600M random-index mmap accesses from stride-1 TextDataset.
train_dataset = PackedDataset(args.train_data, seq_len=lm_config.max_seq_len)
if is_ddp:
train_sampler: DistributedSampler | RandomSampler = DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=args.seed,
)
shuffle = False
else:
train_sampler = RandomSampler(train_dataset)
shuffle = False # Sampler is provided; DataLoader must not also shuffle.
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=6, # 6Γ—8=48 workers, fits 72-core budget with OMP=4
pin_memory=True,
drop_last=True,
prefetch_factor=4, # deeper pipeline for larger worker pool
persistent_workers=True, # keep workers alive across epochs β€” eliminates respawn stall
)
# ---- Optimizer ---------------------------------------------------------
param_groups = build_optimizer_param_groups(
getattr(model, "module", model), args.weight_decay
)
optimizer = torch.optim.AdamW(
param_groups,
lr=args.lr,
betas=(0.9, 0.95),
eps=1e-8,
fused=torch.cuda.is_available(), # Use fused kernel when on CUDA.
)
# ---- LR Scheduler ------------------------------------------------------
train_config = TrainConfig(
checkpoint_dir=str(args.checkpoint_dir),
grad_accum_steps=args.grad_accum,
use_fp8=lm_config.use_fp8,
log_file=str(args.log_file) if args.log_file is not None else None,
)
if args.max_steps is not None:
train_config.max_steps = args.max_steps
scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
warmup_steps=args.warmup_steps,
total_steps=train_config.max_steps,
)
# ---- Resume from checkpoint --------------------------------------------
start_step = 0
if args.resume is not None:
if not args.resume.exists():
raise FileNotFoundError(f"Checkpoint path not found: {args.resume}")
start_step, resume_loss = load_checkpoint(
path=args.resume,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
if is_main_process():
print(f"Resumed from {args.resume} at step {start_step} (loss={resume_loss:.4f})")
# ---- Checkpoint directory ----------------------------------------------
args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# ---- Trainer -----------------------------------------------------------
trainer = Trainer(
model=model,
train_loader=train_loader,
optimizer=optimizer,
scheduler=scheduler,
config=train_config,
device=device,
rank=rank,
sampler=train_sampler if is_ddp else None,
)
# ---- Signal handlers for graceful shutdown ----------------------------
# SIGHUP: SSH μ„Έμ…˜ λŠκΉ€ μ‹œ λ°œμƒ β†’ 이전에 ν•™μŠ΅μ„ 죽인 μ£Όλ²”
# SIGTERM: kill λͺ…λ Ή λ˜λŠ” μ‹œμŠ€ν…œ μ’…λ£Œ μ‹œ λ°œμƒ
# ν•Έλ“€λŸ¬κ°€ trainer.request_shutdown()을 ν˜ΈμΆœν•˜λ©΄, ν•™μŠ΅ 루프가
# ν˜„μž¬ step μ™„λ£Œ ν›„ 비상 체크포인트λ₯Ό μ €μž₯ν•˜κ³  κΉ¨λ—ν•˜κ²Œ μ’…λ£Œν•©λ‹ˆλ‹€.
_trainer_ref = trainer
def _graceful_shutdown_handler(signum, frame):
sig_name = signal.Signals(signum).name
if is_main_process():
import datetime as _dt
ts = _dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
msg = (
f"[{ts}] [SIGNAL] Received {sig_name} (signum={signum}). "
f"Initiating graceful shutdown..."
)
print(f"\n{msg}")
# 둜그 νŒŒμΌμ—λ„ μ¦‰μ‹œ 기둝 (μ‹œκ·Έλ„ ν•Έλ“€λŸ¬ λ‚΄μ—μ„œ μ•ˆμ „ν•˜κ²Œ)
if args.log_file is not None:
try:
with open(args.log_file, "a", encoding="utf-8") as f:
f.write(msg + "\n")
except Exception:
pass # μ‹œκ·Έλ„ ν•Έλ“€λŸ¬ λ‚΄μ—μ„œλŠ” μ˜ˆμ™Έ λ¬΄μ‹œ
_trainer_ref.request_shutdown(sig_name)
for _sig in (signal.SIGHUP, signal.SIGTERM):
signal.signal(_sig, _graceful_shutdown_handler)
if is_main_process():
import datetime
eff_tokens_per_step = args.batch_size * lm_config.max_seq_len * args.grad_accum * world_size
nccl_debug = os.environ.get("NCCL_DEBUG", "not set")
omp_threads = os.environ.get("OMP_NUM_THREADS", "not set")
print(
f"\n{'='*70}\n"
f" LLM Pretraining β€” {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
f"{'='*70}\n"
f" model : {lm_config.num_params:,} params | "
f"d_model={lm_config.d_model} n_layers={lm_config.n_layers}\n"
f" precision : {'FP8 (MXFP8BlockScaling)' if lm_config.use_fp8 else 'BF16'}\n"
f" GPUs : {world_size} | batch/GPU={args.batch_size} "
f"grad_accum={args.grad_accum}\n"
f" eff_batch : {args.batch_size * args.grad_accum * world_size} seqs "
f"= {eff_tokens_per_step:,} tok/step\n"
f" max_steps : {train_config.max_steps:,} "
f"({train_config.max_steps * eff_tokens_per_step / 1e9:.1f}B tokens total)\n"
f" data : {args.train_data}\n"
f" ckpt_dir : {args.checkpoint_dir}\n"
f" env : OMP_NUM_THREADS={omp_threads} NCCL_DEBUG={nccl_debug}\n"
f"{'='*70}\n"
)
try:
trainer.train(start_step=start_step)
# ν•™μŠ΅ μ™„λ£Œ λ˜λŠ” graceful shutdown ν›„ μƒνƒœ 좜λ ₯
if is_main_process():
if trainer._shutdown_requested:
print(
f"\n[INFO] Training gracefully shut down via {trainer._shutdown_signal}. "
f"Emergency checkpoint saved. Resume with same command."
)
else:
print("\n[INFO] Training completed successfully.")
except KeyboardInterrupt:
if is_main_process():
print("\n[INFO] Training interrupted by user (KeyboardInterrupt).")
except Exception as e:
import traceback
if is_main_process():
tb = traceback.format_exc()
print(f"\n[ERROR] Training failed at rank {rank}:\n{tb}")
# log_file에도 기둝
if args.log_file is not None:
with open(args.log_file, "a", encoding="utf-8") as f:
import datetime
f.write(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] [FATAL] {tb}\n")
raise
finally:
if is_ddp:
cleanup_ddp()
# Note: DDP cleanup is handled in the try/finally block above.
if __name__ == "__main__":
main()