| | """ |
| | train/sft.py — Supervised Fine-Tuning (SFT) entry point. |
| | |
| | Loads a pretrained checkpoint and fine-tunes it on instruction/conversation |
| | data using SFTDataset, which masks prompt tokens with ignore_index=-1 so only |
| | the assistant response tokens contribute to the loss. |
| | |
| | Launch single-GPU: |
| | python train/sft.py \\ |
| | --base_checkpoint checkpoints/korean_1b_fp8_run1/checkpoint-0034000 \\ |
| | --sft_data data/sft/train.jsonl \\ |
| | --device cuda:0 |
| | |
| | Launch multi-GPU (DDP via torchrun): |
| | torchrun --nproc_per_node=8 train/sft.py \\ |
| | --base_checkpoint checkpoints/korean_1b_fp8_run1/checkpoint-0034000 \\ |
| | --sft_data data/sft/train.jsonl |
| | |
| | KEY DIFFERENCES from pretrain.py: |
| | - Loads weights from a pretrained checkpoint via LLM.from_pretrained() |
| | - Uses SFTDataset (JSONL instruction data) instead of PackedDataset |
| | - Lower default learning rate (2e-5 vs 2e-4) |
| | - Fewer default steps (3000 vs 100000) |
| | - Copies tokenizer.json to checkpoint_dir for easy deployment |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import os |
| | import random |
| | import signal |
| | import shutil |
| | import sys |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, DistributedSampler, RandomSampler |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class MixingDataLoader: |
| | """ |
| | Wraps two DataLoaders and yields batches from one or the other |
| | based on a probability ratio. |
| | |
| | With ``pretrain_ratio=0.3``, 30% of batches come from the pretrain |
| | loader and 70% from the SFT loader. Both loaders cycle infinitely. |
| | |
| | This is duck-type compatible with DataLoader for the Trainer's needs: |
| | - ``__iter__`` yields batches |
| | - ``__len__`` returns the SFT loader length (used for epoch estimation) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | sft_loader: DataLoader, |
| | pretrain_loader: DataLoader, |
| | pretrain_ratio: float = 0.3, |
| | sft_sampler: DistributedSampler | RandomSampler | None = None, |
| | pretrain_sampler: DistributedSampler | RandomSampler | None = None, |
| | ) -> None: |
| | self.sft_loader = sft_loader |
| | self.pretrain_loader = pretrain_loader |
| | self.pretrain_ratio = pretrain_ratio |
| | self.sft_sampler = sft_sampler |
| | self.pretrain_sampler = pretrain_sampler |
| | self._epoch = 0 |
| |
|
| | def __len__(self) -> int: |
| | return len(self.sft_loader) |
| |
|
| | def __iter__(self): |
| | sft_iter = iter(self.sft_loader) |
| | pt_iter = iter(self.pretrain_loader) |
| |
|
| | while True: |
| | use_pretrain = random.random() < self.pretrain_ratio |
| | try: |
| | if use_pretrain: |
| | batch = next(pt_iter) |
| | else: |
| | batch = next(sft_iter) |
| | except StopIteration: |
| | |
| | if use_pretrain: |
| | self._epoch += 1 |
| | if self.pretrain_sampler is not None and hasattr(self.pretrain_sampler, 'set_epoch'): |
| | self.pretrain_sampler.set_epoch(self._epoch) |
| | pt_iter = iter(self.pretrain_loader) |
| | try: |
| | batch = next(pt_iter) |
| | except StopIteration: |
| | raise RuntimeError( |
| | "Pretrain DataLoader is empty after restart. " |
| | "Check pretrain_data path and drop_last settings." |
| | ) |
| | else: |
| | self._epoch += 1 |
| | if self.sft_sampler is not None and hasattr(self.sft_sampler, 'set_epoch'): |
| | self.sft_sampler.set_epoch(self._epoch) |
| | sft_iter = iter(self.sft_loader) |
| | try: |
| | batch = next(sft_iter) |
| | except StopIteration: |
| | raise RuntimeError( |
| | "SFT DataLoader is empty after restart. " |
| | "Check sft_data path and drop_last settings." |
| | ) |
| | yield batch |
| |
|
| | |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | torch.set_float32_matmul_precision("high") |
| |
|
| | |
| | _PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| | if str(_PROJECT_ROOT) not in sys.path: |
| | sys.path.insert(0, str(_PROJECT_ROOT)) |
| |
|
| | from model import LLM |
| | from train.trainer import TrainConfig, Trainer |
| | from train.utils import ( |
| | cleanup_ddp, |
| | get_cosine_schedule_with_warmup, |
| | is_main_process, |
| | load_checkpoint, |
| | setup_ddp, |
| | ) |
| |
|
| | |
| | |
| | |
| | try: |
| | import transformer_engine.pytorch as te |
| | HAS_TE = True |
| | except ImportError: |
| | te = None |
| | HAS_TE = False |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description="Supervised Fine-Tuning (SFT) of a pretrained decoder-only LLM.", |
| | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--base_checkpoint", |
| | type=Path, |
| | required=True, |
| | help=( |
| | "Path to the pretrained checkpoint directory. " |
| | "Must contain model.pt and config.yaml (produced by save_checkpoint)." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--sft_data", |
| | type=Path, |
| | required=True, |
| | help="Path to the JSONL SFT training data file.", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--val_data", |
| | type=Path, |
| | default=None, |
| | help="Optional path to JSONL SFT validation data file.", |
| | ) |
| | parser.add_argument( |
| | "--checkpoint_dir", |
| | type=Path, |
| | default=Path("checkpoints/korean_1b_sft"), |
| | help="Root directory for saving SFT checkpoints.", |
| | ) |
| | parser.add_argument( |
| | "--resume", |
| | type=Path, |
| | default=None, |
| | help="Path to an SFT checkpoint directory to resume fine-tuning from.", |
| | ) |
| | parser.add_argument( |
| | "--tokenizer", |
| | type=Path, |
| | default=None, |
| | help=( |
| | "Override path to tokenizer.json. " |
| | "Defaults to <base_checkpoint>/tokenizer.json, " |
| | "then falls back to tokenizer/korean_sp/tokenizer.json." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--log_file", |
| | type=Path, |
| | default=None, |
| | help=( |
| | "Path to a text file for structured training logs (rank-0 only). " |
| | "If omitted, logs go only to stdout." |
| | ), |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--max_steps", |
| | type=int, |
| | default=3000, |
| | help="Total number of optimiser steps.", |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=4, |
| | help="Per-GPU micro-batch size.", |
| | ) |
| | parser.add_argument( |
| | "--lr", |
| | type=float, |
| | default=2e-5, |
| | help=( |
| | "Peak learning rate. " |
| | "SFT uses a much lower lr than pretraining (2e-5 vs 2e-4) " |
| | "to preserve pretrained representations." |
| | ), |
| | ) |
| | parser.add_argument( |
| | "--weight_decay", |
| | type=float, |
| | default=0.01, |
| | help="AdamW weight decay. Lower than pretrain (0.01 vs 0.1).", |
| | ) |
| | parser.add_argument( |
| | "--warmup_steps", |
| | type=int, |
| | default=100, |
| | help="Number of linear LR warmup steps.", |
| | ) |
| | parser.add_argument( |
| | "--grad_accum", |
| | type=int, |
| | default=2, |
| | help="Gradient accumulation steps.", |
| | ) |
| | parser.add_argument( |
| | "--seed", |
| | type=int, |
| | default=42, |
| | help="Base random seed (rank offset is added automatically in DDP).", |
| | ) |
| | parser.add_argument( |
| | "--use_fp8", |
| | action="store_true", |
| | default=False, |
| | help=( |
| | "Enable TransformerEngine FP8 training " |
| | "(requires B200/H100, uses MXFP8BlockScaling)." |
| | ), |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--device", |
| | type=str, |
| | default=None, |
| | help=( |
| | "Explicit device string (e.g. 'cuda:0'). " |
| | "Ignored when running under torchrun (DDP auto-assigns devices)." |
| | ), |
| | ) |
| |
|
| | parser.add_argument( |
| | "--config", type=Path, default=None, |
| | help="YAML config file. Values under 'train:' section are used as CLI defaults.", |
| | ) |
| | parser.add_argument("--save_interval", type=int, default=500, help="Checkpoint save interval (steps).") |
| | parser.add_argument("--eval_interval", type=int, default=250, help="Validation eval interval (steps).") |
| | parser.add_argument("--neftune_alpha", type=float, default=5.0, help="NEFTune noise magnitude (0 to disable).") |
| | parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient L2 norm for clipping.") |
| |
|
| | |
| | parser.add_argument( |
| | "--pretrain_data", |
| | type=Path, |
| | default=None, |
| | help="Path to pretrain .bin file for data mixing. Enables SFT+pretrain interleaving.", |
| | ) |
| | parser.add_argument( |
| | "--pretrain_mix_ratio", |
| | type=float, |
| | default=0.3, |
| | help="Fraction of batches from pretrain data (0.3 = 30%% pretrain, 70%% SFT).", |
| | ) |
| |
|
| | |
| | args, remaining = parser.parse_known_args() |
| |
|
| | |
| | if args.config is not None: |
| | if not args.config.exists(): |
| | raise FileNotFoundError(f"Config file not found: {args.config}") |
| | import yaml |
| | with open(args.config, "r") as f: |
| | yaml_cfg = yaml.safe_load(f) |
| | train_section = yaml_cfg.get("train", {}) |
| | yaml_to_arg = { |
| | "max_steps": "max_steps", |
| | "batch_size": "batch_size", |
| | "lr": "lr", |
| | "weight_decay": "weight_decay", |
| | "warmup_steps": "warmup_steps", |
| | "grad_accum_steps": "grad_accum", |
| | "save_interval": "save_interval", |
| | "eval_interval": "eval_interval", |
| | "neftune_alpha": "neftune_alpha", |
| | "pretrain_mix_ratio": "pretrain_mix_ratio", |
| | "max_grad_norm": "max_grad_norm", |
| | } |
| | |
| | if "pretrain_data" in train_section: |
| | parser.set_defaults(pretrain_data=Path(train_section["pretrain_data"])) |
| | new_defaults = {} |
| | for yaml_key, arg_name in yaml_to_arg.items(): |
| | if yaml_key in train_section: |
| | new_defaults[arg_name] = train_section[yaml_key] |
| | if new_defaults: |
| | parser.set_defaults(**new_defaults) |
| |
|
| | return parser.parse_args() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def set_seed(seed: int) -> None: |
| | """Set deterministic seeds for Python, NumPy, and PyTorch.""" |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def build_optimizer_param_groups( |
| | model: torch.nn.Module, |
| | weight_decay: float, |
| | ) -> list[dict]: |
| | """ |
| | Split parameters into two groups: |
| | - decay group : weight tensors with ndim >= 2 (Linear, etc.) |
| | - no-decay group: bias, LayerNorm/RMSNorm weights, and embedding weights |
| | |
| | This follows standard practice (e.g. GPT-style training). |
| | """ |
| | decay_params: list[torch.nn.Parameter] = [] |
| | no_decay_params: list[torch.nn.Parameter] = [] |
| |
|
| | |
| | no_decay_module_types = ( |
| | torch.nn.Embedding, |
| | torch.nn.LayerNorm, |
| | ) |
| | |
| | no_decay_name_suffixes = ("bias",) |
| |
|
| | |
| | no_decay_module_params: set[int] = set() |
| | for module in model.modules(): |
| | if isinstance(module, no_decay_module_types): |
| | for param in module.parameters(recurse=False): |
| | no_decay_module_params.add(id(param)) |
| |
|
| | seen: set[int] = set() |
| | for name, param in model.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| | if id(param) in seen: |
| | continue |
| | seen.add(id(param)) |
| |
|
| | if ( |
| | id(param) in no_decay_module_params |
| | or any(name.endswith(sfx) for sfx in no_decay_name_suffixes) |
| | or param.ndim < 2 |
| | ): |
| | no_decay_params.append(param) |
| | else: |
| | decay_params.append(param) |
| |
|
| | return [ |
| | {"params": decay_params, "weight_decay": weight_decay}, |
| | {"params": no_decay_params, "weight_decay": 0.0}, |
| | ] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def _resolve_tokenizer_path(args: argparse.Namespace) -> Path: |
| | """ |
| | Determine the tokenizer path in priority order: |
| | 1. Explicit --tokenizer argument |
| | 2. tokenizer.json inside the base_checkpoint directory |
| | 3. Project default: tokenizer/korean_sp/tokenizer.json |
| | """ |
| | if args.tokenizer is not None: |
| | p = Path(args.tokenizer) |
| | if not p.exists(): |
| | raise FileNotFoundError(f"Tokenizer not found at --tokenizer path: {p}") |
| | return p |
| |
|
| | ckpt_tok = args.base_checkpoint / "tokenizer.json" |
| | if ckpt_tok.exists(): |
| | return ckpt_tok |
| |
|
| | default_tok = _PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json" |
| | if default_tok.exists(): |
| | return default_tok |
| |
|
| | raise FileNotFoundError( |
| | "Could not locate tokenizer.json. Tried:\n" |
| | f" 1. {ckpt_tok}\n" |
| | f" 2. {default_tok}\n" |
| | "Use --tokenizer to specify an explicit path." |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def dynamic_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Collate function that pads each batch to its own maximum sequence length |
| | instead of a fixed global max_seq_len. This reduces wasted FLOPs on |
| | short sequences and speeds up SFT which tends to have highly variable |
| | response lengths. |
| | |
| | Pads to the batch-local max, aligned to 64 tokens (for Flash Attention |
| | efficiency), with a floor of 512 tokens so micro-batches are not too short. |
| | |
| | Args: |
| | batch: List of ``(input_ids, labels)`` tuples from SFTDataset. |
| | |
| | Returns: |
| | Tuple of ``(input_ids, labels, attention_mask)`` tensors shaped |
| | ``[B, max_len]``. |
| | ``input_ids`` is right-padded with 0 (pad token). |
| | ``labels`` is right-padded with -1 (cross-entropy ignore_index). |
| | ``attention_mask`` is 1 for real tokens, 0 for padding. |
| | """ |
| | |
| | raw_max = max(item[0].size(0) for item in batch) |
| | max_len = max(512, ((raw_max + 63) // 64) * 64) |
| |
|
| | input_ids_list, labels_list, mask_list = [], [], [] |
| | for ids, labs in batch: |
| | pad_len = max_len - ids.size(0) |
| | input_ids_list.append(F.pad(ids, (0, pad_len), value=0)) |
| | labels_list.append(F.pad(labs, (0, pad_len), value=-1)) |
| | mask_list.append( |
| | F.pad(torch.ones(ids.size(0), dtype=torch.long), (0, pad_len), value=0) |
| | ) |
| |
|
| | return ( |
| | torch.stack(input_ids_list), |
| | torch.stack(labels_list), |
| | torch.stack(mask_list), |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def add_neftune_hook(model: torch.nn.Module, noise_alpha: float = 10.0): |
| | """ |
| | Register a forward hook on the model's input embedding layer that adds |
| | uniform noise scaled by noise_alpha during training (NEFTune). |
| | |
| | Reference: "NEFTune: Noisy Embeddings Improve Instruction Finetuning" |
| | (Jain et al., 2023). https://arxiv.org/abs/2310.05914 |
| | |
| | Args: |
| | model: Raw (non-DDP) model instance. |
| | noise_alpha: Noise magnitude parameter (paper default: 10). |
| | |
| | Returns: |
| | The hook handle (call ``handle.remove()`` to deactivate), or None if |
| | the embedding layer could not be located. |
| | """ |
| | |
| | raw = model.module if hasattr(model, "module") else model |
| |
|
| | |
| | embedding: torch.nn.Embedding | None = None |
| | if hasattr(raw, "get_input_embeddings"): |
| | try: |
| | emb = raw.get_input_embeddings() |
| | if isinstance(emb, torch.nn.Embedding): |
| | embedding = emb |
| | except Exception: |
| | pass |
| |
|
| | |
| | if embedding is None: |
| | for attr_path in [ |
| | "embedding", |
| | "embed_tokens", |
| | "token_embedding", |
| | "wte", |
| | "word_embeddings", |
| | "tok_embeddings", |
| | "transformer.wte", |
| | "model.embed_tokens", |
| | "model.embedding", |
| | ]: |
| | obj = raw |
| | for part in attr_path.split("."): |
| | obj = getattr(obj, part, None) |
| | if obj is None: |
| | break |
| | if obj is not None and isinstance(obj, torch.nn.Embedding): |
| | embedding = obj |
| | break |
| |
|
| | if embedding is None: |
| | print("[WARN] NEFTune: embedding layer을 찾지 못함, NEFTune 비활성화") |
| | return None |
| |
|
| | print( |
| | f"[INFO] NEFTune: {type(embedding).__name__} hook 등록 " |
| | f"(shape={tuple(embedding.weight.shape)}, alpha={noise_alpha})" |
| | ) |
| |
|
| | def _hook( |
| | module: torch.nn.Module, |
| | inp: tuple, |
| | out: torch.Tensor, |
| | ) -> torch.Tensor: |
| | if module.training: |
| | |
| | mag = noise_alpha / ((out.size(1) * out.size(2)) ** 0.5) |
| | out = out + torch.empty_like(out).uniform_(-mag, mag) |
| | return out |
| |
|
| | return embedding.register_forward_hook(_hook) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| |
|
| | |
| | is_ddp = "RANK" in os.environ |
| | rank = 0 |
| | local_rank = 0 |
| | world_size = 1 |
| |
|
| | if is_ddp: |
| | rank, local_rank, world_size, device = setup_ddp() |
| | else: |
| | |
| | if args.device is not None: |
| | device = torch.device(args.device) |
| | elif torch.cuda.is_available(): |
| | device = torch.device("cuda:0") |
| | else: |
| | device = torch.device("cpu") |
| |
|
| | |
| | set_seed(args.seed + rank) |
| |
|
| | |
| | |
| | |
| | try: |
| | if local_rank < 4: |
| | os.sched_setaffinity(0, set(range(0, 36))) |
| | else: |
| | os.sched_setaffinity(0, set(range(36, 72))) |
| | if is_main_process(): |
| | print(f"NUMA affinity: rank {rank} (GPU {local_rank}) → " |
| | f"{'NUMA0 cores 0-35' if local_rank < 4 else 'NUMA1 cores 36-71'}") |
| | except (AttributeError, OSError) as e: |
| | if is_main_process(): |
| | print(f"[WARN] NUMA affinity failed: {e}") |
| |
|
| | |
| | if not args.base_checkpoint.exists(): |
| | raise FileNotFoundError( |
| | f"Base checkpoint directory not found: {args.base_checkpoint}" |
| | ) |
| | for required_file in ("model.pt", "config.yaml"): |
| | if not (args.base_checkpoint / required_file).exists(): |
| | raise FileNotFoundError( |
| | f"Expected {required_file} inside base checkpoint: {args.base_checkpoint}" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | model = LLM.from_pretrained(args.base_checkpoint) |
| |
|
| | |
| | |
| | |
| | if args.use_fp8: |
| | model.config.use_fp8 = True |
| |
|
| | |
| | |
| | model = model.to(device=device, dtype=torch.bfloat16) |
| |
|
| | |
| | |
| | |
| | if hasattr(model, 'gradient_checkpointing_enable'): |
| | model.gradient_checkpointing_enable() |
| | if rank == 0: |
| | print("[INFO] Gradient checkpointing enabled") |
| |
|
| | |
| | if model.config.use_fp8: |
| | seq_len = model.config.max_seq_len |
| | if (args.batch_size * seq_len) % 8 != 0: |
| | raise ValueError( |
| | f"FP8: batch_size × max_seq_len = {args.batch_size} × {seq_len} " |
| | f"= {args.batch_size * seq_len} must be divisible by 8." |
| | ) |
| |
|
| | if is_main_process(): |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | print(f"Pretrained model loaded: {total_params:,} parameters") |
| | print(f"LMConfig: {model.config}") |
| |
|
| | |
| | if is_ddp: |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| |
|
| | model = DDP( |
| | model, |
| | device_ids=[local_rank], |
| | output_device=local_rank, |
| | gradient_as_bucket_view=True, |
| | bucket_cap_mb=800, |
| | find_unused_parameters=False, |
| | ) |
| |
|
| | |
| | tokenizer_path = _resolve_tokenizer_path(args) |
| | if is_main_process(): |
| | print(f"Loading tokenizer from: {tokenizer_path}") |
| |
|
| | |
| | from tokenizers import Tokenizer |
| | tokenizer = Tokenizer.from_file(str(tokenizer_path)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | from data.sft_dataset import SFTDataset |
| |
|
| | max_seq_len_cfg = ( |
| | model.config.max_seq_len |
| | if not isinstance(model, torch.nn.parallel.DistributedDataParallel) |
| | else model.module.config.max_seq_len |
| | ) |
| |
|
| | |
| | |
| | |
| | tok_workers = 64 if is_main_process() else 0 |
| | if is_ddp: |
| | if is_main_process(): |
| | |
| | train_dataset = SFTDataset( |
| | data_path=args.sft_data, |
| | tokenizer=tokenizer, |
| | max_seq_len=max_seq_len_cfg, |
| | tokenizer_path=tokenizer_path, |
| | num_workers=tok_workers, |
| | ) |
| | |
| | torch.distributed.barrier() |
| | if not is_main_process(): |
| | |
| | train_dataset = SFTDataset( |
| | data_path=args.sft_data, |
| | tokenizer=tokenizer, |
| | max_seq_len=max_seq_len_cfg, |
| | ) |
| | else: |
| | train_dataset = SFTDataset( |
| | data_path=args.sft_data, |
| | tokenizer=tokenizer, |
| | max_seq_len=max_seq_len_cfg, |
| | tokenizer_path=tokenizer_path, |
| | num_workers=tok_workers, |
| | ) |
| |
|
| | if is_ddp: |
| | train_sampler: DistributedSampler | RandomSampler = DistributedSampler( |
| | train_dataset, |
| | num_replicas=world_size, |
| | rank=rank, |
| | shuffle=True, |
| | seed=args.seed, |
| | ) |
| | shuffle = False |
| | else: |
| | train_sampler = RandomSampler(train_dataset) |
| | shuffle = False |
| |
|
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=args.batch_size, |
| | sampler=train_sampler, |
| | |
| | |
| | num_workers=4, |
| | pin_memory=True, |
| | drop_last=True, |
| | prefetch_factor=2, |
| | persistent_workers=True, |
| | collate_fn=dynamic_collate_fn, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | pretrain_sampler = None |
| | if args.pretrain_data is not None: |
| | if not args.pretrain_data.exists(): |
| | raise FileNotFoundError(f"Pretrain data not found: {args.pretrain_data}") |
| |
|
| | from data import PackedDataset |
| |
|
| | max_seq_len = ( |
| | model.config.max_seq_len |
| | if not isinstance(model, torch.nn.parallel.DistributedDataParallel) |
| | else model.module.config.max_seq_len |
| | ) |
| | pretrain_dataset = PackedDataset(args.pretrain_data, seq_len=max_seq_len) |
| |
|
| | if is_ddp: |
| | pretrain_sampler = DistributedSampler( |
| | pretrain_dataset, |
| | num_replicas=world_size, |
| | rank=rank, |
| | shuffle=True, |
| | seed=args.seed + 1000, |
| | ) |
| | else: |
| | pretrain_sampler = RandomSampler(pretrain_dataset) |
| |
|
| | pretrain_loader = DataLoader( |
| | pretrain_dataset, |
| | batch_size=args.batch_size, |
| | sampler=pretrain_sampler, |
| | num_workers=4, |
| | pin_memory=True, |
| | drop_last=True, |
| | prefetch_factor=2, |
| | persistent_workers=True, |
| | ) |
| |
|
| | |
| | effective_loader = MixingDataLoader( |
| | sft_loader=train_loader, |
| | pretrain_loader=pretrain_loader, |
| | pretrain_ratio=args.pretrain_mix_ratio, |
| | sft_sampler=train_sampler if is_ddp else None, |
| | pretrain_sampler=pretrain_sampler if is_ddp else None, |
| | ) |
| |
|
| | if is_main_process(): |
| | print( |
| | f"[INFO] Data mixing enabled: " |
| | f"{(1 - args.pretrain_mix_ratio) * 100:.0f}% SFT + " |
| | f"{args.pretrain_mix_ratio * 100:.0f}% pretrain" |
| | ) |
| | print(f"[INFO] Pretrain data: {args.pretrain_data} ({len(pretrain_dataset):,} samples)") |
| | else: |
| | effective_loader = train_loader |
| |
|
| | |
| | |
| | |
| | |
| | |
| | val_loader: DataLoader | None = None |
| | if args.val_data is not None: |
| | if not args.val_data.exists(): |
| | raise FileNotFoundError(f"Validation data not found: {args.val_data}") |
| | if is_ddp: |
| | if is_main_process(): |
| | val_dataset = SFTDataset( |
| | data_path=args.val_data, |
| | tokenizer=tokenizer, |
| | max_seq_len=train_dataset.max_seq_len, |
| | tokenizer_path=tokenizer_path, |
| | num_workers=tok_workers, |
| | ) |
| | torch.distributed.barrier() |
| | if not is_main_process(): |
| | val_dataset = SFTDataset( |
| | data_path=args.val_data, |
| | tokenizer=tokenizer, |
| | max_seq_len=train_dataset.max_seq_len, |
| | ) |
| | else: |
| | val_dataset = SFTDataset( |
| | data_path=args.val_data, |
| | tokenizer=tokenizer, |
| | max_seq_len=train_dataset.max_seq_len, |
| | tokenizer_path=tokenizer_path, |
| | num_workers=tok_workers, |
| | ) |
| | val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=args.batch_size, |
| | shuffle=False, |
| | num_workers=2, |
| | pin_memory=True, |
| | drop_last=False, |
| | collate_fn=dynamic_collate_fn, |
| | ) |
| | if is_main_process(): |
| | print(f"Validation dataset: {len(val_dataset):,} samples") |
| |
|
| | |
| | |
| | |
| | raw_model = getattr(model, "module", model) |
| | param_groups = build_optimizer_param_groups(raw_model, args.weight_decay) |
| | optimizer = torch.optim.AdamW( |
| | param_groups, |
| | lr=args.lr, |
| | betas=(0.9, 0.95), |
| | eps=1e-8, |
| | fused=torch.cuda.is_available(), |
| | ) |
| |
|
| | |
| | |
| | |
| | use_fp8 = raw_model.config.use_fp8 |
| |
|
| | train_config = TrainConfig( |
| | max_steps=args.max_steps, |
| | checkpoint_dir=str(args.checkpoint_dir), |
| | grad_accum_steps=args.grad_accum, |
| | use_fp8=use_fp8, |
| | log_file=str(args.log_file) if args.log_file is not None else None, |
| | save_interval=args.save_interval, |
| | log_interval=10, |
| | eval_interval=args.eval_interval, |
| | max_grad_norm=args.max_grad_norm, |
| | ) |
| |
|
| | |
| | scheduler = get_cosine_schedule_with_warmup( |
| | optimizer=optimizer, |
| | warmup_steps=args.warmup_steps, |
| | total_steps=train_config.max_steps, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | start_step = 0 |
| | if args.resume is not None: |
| | if not args.resume.exists(): |
| | raise FileNotFoundError(f"Resume checkpoint not found: {args.resume}") |
| | start_step, resume_loss = load_checkpoint( |
| | path=args.resume, |
| | model=model, |
| | optimizer=optimizer, |
| | scheduler=scheduler, |
| | ) |
| | if is_main_process(): |
| | print(f"Resumed SFT from {args.resume} at step {start_step} (loss={resume_loss:.4f})") |
| |
|
| | if args.resume is not None and isinstance(train_sampler, DistributedSampler): |
| | steps_per_epoch = len(train_loader) |
| | approx_epoch = start_step // steps_per_epoch if steps_per_epoch > 0 else 0 |
| | train_sampler.set_epoch(approx_epoch) |
| | if is_main_process(): |
| | print(f"[INFO] Resume: sampler epoch set to {approx_epoch}") |
| |
|
| | |
| | args.checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | |
| | |
| | if is_main_process(): |
| | dest_tok = args.checkpoint_dir / "tokenizer.json" |
| | if not dest_tok.exists(): |
| | shutil.copy2(str(tokenizer_path), str(dest_tok)) |
| | print(f"Tokenizer copied to {dest_tok}") |
| |
|
| | |
| | |
| | |
| | trainer = Trainer( |
| | model=model, |
| | train_loader=effective_loader, |
| | optimizer=optimizer, |
| | scheduler=scheduler, |
| | config=train_config, |
| | device=device, |
| | rank=rank, |
| | sampler=train_sampler if is_ddp and args.pretrain_data is None else None, |
| | val_loader=val_loader, |
| | ) |
| |
|
| | |
| | import signal as _signal_mod |
| |
|
| | _trainer_ref = trainer |
| |
|
| | def _graceful_shutdown_handler(signum, frame): |
| | sig_name = _signal_mod.Signals(signum).name |
| | if is_main_process(): |
| | import datetime as _dt |
| | ts = _dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| | msg = ( |
| | f"[{ts}] [SIGNAL] Received {sig_name} (signum={signum}). " |
| | f"Initiating graceful shutdown..." |
| | ) |
| | print(f"\n{msg}") |
| | if args.log_file is not None: |
| | try: |
| | with open(args.log_file, "a", encoding="utf-8") as f: |
| | f.write(msg + "\n") |
| | except Exception: |
| | pass |
| | _trainer_ref.request_shutdown(sig_name) |
| |
|
| | for _sig in (_signal_mod.SIGHUP, _signal_mod.SIGTERM): |
| | _signal_mod.signal(_sig, _graceful_shutdown_handler) |
| |
|
| | |
| | if is_main_process(): |
| | import datetime |
| |
|
| | inner_config = raw_model.config |
| | eff_batch_seqs = args.batch_size * args.grad_accum * world_size |
| | eff_tokens_per_step = eff_batch_seqs * inner_config.max_seq_len |
| | train_samples = len(train_dataset) |
| | precision_label = "FP8 (MXFP8BlockScaling)" if use_fp8 else "BF16" |
| | nccl_debug = os.environ.get("NCCL_DEBUG", "not set") |
| | omp_threads = os.environ.get("OMP_NUM_THREADS", "not set") |
| |
|
| | mix_label = "none" |
| | if args.pretrain_data is not None: |
| | mix_label = ( |
| | f"{(1 - args.pretrain_mix_ratio) * 100:.0f}% SFT + " |
| | f"{args.pretrain_mix_ratio * 100:.0f}% pretrain" |
| | ) |
| |
|
| | print( |
| | f"\n{'='*70}\n" |
| | f" LLM Supervised Fine-Tuning — " |
| | f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n" |
| | f"{'='*70}\n" |
| | f" base ckpt : {args.base_checkpoint}\n" |
| | f" sft data : {args.sft_data} ({train_samples:,} samples)\n" |
| | f" data mix : {mix_label}\n" |
| | f" model : {inner_config.num_params:,} params | " |
| | f"d_model={inner_config.d_model} n_layers={inner_config.n_layers}\n" |
| | f" precision : {precision_label}\n" |
| | f" GPUs : {world_size} | batch/GPU={args.batch_size} " |
| | f"grad_accum={args.grad_accum}\n" |
| | f" eff_batch : {eff_batch_seqs} seqs " |
| | f"= {eff_tokens_per_step:,} tok/step\n" |
| | f" max_steps : {train_config.max_steps:,}\n" |
| | f" lr : {args.lr:.2e} " |
| | f"warmup={args.warmup_steps} weight_decay={args.weight_decay}\n" |
| | f" ckpt_dir : {args.checkpoint_dir}\n" |
| | f" env : OMP_NUM_THREADS={omp_threads} NCCL_DEBUG={nccl_debug}\n" |
| | f"{'='*70}\n" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | neftune_alpha = getattr(args, 'neftune_alpha', 5.0) |
| | neftune_handle = add_neftune_hook(raw_model, noise_alpha=neftune_alpha) |
| | if rank == 0: |
| | if neftune_handle is not None: |
| | print(f"[INFO] NEFTune enabled (noise_alpha={neftune_alpha})") |
| | else: |
| | print("[WARN] NEFTune disabled - embedding layer not found") |
| |
|
| | |
| | try: |
| | trainer.train(start_step=start_step) |
| | except KeyboardInterrupt: |
| | if is_main_process(): |
| | print("\n[INFO] SFT interrupted by user (KeyboardInterrupt).") |
| | except Exception as e: |
| | import traceback |
| | if is_main_process(): |
| | tb = traceback.format_exc() |
| | print(f"\n[ERROR] SFT failed at rank {rank}:\n{tb}") |
| | if args.log_file is not None: |
| | with open(args.log_file, "a", encoding="utf-8") as f: |
| | import datetime |
| | f.write( |
| | f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " |
| | f"[FATAL] {tb}\n" |
| | ) |
| | raise |
| | finally: |
| | |
| | if neftune_handle is not None: |
| | neftune_handle.remove() |
| | if is_ddp: |
| | cleanup_ddp() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|