| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import os |
| from abc import ABC, abstractmethod |
| from typing import Any, Dict |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner |
| from ..utils.import_utils import is_torch_version_greater_than |
| from ..utils.logging import get_logger |
| from pathlib import Path |
|
|
| if is_torch_version_greater_than("2.4"): |
| import torch.distributed.checkpoint as dcp |
| from torch.distributed.checkpoint import ( |
| FileSystemReader, |
| FileSystemWriter, |
| ) |
| from torch.distributed.checkpoint.state_dict import ( |
| get_model_state_dict, |
| get_optimizer_state_dict, |
| set_model_state_dict, |
| set_optimizer_state_dict, |
| ) |
| from torch.distributed.checkpoint.stateful import Stateful |
| else: |
| Stateful = ABC |
|
|
| logger = get_logger(__name__) |
|
|
| _EXTRA_STATE_FORMAT = "extra_state_rank_{}.pt" |
| _MODEL_DIR = "model" |
| _EMA_DIR = "ema" |
| _OPTIMIZER_DIR = "optimizer" |
| _EXTRA_STATE_DIR = "extra_state" |
|
|
|
|
| class ModelState(Stateful): |
| """ |
| A wrapper around a model to make it stateful. |
| Args: |
| model (Model): model to wrap. |
| """ |
|
|
| def __init__(self, model): |
| self.model = model |
|
|
| def state_dict(self): |
| model_state_dict = get_model_state_dict(model=self.model) |
| return {"model": model_state_dict} |
|
|
| def load_state_dict(self, state_dict): |
| set_model_state_dict(model=self.model, model_state_dict=state_dict["model"]) |
|
|
|
|
| class OptimizerState(Stateful): |
| """ |
| A wrapper around an optimizer to make it stateful. |
| |
| Args: |
| model (Model): model to wrap. |
| optimizer (Optimizer): optimizer to wrap. |
| """ |
|
|
| def __init__(self, model, optimizer): |
| self.model = model |
| self.optimizer = optimizer |
|
|
| def state_dict(self): |
| optimizer_state_dict = get_optimizer_state_dict(model=self.model, optimizers=self.optimizer) |
| return {"optim": optimizer_state_dict} |
|
|
| def load_state_dict(self, state_dict): |
| set_optimizer_state_dict(model=self.model, optimizers=self.optimizer, optim_state_dict=state_dict["optim"]) |
|
|
|
|
| def build_checkpointer( |
| dist_backend: str = "fsdp1", |
| ckpt_manager: str = "bytecheckpoint", |
| ): |
| """ |
| create a checkpointer manager with given mode. |
| Args: |
| dist_backend (str, optional): checkpoint mode. Defaults to "fsdp1". |
| fsdp1: FSDP1 checkpoint from bytecheckpoint |
| fsdp2-vescale: FSDP2 checkpoint from bytecheckpoint |
| fsdp2: FSDP2 checkpoint from bytecheckpoint |
| ddp: DDP checkpoint from bytecheckpoint |
| dcp: DCP checkpoint from torch.distributed.checkpoint |
| ckpt_manager (str, optional): checkpoint manager. Defaults to "bytecheckpoint". |
| bytecheckpoint: bytecheckpoint checkpoint manager |
| dcp: torch dcp checkpoint manager |
| Raises: |
| ValueError: if ckpt_manager is not supported |
| |
| Returns: |
| Checkpointer: checkpointer with given mode. |
| """ |
|
|
| if ckpt_manager == "bytecheckpoint": |
| if dist_backend == "ddp": |
| from bytecheckpoint import DDPCheckpointer as Checkpointer |
| elif dist_backend == "fsdp1": |
| from bytecheckpoint import FSDPCheckpointer as Checkpointer |
| elif dist_backend == "fsdp2-vescale": |
| from bytecheckpoint import VeScaleCheckpointer as Checkpointer |
| elif dist_backend == "fsdp2": |
| from bytecheckpoint import FSDP2Checkpointer as Checkpointer |
| elif ckpt_manager == "dcp": |
| if not is_torch_version_greater_than("2.4"): |
| raise ValueError("DCP checkpoint manager requires torch version >= 2.4") |
| if dist_backend not in ["ddp", "fsdp1", "fsdp2"]: |
| raise ValueError( |
| f"Unsupported distributed backend: {dist_backend} for DCP checkpoint manager, supported modes are: ddp, fsdp1, fsdp2" |
| ) |
| Checkpointer = DistributedCheckpointer |
| else: |
| raise ValueError( |
| f"Unknown checkpoint manager: {ckpt_manager}, supported modes are: bytecheckpoint, dcp, native" |
| ) |
|
|
| return Checkpointer |
|
|
|
|
| class CheckpointerBase(ABC): |
| """Base class for checkpointer""" |
|
|
| @abstractmethod |
| def save( |
| cls, |
| path: str, |
| state: Dict[str, Any], |
| ): |
| return |
|
|
| @abstractmethod |
| def load( |
| cls, |
| path: str, |
| state: Dict[str, Any], |
| ): |
| return |
|
|
|
|
| class DistributedCheckpointer(CheckpointerBase): |
| """ |
| Distributed checkpointer for torch.distributed.checkpoint |
| """ |
|
|
| @classmethod |
| def save( |
| cls, |
| path: str, |
| state: Dict[str, Any], |
| global_steps: int = None, |
| save_async=False, |
| ) -> None: |
| """ |
| save training state to distributed checkpoint |
| |
| args: |
| path: path to save checkpoint |
| state: state to save |
| global_steps: global steps |
| save_async: whether to save asynchronously |
| return: |
| None |
| """ |
|
|
| checkpoint_dir = f"{path}/global_step_{global_steps}" if global_steps else path |
| os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
| if "model" not in state: |
| raise ValueError("Model must be provided to save a distributed checkpoint.") |
|
|
| if save_async: |
| model_dir = os.path.join(checkpoint_dir, _MODEL_DIR) |
| dcp.async_save( |
| state_dict={"state": ModelState(state["model"])}, |
| storage_writer=FileSystemWriter( |
| model_dir, |
| thread_count=16, |
| single_file_per_rank=True, |
| sync_files=False, |
| ), |
| ) |
| if "ema" in state and state["ema"] is not None: |
| ema_dir = os.path.join(checkpoint_dir, _EMA_DIR) |
| dcp.async_save( |
| state_dict={"state": ModelState(state["ema"])}, |
| storage_writer=FileSystemWriter( |
| ema_dir, |
| thread_count=16, |
| single_file_per_rank=True, |
| sync_files=False, |
| ), |
| ) |
| if "optimizer" in state: |
| optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR) |
| dcp.async_save( |
| state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])}, |
| storage_writer=FileSystemWriter( |
| optimizer_dir, |
| thread_count=16, |
| single_file_per_rank=True, |
| sync_files=False, |
| ), |
| ) |
| else: |
| def safe_create_writer(output_dir): |
| tmp_path = Path(output_dir) / ".metadata.tmp" |
| if tmp_path.exists(): |
| print(f"Warning: removing existing tmp file: {tmp_path}") |
| tmp_path.unlink() |
| return FileSystemWriter( |
| output_dir, |
| thread_count=16, |
| single_file_per_rank=True, |
| sync_files=False, |
| ) |
| model_dir = os.path.join(checkpoint_dir, _MODEL_DIR) |
| storage_writer = safe_create_writer(model_dir) |
| dcp.save( |
| state_dict={"state": ModelState(state["model"])}, |
| storage_writer=storage_writer, |
| ) |
| if "ema" in state and state["ema"] is not None: |
| ema_dir = os.path.join(checkpoint_dir, _EMA_DIR) |
| storage_writer = safe_create_writer(ema_dir) |
| dcp.save( |
| state_dict={"state": ModelState(state["ema"])}, |
| storage_writer=storage_writer, |
| ) |
| if "optimizer" in state: |
| optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR) |
| dcp.save( |
| state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])}, |
| storage_writer=FileSystemWriter( |
| optimizer_dir, |
| thread_count=16, |
| single_file_per_rank=True, |
| sync_files=False, |
| ), |
| ) |
| |
|
|
| if "extra_state" in state: |
| extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR) |
| os.makedirs(extra_state_dir, exist_ok=True) |
| extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank())) |
| torch.save( |
| state["extra_state"], |
| extra_state_path, |
| ) |
|
|
| logger.info_rank0(f"Saved checkpoint to {checkpoint_dir}") |
|
|
| @classmethod |
| def load( |
| cls, |
| path: str, |
| state: Dict[str, Any], |
| process_group=None, |
| ) -> Dict[str, Any]: |
| """ |
| load training state from distributed checkpoint |
| args: |
| path: path to load checkpoint |
| state: state to load, "model" are required, "optimizer" and "extra_state" are optional |
| |
| return: |
| state: state loaded |
| """ |
| checkpoint_dir = path |
|
|
| if state is None: |
| raise ValueError("State dict must be provided to load a distributed checkpoint.") |
|
|
| if "model" not in state: |
| raise ValueError("Model must be provided to load a distributed checkpoint.") |
|
|
| if "ema" in state and state["ema"] is not None: |
| ema_dir = os.path.join(checkpoint_dir, _EMA_DIR) |
| dcp.load( |
| state_dict={"state": ModelState(state["ema"])}, |
| storage_reader=FileSystemReader(ema_dir), |
| process_group=process_group, |
| ) |
|
|
| if "optimizer" in state: |
| model_dir = os.path.join(checkpoint_dir, _MODEL_DIR) |
| dcp.load( |
| state_dict={"state": ModelState(state["model"])}, |
| storage_reader=FileSystemReader(model_dir), |
| process_group=process_group, |
| ) |
|
|
| optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR) |
| try: |
| dcp.load( |
| state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])}, |
| storage_reader=FileSystemReader(optimizer_dir), |
| planner = DefaultLoadPlanner(allow_partial_load=True), |
| process_group=process_group, |
| ) |
| except: |
| logger.info_rank0(f"Skip loading Optimizer from {checkpoint_dir}") |
| else: |
| model_dir = os.path.join(checkpoint_dir, _MODEL_DIR) |
| dcp.load( |
| state_dict={"state": ModelState(state["model"])}, |
| storage_reader=FileSystemReader(model_dir), |
| process_group=process_group, |
| ) |
|
|
| if "extra_state" in state: |
| extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR) |
| os.makedirs(extra_state_dir, exist_ok=True) |
| extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank())) |
| state["extra_state"] = torch.load( |
| extra_state_path, |
| ) |
|
|
| logger.info_rank0(f"Loaded checkpoint from {checkpoint_dir}") |
|
|
| return state |
|
|