# Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from abc import ABC, abstractmethod from typing import Any, Dict import torch import torch.distributed as dist from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner from ..utils.import_utils import is_torch_version_greater_than from ..utils.logging import get_logger from pathlib import Path if is_torch_version_greater_than("2.4"): import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, ) from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, set_model_state_dict, set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful else: Stateful = ABC logger = get_logger(__name__) _EXTRA_STATE_FORMAT = "extra_state_rank_{}.pt" _MODEL_DIR = "model" _EMA_DIR = "ema" _OPTIMIZER_DIR = "optimizer" _EXTRA_STATE_DIR = "extra_state" class ModelState(Stateful): """ A wrapper around a model to make it stateful. Args: model (Model): model to wrap. """ def __init__(self, model): self.model = model def state_dict(self): model_state_dict = get_model_state_dict(model=self.model) return {"model": model_state_dict} def load_state_dict(self, state_dict): set_model_state_dict(model=self.model, model_state_dict=state_dict["model"]) class OptimizerState(Stateful): """ A wrapper around an optimizer to make it stateful. Args: model (Model): model to wrap. optimizer (Optimizer): optimizer to wrap. """ def __init__(self, model, optimizer): self.model = model self.optimizer = optimizer def state_dict(self): optimizer_state_dict = get_optimizer_state_dict(model=self.model, optimizers=self.optimizer) return {"optim": optimizer_state_dict} def load_state_dict(self, state_dict): set_optimizer_state_dict(model=self.model, optimizers=self.optimizer, optim_state_dict=state_dict["optim"]) def build_checkpointer( dist_backend: str = "fsdp1", ckpt_manager: str = "bytecheckpoint", ): """ create a checkpointer manager with given mode. Args: dist_backend (str, optional): checkpoint mode. Defaults to "fsdp1". fsdp1: FSDP1 checkpoint from bytecheckpoint fsdp2-vescale: FSDP2 checkpoint from bytecheckpoint fsdp2: FSDP2 checkpoint from bytecheckpoint ddp: DDP checkpoint from bytecheckpoint dcp: DCP checkpoint from torch.distributed.checkpoint ckpt_manager (str, optional): checkpoint manager. Defaults to "bytecheckpoint". bytecheckpoint: bytecheckpoint checkpoint manager dcp: torch dcp checkpoint manager Raises: ValueError: if ckpt_manager is not supported Returns: Checkpointer: checkpointer with given mode. """ if ckpt_manager == "bytecheckpoint": if dist_backend == "ddp": from bytecheckpoint import DDPCheckpointer as Checkpointer elif dist_backend == "fsdp1": from bytecheckpoint import FSDPCheckpointer as Checkpointer elif dist_backend == "fsdp2-vescale": from bytecheckpoint import VeScaleCheckpointer as Checkpointer elif dist_backend == "fsdp2": from bytecheckpoint import FSDP2Checkpointer as Checkpointer elif ckpt_manager == "dcp": if not is_torch_version_greater_than("2.4"): raise ValueError("DCP checkpoint manager requires torch version >= 2.4") if dist_backend not in ["ddp", "fsdp1", "fsdp2"]: raise ValueError( f"Unsupported distributed backend: {dist_backend} for DCP checkpoint manager, supported modes are: ddp, fsdp1, fsdp2" ) Checkpointer = DistributedCheckpointer else: raise ValueError( f"Unknown checkpoint manager: {ckpt_manager}, supported modes are: bytecheckpoint, dcp, native" ) return Checkpointer class CheckpointerBase(ABC): """Base class for checkpointer""" @abstractmethod def save( cls, path: str, state: Dict[str, Any], ): return @abstractmethod def load( cls, path: str, state: Dict[str, Any], ): return class DistributedCheckpointer(CheckpointerBase): """ Distributed checkpointer for torch.distributed.checkpoint """ @classmethod def save( cls, path: str, state: Dict[str, Any], global_steps: int = None, save_async=False, ) -> None: """ save training state to distributed checkpoint args: path: path to save checkpoint state: state to save global_steps: global steps save_async: whether to save asynchronously return: None """ checkpoint_dir = f"{path}/global_step_{global_steps}" if global_steps else path os.makedirs(checkpoint_dir, exist_ok=True) if "model" not in state: raise ValueError("Model must be provided to save a distributed checkpoint.") if save_async: model_dir = os.path.join(checkpoint_dir, _MODEL_DIR) dcp.async_save( state_dict={"state": ModelState(state["model"])}, storage_writer=FileSystemWriter( model_dir, thread_count=16, single_file_per_rank=True, sync_files=False, ), ) if "ema" in state and state["ema"] is not None: ema_dir = os.path.join(checkpoint_dir, _EMA_DIR) dcp.async_save( state_dict={"state": ModelState(state["ema"])}, storage_writer=FileSystemWriter( ema_dir, thread_count=16, single_file_per_rank=True, sync_files=False, ), ) if "optimizer" in state: optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR) dcp.async_save( state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])}, storage_writer=FileSystemWriter( optimizer_dir, thread_count=16, single_file_per_rank=True, sync_files=False, ), ) else: def safe_create_writer(output_dir): tmp_path = Path(output_dir) / ".metadata.tmp" if tmp_path.exists(): print(f"Warning: removing existing tmp file: {tmp_path}") tmp_path.unlink() # remove .metadata.tmp return FileSystemWriter( output_dir, thread_count=16, single_file_per_rank=True, sync_files=False, ) model_dir = os.path.join(checkpoint_dir, _MODEL_DIR) storage_writer = safe_create_writer(model_dir) dcp.save( state_dict={"state": ModelState(state["model"])}, storage_writer=storage_writer, ) if "ema" in state and state["ema"] is not None: ema_dir = os.path.join(checkpoint_dir, _EMA_DIR) storage_writer = safe_create_writer(ema_dir) dcp.save( state_dict={"state": ModelState(state["ema"])}, storage_writer=storage_writer, ) if "optimizer" in state: optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR) dcp.save( state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])}, storage_writer=FileSystemWriter( optimizer_dir, thread_count=16, single_file_per_rank=True, sync_files=False, ), ) # dist.barrier() if "extra_state" in state: extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR) os.makedirs(extra_state_dir, exist_ok=True) extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank())) torch.save( state["extra_state"], extra_state_path, ) logger.info_rank0(f"Saved checkpoint to {checkpoint_dir}") @classmethod def load( cls, path: str, state: Dict[str, Any], process_group=None, ) -> Dict[str, Any]: """ load training state from distributed checkpoint args: path: path to load checkpoint state: state to load, "model" are required, "optimizer" and "extra_state" are optional return: state: state loaded """ checkpoint_dir = path if state is None: raise ValueError("State dict must be provided to load a distributed checkpoint.") if "model" not in state: raise ValueError("Model must be provided to load a distributed checkpoint.") if "ema" in state and state["ema"] is not None: ema_dir = os.path.join(checkpoint_dir, _EMA_DIR) dcp.load( state_dict={"state": ModelState(state["ema"])}, storage_reader=FileSystemReader(ema_dir), process_group=process_group, ) if "optimizer" in state: model_dir = os.path.join(checkpoint_dir, _MODEL_DIR) dcp.load( state_dict={"state": ModelState(state["model"])}, storage_reader=FileSystemReader(model_dir), process_group=process_group, ) optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR) try: dcp.load( state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])}, # 1043 storage_reader=FileSystemReader(optimizer_dir), # 1027 planner = DefaultLoadPlanner(allow_partial_load=True), process_group=process_group, ) except: logger.info_rank0(f"Skip loading Optimizer from {checkpoint_dir}") else: model_dir = os.path.join(checkpoint_dir, _MODEL_DIR) dcp.load( state_dict={"state": ModelState(state["model"])}, storage_reader=FileSystemReader(model_dir), process_group=process_group, ) if "extra_state" in state: extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR) os.makedirs(extra_state_dir, exist_ok=True) extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank())) state["extra_state"] = torch.load( extra_state_path, ) logger.info_rank0(f"Loaded checkpoint from {checkpoint_dir}") return state