| |
| |
| |
| |
| |
|
|
| import enum |
| import functools |
| import os |
| import queue |
| import re |
| import shutil |
| import threading |
| import time |
| from typing import Any |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.distributed.checkpoint as dcp |
| import torch.multiprocessing as mp |
| import torch.nn as nn |
| from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict |
| from torch.distributed.checkpoint.state_dict import ( |
| get_model_state_dict, |
| set_model_state_dict, |
| StateDictOptions, |
| ) |
| from torch.distributed.checkpoint.stateful import Stateful |
| from torch.utils.data import DataLoader |
|
|
| from torchtitan.components.ft import FTManager |
| from torchtitan.components.lr_scheduler import LRSchedulersContainer |
| from torchtitan.components.optimizer import OptimizersContainer |
| from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP |
| from torchtitan.tools.logging import init_logger, logger |
| from torchtitan.tools.utils import GarbageCollection |
|
|
|
|
| MODEL = "model" |
| OPTIMIZER = "optimizer" |
| LR_SCHEDULER = "lr_scheduler" |
| DATALOADER = "dataloader" |
| TRAIN_STATE = "train_state" |
|
|
|
|
| class AsyncMode(str, enum.Enum): |
| DISABLED = "disabled" |
| ASYNC = "async" |
| ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" |
|
|
|
|
| class ModelWrapper(Stateful): |
| def __init__(self, model: nn.Module | list[nn.Module]) -> None: |
| self.model = [model] if isinstance(model, nn.Module) else model |
| self.cache_state_dict = { |
| k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() |
| } |
|
|
| def state_dict(self) -> dict[str, Any]: |
| return self.cache_state_dict |
|
|
| def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
| func = functools.partial( |
| set_model_state_dict, |
| model_state_dict=state_dict, |
| options=StateDictOptions(strict=False), |
| ) |
| list(map(func, self.model)) |
| |
| |
| self.cache_state_dict = { |
| k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() |
| } |
|
|
|
|
| class Terminate: |
| pass |
|
|
|
|
| class SaveDone: |
| pass |
|
|
|
|
| @torch.no_grad() |
| def save_with_gc(state, checkpoint_id): |
| dcp.save(state, checkpoint_id=checkpoint_id) |
| GarbageCollection.collect("GC collection invoked by checkpointer.") |
|
|
|
|
| def checkpoint_mp(recv: mp.Queue, send: mp.Queue): |
| """Process to save the checkpoint in the background. |
| |
| This is only used when async_checkpoint_with_pinned_memory is enabled. |
| |
| Args: |
| recv (mp.Queue): The queue to receive the state_dict and Terminate signal. |
| send (mp.Queue): The queue to send the SaveDone signal. |
| """ |
| init_logger() |
| os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2) |
| os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" |
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) |
| dist.init_process_group() |
| try: |
| while True: |
| logger.debug("Checkpoint background process is done.") |
| send.put(SaveDone()) |
| logger.debug("Wait for the new state_dict.") |
| obj = recv.get() |
| logger.debug("Received the new state_dict.") |
| if isinstance(obj, Terminate): |
| logger.info("Terminating the checkpoint background process.") |
| return |
| assert isinstance(obj, tuple) |
| begin = time.monotonic() |
| state, checkpoint_id = obj |
| save_with_gc(state, checkpoint_id=checkpoint_id) |
| logger.info( |
| "Finish saving the checkpoint in the background process in %.2f seconds.", |
| time.monotonic() - begin, |
| ) |
| finally: |
| logger.info("Destroying the process group.") |
| dist.destroy_process_group() |
|
|
|
|
| def purge_thread(purge_queue: queue.Queue): |
| """Thread to purge the old checkpoints. |
| |
| This is only used when keep_latest_k > 0. |
| |
| Args: |
| purge_queue (queue.Queue): The queue to receive the path to purge and Terminate signal. |
| """ |
| try: |
| while True: |
| path = purge_queue.get() |
| if isinstance(path, Terminate): |
| return |
| assert isinstance(path, str) |
| logger.info("Checkpointer is deleting %s.", path) |
| begin = time.monotonic() |
| shutil.rmtree(path, ignore_errors=True) |
| logger.info( |
| "Checkpointer deleted %s in %.2f seconds.", |
| path, |
| time.monotonic() - begin, |
| ) |
| finally: |
| logger.info("Destroying the purge thread.") |
|
|
|
|
| class CheckpointManager: |
| """This class manages the checkpointing logic for the TorchTitan trainer. |
| |
| |
| Note: Pipeline Parallelism and Virtual Stages |
| |
| 1. even for simple PP schedules, there is a separate optimizer each PP rank. |
| rank0's optimizer would have a param_group[0] which refers to layers.0 in the original |
| model. rank1's would _also_ have a param_group[0], since it's index based, but |
| referring to layers.1. When saving, these collide and one of them is lost. Then when |
| reloading, only one stage can restore its optimizer states, others will error. |
| |
| The solution to this problem is optimizer flattening: it landed in #127071 and is |
| enabled in TorchTitan by passing the 'flatten_optimizer_state_dict' kwarg to DCP |
| functions called in the OptimizerContainer. |
| See PR #127071 (https://github.com/pytorch/pytorch/pull/127071) for the example of |
| a flattening state_dict. |
| |
| 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds |
| challenge (1) by also requiring us to reason about multiple 'optim' objects locally. |
| |
| We solve this in the Model and Optimizer wrapper classes by flattening the state dicts |
| from each object into one state dict before saving/loading. We rely on the individual |
| state_dicts to not collide, which is gauranteed for the model by correct pipeline |
| splitting and for the optimizer by the flattening support described in (1). |
| |
| 3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers |
| with the assumption that all lr_schedulers have the same state_dict. |
| |
| Note: TorchFT checkpointing flow |
| |
| There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent |
| checkpoint, 2) the per-replica checkpoint. |
| |
| The full perisistent checkpoint is saved by the replica with |
| ``ft_manager.participating_rank() == 0``. It contains everything including the model, |
| optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent |
| checkpoint is loaded by all replicas. However, we can optimize it to only load if |
| there are no other alive replicas. |
| |
| The per-replica checkpoint contains only the dataloader and is saved/loaded by all |
| replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id. |
| |
| Args: |
| dataloader (DataLoader): The dataloader used to load the data. |
| model_parts (List[nn.Module]): List of model parts to be optimized. |
| optimizers (OptimizersContainer): The optimizers used to optimize the model. |
| lr_schedulers (LRSchedulersContainer): The lr schedulers used to optimize the model. |
| states (Dict[str, Any]): The states that need to be saved, other than the |
| previous 4 components. |
| job_config (JobConfig): The job config used to configure the checkpointing. |
| ft_manager (Optional[ft.Manager]): The FTManager from TorchFT. |
| """ |
|
|
| def __init__( |
| self, |
| dataloader: DataLoader, |
| model_parts: list[nn.Module], |
| optimizers: OptimizersContainer, |
| lr_schedulers: LRSchedulersContainer, |
| states: dict[str, Any], |
| job_config: JobConfig, |
| ft_manager: FTManager, |
| ) -> None: |
| ckpt_config = job_config.checkpoint |
| self.enable_checkpoint = ckpt_config.enable_checkpoint |
| self.ft_manager = ft_manager.manager if ft_manager.enabled else None |
|
|
| if self.ft_manager: |
| optimizers.init_cache_state_dict() |
|
|
| def state_dict(): |
| ret = {} |
| for k, v in self.states.items(): |
| if k in { |
| MODEL, |
| OPTIMIZER, |
| LR_SCHEDULER, |
| TRAIN_STATE, |
| }: |
| ret[k] = v.state_dict() |
| return ret |
|
|
| def load_state_dict(state_dict): |
| assert state_dict is not None |
| for k, v in state_dict.items(): |
| self.states[k].load_state_dict(v) |
|
|
| self.ft_manager.set_state_dict_fns(load_state_dict, state_dict) |
| self.ft_replica_id = job_config.fault_tolerance.replica_id |
|
|
| async_mode = ckpt_config.async_mode.lower() |
| self.enable_staging = ( |
| self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM |
| ) or self.ft_manager |
|
|
| if not self.enable_checkpoint and self.ft_manager is None: |
| return |
|
|
| self.states = states |
| self.states.update( |
| { |
| MODEL: ModelWrapper(model_parts), |
| OPTIMIZER: optimizers, |
| DATALOADER: dataloader, |
| LR_SCHEDULER: lr_schedulers, |
| } |
| ) |
| self.ft_states = {DATALOADER: dataloader} |
|
|
| self.staging = False |
| self.sending_to_checkpoint_mp = False |
| self.staging_id = None |
| self.cpu_offload_state_dict = None |
| self.staging_stream = torch.cuda.Stream() if self.enable_staging else None |
|
|
| self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) |
| self.interval = ckpt_config.interval |
| async_mode = ckpt_config.async_mode.lower() |
| if async_mode == AsyncMode.ASYNC or self.ft_manager: |
| self.pg = dist.new_group(backend="gloo") |
|
|
| self.keep_latest_k = ckpt_config.keep_latest_k |
| if self.keep_latest_k > 0: |
| if self.keep_latest_k == 1: |
| raise ValueError( |
| "We need to maintain at least 2 checkpoint replicas, " |
| "as the last one may be in the process of being saved." |
| ) |
| self.purge_queue = queue.Queue() |
| self.purge_thread = threading.Thread( |
| target=purge_thread, args=(self.purge_queue,), daemon=True |
| ) |
| self.purge_thread.start() |
| else: |
| self.purge_thread = None |
|
|
| self.model_weights_only = ckpt_config.model_weights_only |
| self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] |
| self.exclude_from_loading = ckpt_config.exclude_from_loading |
|
|
| self.mp = None |
| if async_mode == AsyncMode.DISABLED: |
| self.async_mode = AsyncMode.DISABLED |
| elif async_mode == AsyncMode.ASYNC: |
| self.async_mode = AsyncMode.ASYNC |
| self.async_future = None |
| elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: |
| self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM |
| ctx = mp.get_context("spawn") |
| self.mp_queue_send = ctx.Queue() |
| self.mp_queue_recv = ctx.Queue() |
| self.mp = ctx.Process( |
| target=checkpoint_mp, |
| args=( |
| self.mp_queue_send, |
| self.mp_queue_recv, |
| ), |
| daemon=True, |
| ) |
| self.mp.start() |
| else: |
| raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}") |
|
|
| logger.info( |
| f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" |
| ) |
|
|
| def __del__(self): |
| self.close() |
|
|
| def close(self): |
| if self.enable_checkpoint: |
| if self.mp and self.mp.is_alive(): |
| self.mp_queue_send.put(Terminate()) |
| self.mp.join() |
| if self.purge_thread and self.purge_thread.is_alive(): |
| self.purge_queue.put(Terminate()) |
| self.purge_thread.join() |
|
|
| @torch.no_grad() |
| def save(self, curr_step: int, force: bool = False) -> None: |
| """Save the checkpoint for the current step. |
| |
| This function will save the checkpoint for the current step. If ``force`` is |
| true, it will save the checkpoint even if the interval has not been reached. |
| This only happens when train_state.step == job_config.training.steps, or |
| for initial seed checkpoint. |
| |
| Args: |
| curr_step (int): The current step. |
| force (bool, optional): Whether to force save the checkpoint. Defaults to False. |
| |
| Returns: |
| None |
| """ |
|
|
| if self.ft_manager: |
| self._ft_save(curr_step) |
|
|
| if not self._should_save(curr_step, force): |
| return |
|
|
| begin = time.monotonic() |
| if not self.ft_manager or self.ft_manager.participating_rank() == 0: |
| logger.info("Saving the checkpoint (or staging if async is enabled).") |
| checkpoint_id = self._create_checkpoint_id(curr_step) |
| self._async_wait() |
| |
| |
| |
| if force: |
| self._save_last_step(curr_step) |
| elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: |
| GarbageCollection.collect("GC collection invoked by checkpointer.") |
| self._async_with_pinned_memory(checkpoint_id) |
| elif self.async_mode == AsyncMode.ASYNC: |
| GarbageCollection.collect("GC collection invoked by checkpointer.") |
| self.async_future = dcp.async_save( |
| self.states, checkpoint_id=checkpoint_id, process_group=self.pg |
| ) |
| GarbageCollection.collect("GC collection invoked by checkpointer.") |
| else: |
| save_with_gc(self.states, checkpoint_id=checkpoint_id) |
| self._purge_stale_checkpoints() |
|
|
| logger.info( |
| "Finished saving the checkpoint (or staging if async is enabled)" |
| f"in {time.monotonic() - begin:.2f} seconds." |
| ) |
| elif self.ft_manager: |
| logger.info( |
| "Replica %d doesn't save checkpoint.", |
| self.ft_manager.participating_rank(), |
| ) |
|
|
| @torch.no_grad() |
| def load(self, step: int = -1) -> bool: |
| """Load the checkpoint for the given step. |
| |
| This function will load the checkpoint for the given step. If ``step`` is -1, it |
| will load the latest checkpoint. If the checkpoint does not exist, it will return |
| False and load nothing. |
| |
| Args: |
| step (int, optional): The step to load the checkpoint for. Defaults to -1. |
| |
| Returns: |
| bool: Whether the checkpoint was loaded successfully. |
| """ |
|
|
| if self.ft_manager: |
| self._ft_load() |
|
|
| if not self.enable_checkpoint or not os.path.isdir(self.folder): |
| return False |
|
|
| if step == -1: |
| step = self._find_load_step() |
| if step == -1: |
| return False |
|
|
| checkpoint_id = self._create_checkpoint_id(step) |
| if not os.path.isdir(checkpoint_id): |
| return False |
|
|
| logger.info(f"Loading the checkpoint at step {step}.") |
| begin = time.monotonic() |
| states = self._states_to_load(step) |
| dcp.load(states, checkpoint_id=checkpoint_id) |
| GarbageCollection.collect("GC collection for checkpoint loading.") |
| logger.info( |
| f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." |
| ) |
| return True |
|
|
| def maybe_wait_for_staging(self) -> None: |
| """Wait for the staging to finish if it is enabled. |
| |
| This function will wait for staging to finish. The staging is only enabled |
| with ``async_checkpoint_with_pinned_memory``. |
| """ |
| if self.enable_staging and self.staging: |
| if not self.staging_stream.query(): |
| begin = time.monotonic() |
| self.staging_stream.synchronize() |
| logger.info( |
| "Checkpointer waited staging %.2f seconds.", |
| time.monotonic() - begin, |
| ) |
| self.staging = False |
|
|
| if self.sending_to_checkpoint_mp: |
| |
| def sync_func(): |
| self.mp_queue_send.put_nowait( |
| (self.cpu_offload_state_dict, self.staging_id) |
| ) |
|
|
| |
| |
| |
| |
| begin = time.monotonic() |
| sync_func() |
| logger.info( |
| "Checkpointer sent staged state_dict to another process %.2f seconds", |
| time.monotonic() - begin, |
| ) |
| self.sending_to_checkpoint_mp = False |
|
|
| def _find_load_step(self, folder: str = "") -> int: |
| """Find the step to load the checkpoint for. |
| |
| Args: |
| folder (str, optional): The folder to find the checkpoint for. If ``folder`` |
| is "", then ``self.folder`` will be used. |
| |
| Returns: |
| int: The step to load the checkpoint for. |
| """ |
| folder = folder if folder else self.folder |
| pattern = r"step-(\d+)" |
| step_counts = [] |
|
|
| if not os.path.isdir(folder): |
| return -1 |
|
|
| for filename in os.listdir(folder): |
| match = re.search(pattern, filename) |
| metadata_probe = os.path.join(folder, filename, ".metadata") |
| if match and os.path.isfile(metadata_probe): |
| step_counts.append(int(match.group(1))) |
| if not step_counts: |
| return -1 |
| return max(step_counts) |
|
|
| def _ft_folder(self) -> str: |
| return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}") |
|
|
| def _create_checkpoint_id(self, step: int, folder: str = "") -> str: |
| folder = folder if folder else self.folder |
| return os.path.join(folder, f"step-{step}") |
|
|
| def _ft_save(self, step: int) -> None: |
| begin = time.monotonic() |
| self._async_wait() |
| checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) |
| self.async_future = dcp.async_save( |
| self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg |
| ) |
| logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.") |
|
|
| def _ft_load(self) -> None: |
| step = self._find_load_step(folder=self._ft_folder()) |
| if step == -1: |
| return |
|
|
| begin = time.monotonic() |
| logger.info(f"Loading the FT checkpoint at step {step}.") |
| checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) |
| dcp.load(self.ft_states, checkpoint_id=checkpoint_id) |
| GarbageCollection.collect("GC collection for checkpoint loading.") |
| logger.info( |
| f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds." |
| ) |
|
|
| def _states_to_load(self, step: int) -> dict[str, Any]: |
| """Determines which states to load for the given step. |
| |
| When checkpointer determines which step of the checkpoint to load, this API is |
| used to determine which states to load based on the step. |
| |
| Args: |
| step (int): The step to load the checkpoint for. |
| |
| Returns: |
| Dict[str, Any]: The states to load for the given step. |
| """ |
| |
| states = {MODEL: self.states[MODEL]} if step == 0 else self.states |
| states_to_load = { |
| k: v for k, v in states.items() if k not in self.exclude_from_loading |
| } |
| for exclude_key in self.exclude_from_loading: |
| if exclude_key not in states: |
| raise ValueError(f"{exclude_key} not found in state_dict.") |
| if self.ft_manager: |
| states_to_load.pop(DATALOADER) |
| return states_to_load |
|
|
| def _save_last_step(self, curr_step: int) -> None: |
| |
| |
| |
| |
|
|
| if self.model_weights_only: |
| |
| |
| |
| |
| |
| self.states = self.states[MODEL].state_dict() |
|
|
| |
| |
| |
| self.states.pop("freqs_cis") |
|
|
| if self.export_dtype != torch.float32: |
| self.states = { |
| k: v.to(self.export_dtype) for k, v in self.states.items() |
| } |
| logger.info( |
| f"Saving a model weights only checkpoint in {self.export_dtype} " |
| f"at last step, step {curr_step}." |
| ) |
| else: |
| logger.info(f"Saving a full checkpoint at last step, step {curr_step}.") |
|
|
| save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) |
|
|
| def _should_save(self, curr_step: int, force: bool = False) -> bool: |
| if not self.enable_checkpoint: |
| return False |
|
|
| |
| |
| if curr_step == 1: |
| return True |
|
|
| if force: |
| return True |
|
|
| if curr_step % self.interval == 0: |
| return True |
|
|
| return False |
|
|
| def _async_wait(self) -> None: |
| if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: |
| logger.debug( |
| f"Waiting for the background process to finish, {time.monotonic()=}.:.2f" |
| ) |
| if not self.mp.is_alive(): |
| raise RuntimeError("The checkpoint background process is dead.") |
| _ = self.mp_queue_recv.get() |
| elif self.async_mode == AsyncMode.ASYNC: |
| if self.async_future is not None: |
| self.async_future.result() |
|
|
| def _async_with_pinned_memory(self, checkpoint_id: str) -> None: |
| self._cpu_staging(checkpoint_id) |
| self.sending_to_checkpoint_mp = True |
|
|
| def _cpu_staging(self, checkpoint_id: str | None) -> None: |
| """Offload state_dict to CPU memory""" |
| state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states) |
| if self.cpu_offload_state_dict is None: |
| logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f") |
| self.cpu_offload_state_dict = _create_cpu_state_dict( |
| state_dict, pin_memory=True, share_memory=True |
| ) |
|
|
| logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f") |
| with torch.cuda.stream(self.staging_stream): |
| self.cpu_offload_state_dict = _copy_state_dict( |
| state_dict, |
| self.cpu_offload_state_dict, |
| non_blocking=True, |
| ) |
| self.staging = True |
| self.staging_id = checkpoint_id |
|
|
| def _purge_stale_checkpoints(self): |
| if ( |
| self.keep_latest_k > 0 |
| and dist.get_rank() == 0 |
| and os.path.isdir(self.folder) |
| and (not self.ft_manager or self.ft_manager.participating_rank() == 0) |
| ): |
| discovered_checkpoints = [] |
| for filename in os.listdir(self.folder): |
| match = re.search(r"step-(\d+)", filename) |
| path = os.path.join(self.folder, filename) |
| discovered_checkpoints.append((int(match.group(1)), path)) |
|
|
| discovered_checkpoints.sort() |
| to_delete = discovered_checkpoints[: -1 * self.keep_latest_k] |
|
|
| for _, path in to_delete: |
| assert self.purge_thread is not None |
| self.purge_queue.put(path) |
|
|