| | """ |
| | train/utils.py — Training utility functions. |
| | |
| | Provides: |
| | get_cosine_schedule_with_warmup : LambdaLR scheduler with linear warmup + cosine decay |
| | save_checkpoint : Persist model/optimizer/scheduler state to disk |
| | load_checkpoint : Restore state from a saved checkpoint directory |
| | get_grad_norm : Compute total L2 gradient norm across all parameters |
| | setup_ddp : Initialise NCCL distributed process group |
| | cleanup_ddp : Tear down distributed process group |
| | is_main_process : True when this process is rank 0 (or non-distributed) |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import math |
| | import os |
| | import shutil |
| | from pathlib import Path |
| | from typing import Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | import yaml |
| | from torch.optim import Optimizer |
| | from torch.optim.lr_scheduler import LambdaLR |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def get_cosine_schedule_with_warmup( |
| | optimizer: Optimizer, |
| | warmup_steps: int, |
| | total_steps: int, |
| | min_lr_ratio: float = 0.1, |
| | ) -> LambdaLR: |
| | """ |
| | Create a LambdaLR scheduler with: |
| | - Linear warmup: lr scales from 0 → 1 over [0, warmup_steps) |
| | - Cosine decay: lr scales from 1 → min_lr_ratio over [warmup_steps, total_steps] |
| | |
| | Args: |
| | optimizer: The wrapped optimizer. |
| | warmup_steps: Number of linear-warmup steps. |
| | total_steps: Total number of training steps. |
| | min_lr_ratio: Minimum lr as a fraction of the peak lr (default 0.1). |
| | |
| | Returns: |
| | A LambdaLR scheduler instance. |
| | """ |
| | if warmup_steps < 0: |
| | raise ValueError(f"warmup_steps must be >= 0, got {warmup_steps}") |
| | if total_steps <= 0: |
| | raise ValueError(f"total_steps must be > 0, got {total_steps}") |
| | if not (0.0 <= min_lr_ratio <= 1.0): |
| | raise ValueError(f"min_lr_ratio must be in [0, 1], got {min_lr_ratio}") |
| |
|
| | def lr_lambda(current_step: int) -> float: |
| | |
| | if current_step < warmup_steps: |
| | return float(current_step) / float(max(1, warmup_steps)) |
| |
|
| | |
| | if current_step >= total_steps: |
| | return min_lr_ratio |
| |
|
| | |
| | decay_steps = total_steps - warmup_steps |
| | progress = float(current_step - warmup_steps) / float(max(1, decay_steps)) |
| | cosine_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| | |
| | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_factor |
| |
|
| | return LambdaLR(optimizer, lr_lambda) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def save_checkpoint( |
| | model: torch.nn.Module, |
| | optimizer: Optimizer, |
| | scheduler: LambdaLR, |
| | step: int, |
| | loss: float, |
| | path: str | Path, |
| | suffix: str | None = None, |
| | ) -> Path: |
| | """ |
| | Save a training checkpoint to ``path/checkpoint-{step:07d}/``. |
| | |
| | Saves: |
| | - model.pt : model state_dict |
| | - optimizer.pt : optimizer state_dict |
| | - scheduler.pt : scheduler state_dict |
| | - train_state.pt : step and loss scalars |
| | - config.yaml : model LMConfig (if the model exposes a ``.config`` attribute) |
| | |
| | Handles both plain ``nn.Module`` and DDP-wrapped models by unwrapping |
| | via ``.module`` when present. |
| | |
| | Args: |
| | model: The model (plain or DDP-wrapped). |
| | optimizer: The optimizer. |
| | scheduler: The LR scheduler. |
| | step: Current training step (used in directory name). |
| | loss: Current loss value (stored for reference). |
| | path: Root checkpoint directory. |
| | |
| | Returns: |
| | Path to the created checkpoint sub-directory. |
| | """ |
| | dir_name = f"checkpoint-{suffix}" if suffix else f"checkpoint-{step:07d}" |
| | ckpt_dir = Path(path) / dir_name |
| | tmp_dir = Path(path) / f".tmp_{dir_name}" |
| |
|
| | |
| | if tmp_dir.exists(): |
| | shutil.rmtree(tmp_dir) |
| | tmp_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | raw_model: torch.nn.Module = getattr(model, "module", model) |
| |
|
| | torch.save(raw_model.state_dict(), tmp_dir / "model.pt") |
| | torch.save(optimizer.state_dict(), tmp_dir / "optimizer.pt") |
| | torch.save(scheduler.state_dict(), tmp_dir / "scheduler.pt") |
| |
|
| | import random as _random |
| | train_state = { |
| | "step": step, |
| | "loss": loss, |
| | "rng_state": { |
| | "python": _random.getstate(), |
| | "numpy": np.random.get_state(), |
| | "torch_cpu": torch.random.get_rng_state(), |
| | "torch_cuda": torch.cuda.get_rng_state_all(), |
| | }, |
| | } |
| | torch.save(train_state, tmp_dir / "train_state.pt") |
| |
|
| | |
| | if hasattr(raw_model, "config"): |
| | cfg = raw_model.config |
| | if hasattr(cfg, "to_dict"): |
| | config_dict = cfg.to_dict() |
| | else: |
| | |
| | config_dict = { |
| | k: v for k, v in vars(cfg).items() if not k.startswith("_") |
| | } |
| | with open(tmp_dir / "config.yaml", "w", encoding="utf-8") as f: |
| | yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False) |
| |
|
| | |
| | trash_dir = Path(path) / f".trash_{dir_name}" |
| | if trash_dir.exists(): |
| | shutil.rmtree(trash_dir) |
| | if ckpt_dir.exists(): |
| | ckpt_dir.rename(trash_dir) |
| | tmp_dir.rename(ckpt_dir) |
| | if trash_dir.exists(): |
| | shutil.rmtree(trash_dir) |
| |
|
| | |
| | cleanup_old_checkpoints(Path(path)) |
| |
|
| | return ckpt_dir |
| |
|
| |
|
| | def cleanup_old_checkpoints(path: Path, keep: int = 5) -> None: |
| | """Remove old checkpoints, keeping the most recent `keep` plus checkpoint-best.""" |
| | ckpts = sorted( |
| | [d for d in path.glob("checkpoint-[0-9]*") if d.is_dir()], |
| | key=lambda d: d.stat().st_mtime, |
| | ) |
| | for old in ckpts[:-keep]: |
| | shutil.rmtree(old) |
| |
|
| |
|
| | def load_checkpoint( |
| | path: str | Path, |
| | model: torch.nn.Module, |
| | optimizer: Optional[Optimizer] = None, |
| | scheduler: Optional[LambdaLR] = None, |
| | ) -> Tuple[int, float]: |
| | """ |
| | Load a checkpoint from a directory created by :func:`save_checkpoint`. |
| | |
| | The model weights are always restored. Optimizer and scheduler states are |
| | only restored when the corresponding objects are provided. |
| | |
| | Args: |
| | path: Path to the checkpoint directory (e.g. ``checkpoints/checkpoint-0001000``). |
| | model: Model to load weights into (plain or DDP-wrapped). |
| | optimizer: Optional optimizer to restore state into. |
| | scheduler: Optional LR scheduler to restore state into. |
| | |
| | Returns: |
| | ``(step, loss)`` — the training step and loss recorded at save time. |
| | """ |
| | ckpt_dir = Path(path) |
| | if not ckpt_dir.is_dir(): |
| | raise FileNotFoundError(f"Checkpoint directory not found: {ckpt_dir}") |
| |
|
| | |
| | raw_model: torch.nn.Module = getattr(model, "module", model) |
| |
|
| | |
| | try: |
| | device = next(raw_model.parameters()).device |
| | except StopIteration: |
| | device = torch.device("cpu") |
| |
|
| | raw_model.load_state_dict( |
| | torch.load(ckpt_dir / "model.pt", map_location=device, weights_only=True) |
| | ) |
| |
|
| | if optimizer is not None: |
| | optimizer.load_state_dict( |
| | torch.load(ckpt_dir / "optimizer.pt", map_location=device, weights_only=True) |
| | ) |
| |
|
| | if scheduler is not None: |
| | scheduler.load_state_dict( |
| | torch.load(ckpt_dir / "scheduler.pt", map_location=device, weights_only=True) |
| | ) |
| |
|
| | train_state = torch.load( |
| | ckpt_dir / "train_state.pt", map_location="cpu", weights_only=True |
| | ) |
| | step: int = int(train_state["step"]) |
| | loss: float = float(train_state["loss"]) |
| |
|
| | |
| | rng_state = train_state.get("rng_state") |
| | if rng_state is not None: |
| | import random as _random |
| | try: |
| | _random.setstate(rng_state["python"]) |
| | np.random.set_state(rng_state["numpy"]) |
| | torch.random.set_rng_state(rng_state["torch_cpu"]) |
| | torch.cuda.set_rng_state_all(rng_state["torch_cuda"]) |
| | except Exception as e: |
| | print(f"[WARN] RNG state restore failed (non-fatal): {e}") |
| |
|
| | return step, loss |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def get_grad_norm(model: torch.nn.Module) -> float: |
| | """ |
| | Compute the total L2 norm of all parameter gradients. |
| | |
| | Uses a single GPU kernel + one GPU-CPU sync instead of one sync per |
| | parameter (the naive loop approach). Only parameters with non-None |
| | ``.grad`` attribute contribute. |
| | |
| | Args: |
| | model: The model (plain or DDP-wrapped). |
| | |
| | Returns: |
| | Scalar float — the global gradient L2 norm. |
| | """ |
| | raw_model: torch.nn.Module = getattr(model, "module", model) |
| | grads = [p.grad.detach().float() for p in raw_model.parameters() if p.grad is not None] |
| | if not grads: |
| | return 0.0 |
| | |
| | return torch.stack([g.norm(2) for g in grads]).norm(2).item() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def setup_ddp() -> Tuple[int, int, int, torch.device]: |
| | """ |
| | Initialise the NCCL distributed process group for DDP training. |
| | |
| | Reads ``RANK``, ``LOCAL_RANK``, and ``WORLD_SIZE`` from the environment |
| | (set automatically by ``torchrun``). |
| | |
| | Returns: |
| | ``(rank, local_rank, world_size, device)`` |
| | """ |
| | rank = int(os.environ["RANK"]) |
| | local_rank = int(os.environ["LOCAL_RANK"]) |
| | world_size = int(os.environ["WORLD_SIZE"]) |
| |
|
| | |
| | |
| | os.environ.setdefault("OMP_NUM_THREADS", "4") |
| | os.environ.setdefault("MKL_NUM_THREADS", "4") |
| |
|
| | import datetime as _dt |
| | dist.init_process_group( |
| | backend="nccl", |
| | timeout=_dt.timedelta(seconds=7200), |
| | ) |
| |
|
| | torch.cuda.set_device(local_rank) |
| | device = torch.device(f"cuda:{local_rank}") |
| |
|
| | return rank, local_rank, world_size, device |
| |
|
| |
|
| | def cleanup_ddp() -> None: |
| | """Tear down the distributed process group (call at end of training).""" |
| | if dist.is_available() and dist.is_initialized(): |
| | dist.destroy_process_group() |
| |
|
| |
|
| | def is_main_process() -> bool: |
| | """ |
| | Return ``True`` when this process is rank 0 or when running without DDP. |
| | |
| | Reads the ``RANK`` environment variable; if it is absent the process is |
| | assumed to be the sole process (rank 0). |
| | """ |
| | return int(os.environ.get("RANK", "0")) == 0 |
| |
|