| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| sys.path.append(str(ROOT / "src")) |
|
|
| from sllm.checkpoint import load_checkpoint, save_checkpoint |
| from sllm.config import ModelConfig, SFTConfig, load_json, save_json |
| from sllm.data import FixedSFTDataset |
| from sllm.model import SLLMForCausalLM |
| from sllm.utils import ( |
| append_jsonl, |
| autocast_context, |
| cosine_lr, |
| cuda_memory_snapshot, |
| ensure_dir, |
| format_number, |
| get_device, |
| iso_timestamp, |
| maybe_enable_tf32, |
| model_parameter_count, |
| resolve_runtime_precision, |
| set_optimizer_lr, |
| set_seed, |
| setup_logger, |
| timestamp, |
| tokens_per_step, |
| ) |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Run supervised fine-tuning for the sLLM.") |
| parser.add_argument("--model-config", required=True, help="Path to model JSON config.") |
| parser.add_argument("--train-config", required=True, help="Path to SFT JSON config.") |
| parser.add_argument("--max-steps", type=int, default=None, help="Optional debug override.") |
| return parser |
|
|
|
|
| def build_optimizer(model: torch.nn.Module, config: SFTConfig, device: torch.device): |
| decay_params = [] |
| no_decay_params = [] |
| for name, parameter in model.named_parameters(): |
| if not parameter.requires_grad: |
| continue |
| if parameter.ndim <= 1 or name.endswith("bias"): |
| no_decay_params.append(parameter) |
| else: |
| decay_params.append(parameter) |
| return torch.optim.AdamW( |
| [ |
| {"params": decay_params, "weight_decay": config.weight_decay}, |
| {"params": no_decay_params, "weight_decay": 0.0}, |
| ], |
| lr=config.learning_rate, |
| betas=(config.beta1, config.beta2), |
| fused=device.type == "cuda", |
| ) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model: SLLMForCausalLM, loader: DataLoader, device: torch.device, precision: str, max_batches: int): |
| model.eval() |
| losses = [] |
| for batch_index, batch in enumerate(loader): |
| if batch_index >= max_batches: |
| break |
| batch = {key: value.to(device) for key, value in batch.items()} |
| with autocast_context(device, precision): |
| loss = model(**batch)["loss"] |
| losses.append(loss.detach().float().item()) |
| model.train() |
| mean_loss = float(sum(losses) / max(1, len(losses))) |
| return mean_loss, math.exp(min(mean_loss, 20)) |
|
|
|
|
| def save_run_config(output_dir: Path, model_config: ModelConfig, train_config: SFTConfig) -> None: |
| save_json( |
| output_dir / "run_config.json", |
| { |
| "model_config": model_config.to_dict(), |
| "train_config": train_config.to_dict(), |
| }, |
| ) |
|
|
|
|
| def main() -> None: |
| args = build_parser().parse_args() |
| model_config = ModelConfig.from_dict(load_json(args.model_config)) |
| train_config = SFTConfig.from_dict(load_json(args.train_config)) |
| if args.max_steps is not None: |
| train_config.max_steps = args.max_steps |
|
|
| set_seed(train_config.seed) |
| device = get_device() |
| maybe_enable_tf32(device) |
| runtime_precision, precision_warning = resolve_runtime_precision(device, train_config.precision) |
| train_config.precision = runtime_precision |
|
|
| output_dir = ensure_dir(train_config.output_dir) |
| checkpoint_dir = ensure_dir(train_config.checkpoint_dir) |
| logger, log_path = setup_logger("sllm.train_sft", output_dir, "train_sft") |
| metrics_path = Path(output_dir) / "logs" / f"{log_path.stem}.jsonl" |
| logger.info("SFT training started") |
| logger.info("Log file: %s", log_path) |
| logger.info("Metrics JSONL: %s", metrics_path) |
| logger.info("Arguments | model_config=%s train_config=%s max_steps_override=%s", args.model_config, args.train_config, args.max_steps) |
| if precision_warning is not None: |
| logger.warning(precision_warning) |
| logger.info("Model config | %s", model_config.to_dict()) |
| logger.info("SFT config | %s", train_config.to_dict()) |
| append_jsonl( |
| metrics_path, |
| { |
| "event": "run_started", |
| "timestamp": iso_timestamp(), |
| "log_path": str(log_path), |
| "metrics_path": str(metrics_path), |
| "model_config": model_config.to_dict(), |
| "train_config": train_config.to_dict(), |
| "args": { |
| "model_config": args.model_config, |
| "train_config": args.train_config, |
| "max_steps_override": args.max_steps, |
| }, |
| }, |
| ) |
| save_run_config(output_dir, model_config, train_config) |
|
|
| train_dataset = FixedSFTDataset(train_config.dataset_path, split="train") |
| val_dataset = FixedSFTDataset(train_config.dataset_path, split="val") |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=train_config.micro_batch_size, |
| shuffle=True, |
| num_workers=train_config.num_workers, |
| pin_memory=device.type == "cuda", |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=train_config.micro_batch_size, |
| shuffle=False, |
| num_workers=0, |
| pin_memory=device.type == "cuda", |
| ) |
|
|
| model = SLLMForCausalLM(model_config).to(device) |
| if train_config.compile_model and hasattr(torch, "compile"): |
| model = torch.compile(model) |
|
|
| optimizer = build_optimizer(model, train_config, device) |
| scaler = torch.amp.GradScaler( |
| "cuda", |
| enabled=device.type == "cuda" and train_config.precision.lower() == "fp16", |
| ) |
|
|
| start_step = 0 |
| checkpoint_path = train_config.resume_from or train_config.init_from |
| if checkpoint_path: |
| payload = load_checkpoint(checkpoint_path, map_location=device) |
| model.load_state_dict(payload["model"]) |
| if train_config.resume_from and payload.get("optimizer") is not None: |
| optimizer.load_state_dict(payload["optimizer"]) |
| start_step = int(payload.get("step", 0)) |
| logger.info("Resumed SFT | step=%s checkpoint=%s", start_step, checkpoint_path) |
| append_jsonl( |
| metrics_path, |
| { |
| "event": "resumed", |
| "timestamp": iso_timestamp(), |
| "step": start_step, |
| "checkpoint": checkpoint_path, |
| }, |
| ) |
| else: |
| logger.info("Loaded initialization weights | checkpoint=%s", checkpoint_path) |
| append_jsonl( |
| metrics_path, |
| { |
| "event": "initialized_from_checkpoint", |
| "timestamp": iso_timestamp(), |
| "checkpoint": checkpoint_path, |
| }, |
| ) |
|
|
| model.train() |
| tokens_step = tokens_per_step( |
| train_config.micro_batch_size, |
| train_config.grad_accum_steps, |
| train_config.seq_len, |
| ) |
| logger.info("Device summary | device=%s precision=%s compile_model=%s", device, train_config.precision, train_config.compile_model) |
| logger.info("Model summary | parameters=%s", format_number(model_parameter_count(model))) |
| logger.info( |
| "Batch summary | seq_len=%s micro_batch_size=%s grad_accum_steps=%s tokens_per_step=%s", |
| train_config.seq_len, |
| train_config.micro_batch_size, |
| train_config.grad_accum_steps, |
| f"{tokens_step:,}", |
| ) |
| logger.info( |
| "Dataset summary | dataset_path=%s train_examples=%s val_examples=%s", |
| train_config.dataset_path, |
| len(train_dataset), |
| len(val_dataset), |
| ) |
| append_jsonl( |
| metrics_path, |
| { |
| "event": "runtime_summary", |
| "timestamp": iso_timestamp(), |
| "device": str(device), |
| "precision": train_config.precision, |
| "compile_model": train_config.compile_model, |
| "parameters": model_parameter_count(model), |
| "seq_len": train_config.seq_len, |
| "micro_batch_size": train_config.micro_batch_size, |
| "grad_accum_steps": train_config.grad_accum_steps, |
| "tokens_per_step": tokens_step, |
| "dataset_path": train_config.dataset_path, |
| "train_examples": len(train_dataset), |
| "val_examples": len(val_dataset), |
| }, |
| ) |
| running_loss = 0.0 |
| log_start_time = time.perf_counter() |
| train_iterator = iter(train_loader) |
| last_grad_norm = float("nan") |
|
|
| for step in range(start_step, train_config.max_steps): |
| lr = cosine_lr( |
| step=step, |
| warmup_steps=train_config.warmup_steps, |
| max_steps=train_config.max_steps, |
| max_lr=train_config.learning_rate, |
| min_lr=train_config.min_lr, |
| ) |
| set_optimizer_lr(optimizer, lr) |
| optimizer.zero_grad(set_to_none=True) |
|
|
| step_loss = 0.0 |
| for _ in range(train_config.grad_accum_steps): |
| try: |
| batch = next(train_iterator) |
| except StopIteration: |
| train_iterator = iter(train_loader) |
| batch = next(train_iterator) |
|
|
| batch = {key: value.to(device, non_blocking=device.type == "cuda") for key, value in batch.items()} |
| with autocast_context(device, train_config.precision): |
| loss = model(**batch)["loss"] / train_config.grad_accum_steps |
| step_loss += loss.detach().float().item() |
| if scaler.is_enabled(): |
| scaler.scale(loss).backward() |
| else: |
| loss.backward() |
|
|
| if train_config.grad_clip and train_config.grad_clip > 0: |
| if scaler.is_enabled(): |
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.grad_clip) |
| last_grad_norm = float(grad_norm) |
|
|
| if scaler.is_enabled(): |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| optimizer.step() |
|
|
| running_loss += step_loss |
|
|
| if (step + 1) % train_config.log_interval == 0: |
| elapsed = time.perf_counter() - log_start_time |
| avg_loss = running_loss / train_config.log_interval |
| tok_per_sec = (tokens_step * train_config.log_interval) / max(elapsed, 1e-6) |
| memory = cuda_memory_snapshot(device) |
| memory_suffix = "" |
| if memory: |
| memory_suffix = ( |
| f" mem_alloc_gb={memory['allocated_gb']:.2f}" |
| f" mem_reserved_gb={memory['reserved_gb']:.2f}" |
| f" max_mem_alloc_gb={memory['max_allocated_gb']:.2f}" |
| f" max_mem_reserved_gb={memory['max_reserved_gb']:.2f}" |
| ) |
| logger.info( |
| "Train step | step=%s loss=%.4f lr=%.6f tok_per_sec=%s grad_norm=%.4f%s", |
| step + 1, |
| avg_loss, |
| lr, |
| f"{tok_per_sec:,.0f}", |
| last_grad_norm, |
| memory_suffix, |
| ) |
| append_jsonl( |
| metrics_path, |
| { |
| "event": "train", |
| "timestamp": iso_timestamp(), |
| "step": step + 1, |
| "loss": avg_loss, |
| "lr": lr, |
| "tok_per_sec": tok_per_sec, |
| "grad_norm": last_grad_norm, |
| "tokens_seen": (step + 1) * tokens_step, |
| "elapsed_sec": elapsed, |
| "seq_len": train_config.seq_len, |
| "micro_batch_size": train_config.micro_batch_size, |
| "grad_accum_steps": train_config.grad_accum_steps, |
| **memory, |
| }, |
| ) |
| running_loss = 0.0 |
| log_start_time = time.perf_counter() |
|
|
| if (step + 1) % train_config.eval_interval == 0: |
| val_loss, val_ppl = evaluate( |
| model=model, |
| loader=val_loader, |
| device=device, |
| precision=train_config.precision, |
| max_batches=train_config.eval_batches, |
| ) |
| logger.info("Eval step | step=%s val_loss=%.4f perplexity=%.2f", step + 1, val_loss, val_ppl) |
| append_jsonl( |
| metrics_path, |
| { |
| "event": "eval", |
| "timestamp": iso_timestamp(), |
| "step": step + 1, |
| "val_loss": val_loss, |
| "perplexity": val_ppl, |
| "eval_batches": train_config.eval_batches, |
| }, |
| ) |
|
|
| if (step + 1) % train_config.save_interval == 0 or (step + 1) == train_config.max_steps: |
| step_checkpoint_path = checkpoint_dir / f"step_{step + 1:07d}.pt" |
| last_checkpoint_path = checkpoint_dir / "last.pt" |
| save_checkpoint( |
| step_checkpoint_path, |
| model=model, |
| optimizer=optimizer, |
| step=step + 1, |
| model_config=model_config.to_dict(), |
| train_config=train_config.to_dict(), |
| extra_state={"tokens_seen": (step + 1) * tokens_step}, |
| ) |
| save_checkpoint( |
| last_checkpoint_path, |
| model=model, |
| optimizer=optimizer, |
| step=step + 1, |
| model_config=model_config.to_dict(), |
| train_config=train_config.to_dict(), |
| extra_state={"tokens_seen": (step + 1) * tokens_step}, |
| ) |
| logger.info( |
| "Checkpoint saved | step=%s step_checkpoint=%s last_checkpoint=%s", |
| step + 1, |
| step_checkpoint_path, |
| last_checkpoint_path, |
| ) |
| append_jsonl( |
| metrics_path, |
| { |
| "event": "checkpoint", |
| "timestamp": iso_timestamp(), |
| "step": step + 1, |
| "step_checkpoint": str(step_checkpoint_path), |
| "last_checkpoint": str(last_checkpoint_path), |
| "tokens_seen": (step + 1) * tokens_step, |
| }, |
| ) |
|
|
| append_jsonl( |
| metrics_path, |
| { |
| "event": "run_finished", |
| "timestamp": iso_timestamp(), |
| "final_step": train_config.max_steps, |
| "tokens_seen": train_config.max_steps * tokens_step, |
| }, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|