pathcosmos's picture
Upload folder using huggingface_hub (#14)
17758b3
"""
train/sft.py — Supervised Fine-Tuning (SFT) entry point.
Loads a pretrained checkpoint and fine-tunes it on instruction/conversation
data using SFTDataset, which masks prompt tokens with ignore_index=-1 so only
the assistant response tokens contribute to the loss.
Launch single-GPU:
python train/sft.py \\
--base_checkpoint checkpoints/korean_1b_fp8_run1/checkpoint-0034000 \\
--sft_data data/sft/train.jsonl \\
--device cuda:0
Launch multi-GPU (DDP via torchrun):
torchrun --nproc_per_node=8 train/sft.py \\
--base_checkpoint checkpoints/korean_1b_fp8_run1/checkpoint-0034000 \\
--sft_data data/sft/train.jsonl
KEY DIFFERENCES from pretrain.py:
- Loads weights from a pretrained checkpoint via LLM.from_pretrained()
- Uses SFTDataset (JSONL instruction data) instead of PackedDataset
- Lower default learning rate (2e-5 vs 2e-4)
- Fewer default steps (3000 vs 100000)
- Copies tokenizer.json to checkpoint_dir for easy deployment
"""
from __future__ import annotations
import argparse
import os
import random
import signal
import shutil
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
# ---------------------------------------------------------------------------
# Data Mixing: Interleave SFT + Pretrain batches for forgetting prevention
# ---------------------------------------------------------------------------
class MixingDataLoader:
"""
Wraps two DataLoaders and yields batches from one or the other
based on a probability ratio.
With ``pretrain_ratio=0.3``, 30% of batches come from the pretrain
loader and 70% from the SFT loader. Both loaders cycle infinitely.
This is duck-type compatible with DataLoader for the Trainer's needs:
- ``__iter__`` yields batches
- ``__len__`` returns the SFT loader length (used for epoch estimation)
"""
def __init__(
self,
sft_loader: DataLoader,
pretrain_loader: DataLoader,
pretrain_ratio: float = 0.3,
sft_sampler: DistributedSampler | RandomSampler | None = None,
pretrain_sampler: DistributedSampler | RandomSampler | None = None,
) -> None:
self.sft_loader = sft_loader
self.pretrain_loader = pretrain_loader
self.pretrain_ratio = pretrain_ratio
self.sft_sampler = sft_sampler
self.pretrain_sampler = pretrain_sampler
self._epoch = 0
def __len__(self) -> int:
return len(self.sft_loader)
def __iter__(self):
sft_iter = iter(self.sft_loader)
pt_iter = iter(self.pretrain_loader)
while True:
use_pretrain = random.random() < self.pretrain_ratio
try:
if use_pretrain:
batch = next(pt_iter)
else:
batch = next(sft_iter)
except StopIteration:
# Whichever exhausted, restart it
if use_pretrain:
self._epoch += 1
if self.pretrain_sampler is not None and hasattr(self.pretrain_sampler, 'set_epoch'):
self.pretrain_sampler.set_epoch(self._epoch)
pt_iter = iter(self.pretrain_loader)
try:
batch = next(pt_iter)
except StopIteration:
raise RuntimeError(
"Pretrain DataLoader is empty after restart. "
"Check pretrain_data path and drop_last settings."
)
else:
self._epoch += 1
if self.sft_sampler is not None and hasattr(self.sft_sampler, 'set_epoch'):
self.sft_sampler.set_epoch(self._epoch)
sft_iter = iter(self.sft_loader)
try:
batch = next(sft_iter)
except StopIteration:
raise RuntimeError(
"SFT DataLoader is empty after restart. "
"Check sft_data path and drop_last settings."
)
yield batch
# B200 Tensor Core 최대 활용: TF32 matmul + cuDNN
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high") # TF32 precision for fp32 matmul
# Allow imports from the project root regardless of working directory.
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
from model import LLM
from train.trainer import TrainConfig, Trainer
from train.utils import (
cleanup_ddp,
get_cosine_schedule_with_warmup,
is_main_process,
load_checkpoint,
setup_ddp,
)
# ---------------------------------------------------------------------------
# Optional TransformerEngine import (FP8 support)
# ---------------------------------------------------------------------------
try:
import transformer_engine.pytorch as te # type: ignore[import]
HAS_TE = True
except ImportError:
te = None # type: ignore[assignment]
HAS_TE = False
# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Supervised Fine-Tuning (SFT) of a pretrained decoder-only LLM.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# --- Required paths -----------------------------------------------------
parser.add_argument(
"--base_checkpoint",
type=Path,
required=True,
help=(
"Path to the pretrained checkpoint directory. "
"Must contain model.pt and config.yaml (produced by save_checkpoint)."
),
)
parser.add_argument(
"--sft_data",
type=Path,
required=True,
help="Path to the JSONL SFT training data file.",
)
# --- Optional paths -----------------------------------------------------
parser.add_argument(
"--val_data",
type=Path,
default=None,
help="Optional path to JSONL SFT validation data file.",
)
parser.add_argument(
"--checkpoint_dir",
type=Path,
default=Path("checkpoints/korean_1b_sft"),
help="Root directory for saving SFT checkpoints.",
)
parser.add_argument(
"--resume",
type=Path,
default=None,
help="Path to an SFT checkpoint directory to resume fine-tuning from.",
)
parser.add_argument(
"--tokenizer",
type=Path,
default=None,
help=(
"Override path to tokenizer.json. "
"Defaults to <base_checkpoint>/tokenizer.json, "
"then falls back to tokenizer/korean_sp/tokenizer.json."
),
)
parser.add_argument(
"--log_file",
type=Path,
default=None,
help=(
"Path to a text file for structured training logs (rank-0 only). "
"If omitted, logs go only to stdout."
),
)
# --- Training hyper-parameters ------------------------------------------
parser.add_argument(
"--max_steps",
type=int,
default=3000,
help="Total number of optimiser steps.",
)
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="Per-GPU micro-batch size.",
)
parser.add_argument(
"--lr",
type=float,
default=2e-5,
help=(
"Peak learning rate. "
"SFT uses a much lower lr than pretraining (2e-5 vs 2e-4) "
"to preserve pretrained representations."
),
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.01,
help="AdamW weight decay. Lower than pretrain (0.01 vs 0.1).",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=100,
help="Number of linear LR warmup steps.",
)
parser.add_argument(
"--grad_accum",
type=int,
default=2,
help="Gradient accumulation steps.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Base random seed (rank offset is added automatically in DDP).",
)
parser.add_argument(
"--use_fp8",
action="store_true",
default=False,
help=(
"Enable TransformerEngine FP8 training "
"(requires B200/H100, uses MXFP8BlockScaling)."
),
)
# --- Single-GPU device override (ignored when using torchrun) -----------
parser.add_argument(
"--device",
type=str,
default=None,
help=(
"Explicit device string (e.g. 'cuda:0'). "
"Ignored when running under torchrun (DDP auto-assigns devices)."
),
)
parser.add_argument(
"--config", type=Path, default=None,
help="YAML config file. Values under 'train:' section are used as CLI defaults.",
)
parser.add_argument("--save_interval", type=int, default=500, help="Checkpoint save interval (steps).")
parser.add_argument("--eval_interval", type=int, default=250, help="Validation eval interval (steps).")
parser.add_argument("--neftune_alpha", type=float, default=5.0, help="NEFTune noise magnitude (0 to disable).")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient L2 norm for clipping.")
# --- Data mixing (forgetting prevention) --------------------------------
parser.add_argument(
"--pretrain_data",
type=Path,
default=None,
help="Path to pretrain .bin file for data mixing. Enables SFT+pretrain interleaving.",
)
parser.add_argument(
"--pretrain_mix_ratio",
type=float,
default=0.3,
help="Fraction of batches from pretrain data (0.3 = 30%% pretrain, 70%% SFT).",
)
# First pass: just get --config
args, remaining = parser.parse_known_args()
# Load YAML config and apply values as defaults
if args.config is not None:
if not args.config.exists():
raise FileNotFoundError(f"Config file not found: {args.config}")
import yaml
with open(args.config, "r") as f:
yaml_cfg = yaml.safe_load(f)
train_section = yaml_cfg.get("train", {})
yaml_to_arg = {
"max_steps": "max_steps",
"batch_size": "batch_size",
"lr": "lr",
"weight_decay": "weight_decay",
"warmup_steps": "warmup_steps",
"grad_accum_steps": "grad_accum",
"save_interval": "save_interval",
"eval_interval": "eval_interval",
"neftune_alpha": "neftune_alpha",
"pretrain_mix_ratio": "pretrain_mix_ratio",
"max_grad_norm": "max_grad_norm",
}
# pretrain_data is a Path, handle separately
if "pretrain_data" in train_section:
parser.set_defaults(pretrain_data=Path(train_section["pretrain_data"]))
new_defaults = {}
for yaml_key, arg_name in yaml_to_arg.items():
if yaml_key in train_section:
new_defaults[arg_name] = train_section[yaml_key]
if new_defaults:
parser.set_defaults(**new_defaults)
return parser.parse_args()
# ---------------------------------------------------------------------------
# Seed helper
# ---------------------------------------------------------------------------
def set_seed(seed: int) -> None:
"""Set deterministic seeds for Python, NumPy, and PyTorch."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ---------------------------------------------------------------------------
# Optimizer parameter groups
# (Copied from pretrain.py to avoid circular import; identical logic)
# ---------------------------------------------------------------------------
def build_optimizer_param_groups(
model: torch.nn.Module,
weight_decay: float,
) -> list[dict]:
"""
Split parameters into two groups:
- decay group : weight tensors with ndim >= 2 (Linear, etc.)
- no-decay group: bias, LayerNorm/RMSNorm weights, and embedding weights
This follows standard practice (e.g. GPT-style training).
"""
decay_params: list[torch.nn.Parameter] = []
no_decay_params: list[torch.nn.Parameter] = []
# Module types whose parameters should never be decayed.
no_decay_module_types = (
torch.nn.Embedding,
torch.nn.LayerNorm,
)
# Also skip any parameter whose name ends with '.bias'.
no_decay_name_suffixes = ("bias",)
# Collect module-level exclusions.
no_decay_module_params: set[int] = set()
for module in model.modules():
if isinstance(module, no_decay_module_types):
for param in module.parameters(recurse=False):
no_decay_module_params.add(id(param))
seen: set[int] = set()
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if id(param) in seen:
continue
seen.add(id(param))
if (
id(param) in no_decay_module_params
or any(name.endswith(sfx) for sfx in no_decay_name_suffixes)
or param.ndim < 2
):
no_decay_params.append(param)
else:
decay_params.append(param)
return [
{"params": decay_params, "weight_decay": weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
# ---------------------------------------------------------------------------
# Tokenizer resolution helper
# ---------------------------------------------------------------------------
def _resolve_tokenizer_path(args: argparse.Namespace) -> Path:
"""
Determine the tokenizer path in priority order:
1. Explicit --tokenizer argument
2. tokenizer.json inside the base_checkpoint directory
3. Project default: tokenizer/korean_sp/tokenizer.json
"""
if args.tokenizer is not None:
p = Path(args.tokenizer)
if not p.exists():
raise FileNotFoundError(f"Tokenizer not found at --tokenizer path: {p}")
return p
ckpt_tok = args.base_checkpoint / "tokenizer.json"
if ckpt_tok.exists():
return ckpt_tok
default_tok = _PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json"
if default_tok.exists():
return default_tok
raise FileNotFoundError(
"Could not locate tokenizer.json. Tried:\n"
f" 1. {ckpt_tok}\n"
f" 2. {default_tok}\n"
"Use --tokenizer to specify an explicit path."
)
# ---------------------------------------------------------------------------
# Dynamic padding collate function
# ---------------------------------------------------------------------------
def dynamic_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Collate function that pads each batch to its own maximum sequence length
instead of a fixed global max_seq_len. This reduces wasted FLOPs on
short sequences and speeds up SFT which tends to have highly variable
response lengths.
Pads to the batch-local max, aligned to 64 tokens (for Flash Attention
efficiency), with a floor of 512 tokens so micro-batches are not too short.
Args:
batch: List of ``(input_ids, labels)`` tuples from SFTDataset.
Returns:
Tuple of ``(input_ids, labels, attention_mask)`` tensors shaped
``[B, max_len]``.
``input_ids`` is right-padded with 0 (pad token).
``labels`` is right-padded with -1 (cross-entropy ignore_index).
``attention_mask`` is 1 for real tokens, 0 for padding.
"""
# 64-token alignment + minimum 512 floor
raw_max = max(item[0].size(0) for item in batch)
max_len = max(512, ((raw_max + 63) // 64) * 64)
input_ids_list, labels_list, mask_list = [], [], []
for ids, labs in batch:
pad_len = max_len - ids.size(0)
input_ids_list.append(F.pad(ids, (0, pad_len), value=0))
labels_list.append(F.pad(labs, (0, pad_len), value=-1))
mask_list.append(
F.pad(torch.ones(ids.size(0), dtype=torch.long), (0, pad_len), value=0)
)
return (
torch.stack(input_ids_list),
torch.stack(labels_list),
torch.stack(mask_list),
)
# ---------------------------------------------------------------------------
# NEFTune helper
# ---------------------------------------------------------------------------
def add_neftune_hook(model: torch.nn.Module, noise_alpha: float = 10.0):
"""
Register a forward hook on the model's input embedding layer that adds
uniform noise scaled by noise_alpha during training (NEFTune).
Reference: "NEFTune: Noisy Embeddings Improve Instruction Finetuning"
(Jain et al., 2023). https://arxiv.org/abs/2310.05914
Args:
model: Raw (non-DDP) model instance.
noise_alpha: Noise magnitude parameter (paper default: 10).
Returns:
The hook handle (call ``handle.remove()`` to deactivate), or None if
the embedding layer could not be located.
"""
# Unwrap DDP if needed
raw = model.module if hasattr(model, "module") else model
# 1) Try the standard HuggingFace accessor first.
embedding: torch.nn.Embedding | None = None
if hasattr(raw, "get_input_embeddings"):
try:
emb = raw.get_input_embeddings()
if isinstance(emb, torch.nn.Embedding):
embedding = emb
except Exception:
pass
# 2) Fallback: walk common attribute paths found in open-source LLMs.
if embedding is None:
for attr_path in [
"embedding",
"embed_tokens",
"token_embedding",
"wte",
"word_embeddings",
"tok_embeddings",
"transformer.wte",
"model.embed_tokens",
"model.embedding",
]:
obj = raw
for part in attr_path.split("."):
obj = getattr(obj, part, None)
if obj is None:
break
if obj is not None and isinstance(obj, torch.nn.Embedding):
embedding = obj
break
if embedding is None:
print("[WARN] NEFTune: embedding layer을 찾지 못함, NEFTune 비활성화")
return None
print(
f"[INFO] NEFTune: {type(embedding).__name__} hook 등록 "
f"(shape={tuple(embedding.weight.shape)}, alpha={noise_alpha})"
)
def _hook(
module: torch.nn.Module,
inp: tuple,
out: torch.Tensor,
) -> torch.Tensor:
if module.training:
# out shape: [B, seq_len, d_model]
mag = noise_alpha / ((out.size(1) * out.size(2)) ** 0.5)
out = out + torch.empty_like(out).uniform_(-mag, mag)
return out
return embedding.register_forward_hook(_hook)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
args = parse_args()
# ---- Distributed setup -------------------------------------------------
is_ddp = "RANK" in os.environ
rank = 0
local_rank = 0
world_size = 1
if is_ddp:
rank, local_rank, world_size, device = setup_ddp()
else:
# Single-GPU: honour --device flag, else pick cuda:0 or cpu.
if args.device is not None:
device = torch.device(args.device)
elif torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
# Per-rank seed so data shuffling differs across replicas.
set_seed(args.seed + rank)
# ---- NUMA affinity for optimal GPU↔CPU memory locality ---------------
# B200 topology: GPU 0-3 → NUMA node 0 (cores 0-35)
# GPU 4-7 → NUMA node 1 (cores 36-71)
try:
if local_rank < 4:
os.sched_setaffinity(0, set(range(0, 36))) # NUMA node 0
else:
os.sched_setaffinity(0, set(range(36, 72))) # NUMA node 1
if is_main_process():
print(f"NUMA affinity: rank {rank} (GPU {local_rank}) → "
f"{'NUMA0 cores 0-35' if local_rank < 4 else 'NUMA1 cores 36-71'}")
except (AttributeError, OSError) as e:
if is_main_process():
print(f"[WARN] NUMA affinity failed: {e}")
# ---- Validate base checkpoint ------------------------------------------
if not args.base_checkpoint.exists():
raise FileNotFoundError(
f"Base checkpoint directory not found: {args.base_checkpoint}"
)
for required_file in ("model.pt", "config.yaml"):
if not (args.base_checkpoint / required_file).exists():
raise FileNotFoundError(
f"Expected {required_file} inside base checkpoint: {args.base_checkpoint}"
)
# ---- Load pretrained model ---------------------------------------------
# LLM.from_pretrained() reads config.yaml + model.pt and returns the model on CPU.
# We move it to the target device immediately after loading.
#
# NOTE: fp8_model_init() is intentionally NOT used here (same as pretrain.py).
# MXFP8Tensor weights are incompatible with DDP's _broadcast_coalesced.
# Weights stay in float32; TransformerEngine quantizes on-the-fly inside fp8_autocast.
model = LLM.from_pretrained(args.base_checkpoint)
# When FP8 flag is passed at SFT time, enable it on the loaded config.
# This is useful if the pretrained model was trained without FP8 but you
# want to fine-tune with FP8 precision (the TE layers must exist in the model).
if args.use_fp8:
model.config.use_fp8 = True
# Move model to target device in bfloat16 (more memory-efficient than fp32
# for fine-tuning, and required when BF16 autocast + TE are active).
model = model.to(device=device, dtype=torch.bfloat16)
# ---- Gradient checkpointing ----------------------------------------
# Trades activation memory for recomputation during backward pass.
# Especially useful for large models / long sequences in SFT.
if hasattr(model, 'gradient_checkpointing_enable'):
model.gradient_checkpointing_enable()
if rank == 0:
print("[INFO] Gradient checkpointing enabled")
# FP8 alignment check: (batch_size × seq_len) must be divisible by 8.
if model.config.use_fp8:
seq_len = model.config.max_seq_len
if (args.batch_size * seq_len) % 8 != 0:
raise ValueError(
f"FP8: batch_size × max_seq_len = {args.batch_size} × {seq_len} "
f"= {args.batch_size * seq_len} must be divisible by 8."
)
if is_main_process():
total_params = sum(p.numel() for p in model.parameters())
print(f"Pretrained model loaded: {total_params:,} parameters")
print(f"LMConfig: {model.config}")
# ---- Wrap in DDP -------------------------------------------------------
if is_ddp:
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
gradient_as_bucket_view=True,
bucket_cap_mb=800,
find_unused_parameters=False,
)
# ---- Tokenizer ---------------------------------------------------------
tokenizer_path = _resolve_tokenizer_path(args)
if is_main_process():
print(f"Loading tokenizer from: {tokenizer_path}")
# Use the fast tokenizers library (same as the rest of the project).
from tokenizers import Tokenizer # type: ignore[import]
tokenizer = Tokenizer.from_file(str(tokenizer_path))
# ---- Dataset & DataLoader ----------------------------------------------
# Import SFTDataset (created separately alongside this file).
# SFTDataset returns (input_ids, targets) where prompt token positions in
# targets are filled with -1. The Trainer._compute_loss already uses
# ignore_index=-1, so only response tokens contribute to the gradient.
from data.sft_dataset import SFTDataset # type: ignore[import]
max_seq_len_cfg = (
model.config.max_seq_len
if not isinstance(model, torch.nn.parallel.DistributedDataParallel)
else model.module.config.max_seq_len
)
# DDP optimization: rank 0 does tokenization + cache, other ranks load cache.
# This avoids 8× redundant work and 8× memory usage.
# Rank 0 gets all 64 CPU cores for parallel tokenization (reserve 8 for system)
tok_workers = 64 if is_main_process() else 0
if is_ddp:
if is_main_process():
# Rank 0: full tokenization with parallel workers + save cache
train_dataset = SFTDataset(
data_path=args.sft_data,
tokenizer=tokenizer,
max_seq_len=max_seq_len_cfg,
tokenizer_path=tokenizer_path,
num_workers=tok_workers,
)
# Barrier: wait for rank 0 to finish tokenization and save cache
torch.distributed.barrier()
if not is_main_process():
# Other ranks: load from cache (rank 0 already saved it)
train_dataset = SFTDataset(
data_path=args.sft_data,
tokenizer=tokenizer,
max_seq_len=max_seq_len_cfg,
)
else:
train_dataset = SFTDataset(
data_path=args.sft_data,
tokenizer=tokenizer,
max_seq_len=max_seq_len_cfg,
tokenizer_path=tokenizer_path,
num_workers=tok_workers,
)
if is_ddp:
train_sampler: DistributedSampler | RandomSampler = DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=args.seed,
)
shuffle = False
else:
train_sampler = RandomSampler(train_dataset)
shuffle = False # Sampler is provided; DataLoader must not also shuffle.
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
# SFT datasets are typically small enough that 2–4 workers suffice.
# We use 4 to balance I/O with CPU parsing overhead from JSONL.
num_workers=4,
pin_memory=True,
drop_last=True,
prefetch_factor=2,
persistent_workers=True,
collate_fn=dynamic_collate_fn,
)
# ---- Pretrain Data Mixing (forgetting prevention) -----------------------
# When --pretrain_data is specified, create a second DataLoader for pretrain
# data and wrap both in MixingDataLoader. This interleaves SFT and pretrain
# batches (default 70/30 ratio) so the model retains pretrained knowledge.
pretrain_sampler = None
if args.pretrain_data is not None:
if not args.pretrain_data.exists():
raise FileNotFoundError(f"Pretrain data not found: {args.pretrain_data}")
from data import PackedDataset
max_seq_len = (
model.config.max_seq_len
if not isinstance(model, torch.nn.parallel.DistributedDataParallel)
else model.module.config.max_seq_len
)
pretrain_dataset = PackedDataset(args.pretrain_data, seq_len=max_seq_len)
if is_ddp:
pretrain_sampler = DistributedSampler(
pretrain_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=args.seed + 1000, # different seed from SFT
)
else:
pretrain_sampler = RandomSampler(pretrain_dataset)
pretrain_loader = DataLoader(
pretrain_dataset,
batch_size=args.batch_size,
sampler=pretrain_sampler,
num_workers=4,
pin_memory=True,
drop_last=True,
prefetch_factor=2,
persistent_workers=True,
)
# Wrap both loaders in MixingDataLoader
effective_loader = MixingDataLoader(
sft_loader=train_loader,
pretrain_loader=pretrain_loader,
pretrain_ratio=args.pretrain_mix_ratio,
sft_sampler=train_sampler if is_ddp else None,
pretrain_sampler=pretrain_sampler if is_ddp else None,
)
if is_main_process():
print(
f"[INFO] Data mixing enabled: "
f"{(1 - args.pretrain_mix_ratio) * 100:.0f}% SFT + "
f"{args.pretrain_mix_ratio * 100:.0f}% pretrain"
)
print(f"[INFO] Pretrain data: {args.pretrain_data} ({len(pretrain_dataset):,} samples)")
else:
effective_loader = train_loader
# Optional validation loader.
# NOTE: The current Trainer implementation does not yet accept a val_loader
# argument; the eval_interval config field is reserved for future use.
# We construct the loader here so that once Trainer gains eval support,
# wiring it in requires only passing val_loader=val_loader below.
val_loader: DataLoader | None = None
if args.val_data is not None:
if not args.val_data.exists():
raise FileNotFoundError(f"Validation data not found: {args.val_data}")
if is_ddp:
if is_main_process():
val_dataset = SFTDataset(
data_path=args.val_data,
tokenizer=tokenizer,
max_seq_len=train_dataset.max_seq_len,
tokenizer_path=tokenizer_path,
num_workers=tok_workers,
)
torch.distributed.barrier()
if not is_main_process():
val_dataset = SFTDataset(
data_path=args.val_data,
tokenizer=tokenizer,
max_seq_len=train_dataset.max_seq_len,
)
else:
val_dataset = SFTDataset(
data_path=args.val_data,
tokenizer=tokenizer,
max_seq_len=train_dataset.max_seq_len,
tokenizer_path=tokenizer_path,
num_workers=tok_workers,
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=2,
pin_memory=True,
drop_last=False,
collate_fn=dynamic_collate_fn,
)
if is_main_process():
print(f"Validation dataset: {len(val_dataset):,} samples")
# ---- Optimizer ---------------------------------------------------------
# Use the same two-group split (weight_decay / no weight_decay) as pretrain.
# Unwrap DDP to get the raw model's parameters.
raw_model = getattr(model, "module", model)
param_groups = build_optimizer_param_groups(raw_model, args.weight_decay)
optimizer = torch.optim.AdamW(
param_groups,
lr=args.lr,
betas=(0.9, 0.95),
eps=1e-8,
fused=torch.cuda.is_available(), # Use fused kernel when on CUDA.
)
# ---- TrainConfig -------------------------------------------------------
# Set use_fp8 from the (possibly overridden) model config so Trainer builds
# the correct FP8 recipe and wraps forward passes in fp8_autocast.
use_fp8 = raw_model.config.use_fp8
train_config = TrainConfig(
max_steps=args.max_steps,
checkpoint_dir=str(args.checkpoint_dir),
grad_accum_steps=args.grad_accum,
use_fp8=use_fp8,
log_file=str(args.log_file) if args.log_file is not None else None,
save_interval=args.save_interval,
log_interval=10,
eval_interval=args.eval_interval,
max_grad_norm=args.max_grad_norm,
)
# ---- LR Scheduler ------------------------------------------------------
scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
warmup_steps=args.warmup_steps,
total_steps=train_config.max_steps,
)
# ---- Resume from SFT checkpoint ----------------------------------------
# When --resume is given we restore the SFT optimizer/scheduler state as
# well so learning rate, momentum buffers, etc. are correctly restored.
# NOTE: This resumes SFT training, NOT the pretrain checkpoint.
# The pretrain weights were already loaded above via from_pretrained().
start_step = 0
if args.resume is not None:
if not args.resume.exists():
raise FileNotFoundError(f"Resume checkpoint not found: {args.resume}")
start_step, resume_loss = load_checkpoint(
path=args.resume,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
if is_main_process():
print(f"Resumed SFT from {args.resume} at step {start_step} (loss={resume_loss:.4f})")
if args.resume is not None and isinstance(train_sampler, DistributedSampler):
steps_per_epoch = len(train_loader)
approx_epoch = start_step // steps_per_epoch if steps_per_epoch > 0 else 0
train_sampler.set_epoch(approx_epoch)
if is_main_process():
print(f"[INFO] Resume: sampler epoch set to {approx_epoch}")
# ---- Checkpoint directory ----------------------------------------------
args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# ---- Copy tokenizer to checkpoint dir for easy deployment later --------
# This mirrors the tokenizer into the SFT checkpoint root so that the
# final checkpoint directory is self-contained for convert_to_hf.py, etc.
if is_main_process():
dest_tok = args.checkpoint_dir / "tokenizer.json"
if not dest_tok.exists():
shutil.copy2(str(tokenizer_path), str(dest_tok))
print(f"Tokenizer copied to {dest_tok}")
# ---- Trainer -----------------------------------------------------------
# When data mixing is active, pass effective_loader (MixingDataLoader).
# MixingDataLoader handles its own epoch cycling, so no external sampler needed.
trainer = Trainer(
model=model,
train_loader=effective_loader,
optimizer=optimizer,
scheduler=scheduler,
config=train_config,
device=device,
rank=rank,
sampler=train_sampler if is_ddp and args.pretrain_data is None else None,
val_loader=val_loader,
)
# ---- Signal handlers for graceful shutdown ----------------------------
import signal as _signal_mod
_trainer_ref = trainer
def _graceful_shutdown_handler(signum, frame):
sig_name = _signal_mod.Signals(signum).name
if is_main_process():
import datetime as _dt
ts = _dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
msg = (
f"[{ts}] [SIGNAL] Received {sig_name} (signum={signum}). "
f"Initiating graceful shutdown..."
)
print(f"\n{msg}")
if args.log_file is not None:
try:
with open(args.log_file, "a", encoding="utf-8") as f:
f.write(msg + "\n")
except Exception:
pass
_trainer_ref.request_shutdown(sig_name)
for _sig in (_signal_mod.SIGHUP, _signal_mod.SIGTERM):
_signal_mod.signal(_sig, _graceful_shutdown_handler)
# ---- SFT banner --------------------------------------------------------
if is_main_process():
import datetime
inner_config = raw_model.config
eff_batch_seqs = args.batch_size * args.grad_accum * world_size
eff_tokens_per_step = eff_batch_seqs * inner_config.max_seq_len
train_samples = len(train_dataset)
precision_label = "FP8 (MXFP8BlockScaling)" if use_fp8 else "BF16"
nccl_debug = os.environ.get("NCCL_DEBUG", "not set")
omp_threads = os.environ.get("OMP_NUM_THREADS", "not set")
mix_label = "none"
if args.pretrain_data is not None:
mix_label = (
f"{(1 - args.pretrain_mix_ratio) * 100:.0f}% SFT + "
f"{args.pretrain_mix_ratio * 100:.0f}% pretrain"
)
print(
f"\n{'='*70}\n"
f" LLM Supervised Fine-Tuning — "
f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
f"{'='*70}\n"
f" base ckpt : {args.base_checkpoint}\n"
f" sft data : {args.sft_data} ({train_samples:,} samples)\n"
f" data mix : {mix_label}\n"
f" model : {inner_config.num_params:,} params | "
f"d_model={inner_config.d_model} n_layers={inner_config.n_layers}\n"
f" precision : {precision_label}\n"
f" GPUs : {world_size} | batch/GPU={args.batch_size} "
f"grad_accum={args.grad_accum}\n"
f" eff_batch : {eff_batch_seqs} seqs "
f"= {eff_tokens_per_step:,} tok/step\n"
f" max_steps : {train_config.max_steps:,}\n"
f" lr : {args.lr:.2e} "
f"warmup={args.warmup_steps} weight_decay={args.weight_decay}\n"
f" ckpt_dir : {args.checkpoint_dir}\n"
f" env : OMP_NUM_THREADS={omp_threads} NCCL_DEBUG={nccl_debug}\n"
f"{'='*70}\n"
)
# ---- NEFTune -----------------------------------------------------------
# Add uniform noise to embeddings during training to improve instruction
# following (Jain et al., 2023). Hook is registered on the raw (non-DDP)
# model so it survives DDP's internal module wrapping.
neftune_alpha = getattr(args, 'neftune_alpha', 5.0)
neftune_handle = add_neftune_hook(raw_model, noise_alpha=neftune_alpha)
if rank == 0:
if neftune_handle is not None:
print(f"[INFO] NEFTune enabled (noise_alpha={neftune_alpha})")
else:
print("[WARN] NEFTune disabled - embedding layer not found")
# ---- Train -------------------------------------------------------------
try:
trainer.train(start_step=start_step)
except KeyboardInterrupt:
if is_main_process():
print("\n[INFO] SFT interrupted by user (KeyboardInterrupt).")
except Exception as e:
import traceback
if is_main_process():
tb = traceback.format_exc()
print(f"\n[ERROR] SFT failed at rank {rank}:\n{tb}")
if args.log_file is not None:
with open(args.log_file, "a", encoding="utf-8") as f:
import datetime
f.write(
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"[FATAL] {tb}\n"
)
raise
finally:
# Remove NEFTune hook so the model is clean for inference/saving.
if neftune_handle is not None:
neftune_handle.remove()
if is_ddp:
cleanup_ddp()
if __name__ == "__main__":
main()