| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import os |
| import socket |
| from typing import Any, Callable, Dict, Optional |
|
|
| import torch |
| from lightning_fabric.plugins import CheckpointIO |
| from lightning_fabric.utilities.types import _PATH |
| from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict |
| from nvidia_resiliency_ext.checkpointing.local.base_state_dict import TensorAwareStateDict |
| from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.base_manager import BaseCheckpointManager |
| from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import LocalCheckpointManager |
| from nvidia_resiliency_ext.checkpointing.local.replication.strategies import LazyCliqueReplicationStrategy |
| from nvidia_resiliency_ext.fault_tolerance.dict_utils import dict_list_map_inplace |
| from nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback import ( |
| HierarchicalCheckpointIO, |
| LocalCheckpointCallback, |
| ) |
|
|
| from nemo.lightning.pytorch.trainer import Trainer |
| from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO, AsyncFinalizableCheckpointIO |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class MCoreHierarchicalCheckpointIO(HierarchicalCheckpointIO, AsyncCompatibleCheckpointIO): |
| """HierarchicalCheckpointIO implementation compatible with MCore distributed checkpointing. |
| |
| Args: |
| wrapped_checkpoint_io (CheckpointIO): previously used checkpoint_io (for global checkpoints). |
| local_ckpt_manager (BaseCheckpointManager): local checkpoint manager used to store the local checkpoints |
| get_global_ckpt_iteration_fn (Callable[[_PATH], int]): a function that retrieves the iteration |
| of a global checkpoint that will be compared with local checkpoint iteration |
| in order to decide which to resume from. |
| async_save (bool, optional): enables asynchronous save. Passed down to the local checkpoint |
| manager unless overriden with `local_ckpt_options` in `_save_local_checkpoint`. |
| If True, MCoreHierarchicalCheckpointIO must be wrapped with `AsyncFinalizableCheckpointIO` wrapper |
| local_ckpt_algo (str, optional): local checkpoint save algorithm. See MCoreTensorAwareStateDict for details. |
| By default, uses a fully parallel save and load algorithm ('fully_parallel`). |
| parallelization_group (ProcessGroup, optional): save/load parallelization group |
| allow_cache (bool, optional): if True, subsequent checkpoint saves will reuse |
| the cached parallelization metadata. |
| """ |
|
|
| def __init__( |
| self, |
| wrapped_checkpoint_io: CheckpointIO, |
| local_ckpt_manager: BaseCheckpointManager, |
| get_global_ckpt_iteration_fn: Callable[[_PATH], int], |
| async_save: bool = False, |
| local_ckpt_algo: str = "fully_parallel", |
| parallelization_group: Optional[torch.distributed.ProcessGroup] = None, |
| allow_cache: bool = False, |
| ): |
| super().__init__(wrapped_checkpoint_io, local_ckpt_manager, get_global_ckpt_iteration_fn, async_save) |
| self.local_ckpt_algo = local_ckpt_algo |
| self.parallelization_group = parallelization_group |
| self.cached_metadata = None |
| self.allow_cache = allow_cache |
|
|
| def to_tensor_aware_state_dict(self, checkpoint: Dict[str, Any]) -> TensorAwareStateDict: |
| """Specialized implementation using MCoreTensorAwareStateDict. |
| |
| Wraps the state dict in MCoreTensorAwareStateDict and makes sure |
| that "common" state dict doesn't have any CUDA tensors. |
| """ |
| state_dict_for_save, _ = MCoreTensorAwareStateDict.from_state_dict( |
| checkpoint, |
| algo=self.local_ckpt_algo, |
| parallelization_group=self.parallelization_group, |
| cached_metadata=self.cached_metadata, |
| ) |
|
|
| def to_cpu(x): |
| if isinstance(x, torch.Tensor) and x.device.type != "cpu": |
| logger.debug("Moving CUDA tensor to CPU") |
| x = x.to("cpu", non_blocking=True) |
| return x |
|
|
| dict_list_map_inplace(to_cpu, state_dict_for_save.common) |
| if self.allow_cache: |
| self.cached_metadata = None |
| return state_dict_for_save |
|
|
| def from_tensor_aware_state_dict( |
| self, tensor_aware_checkpoint: TensorAwareStateDict, sharded_state_dict=None, strict=None |
| ): |
| """Unwraps MCoreTensorAwareStateDict to a plain state dict.""" |
| assert isinstance( |
| tensor_aware_checkpoint, MCoreTensorAwareStateDict |
| ), f"Unexpected tensor aware state dict type: {type(tensor_aware_checkpoint)}" |
| if strict is not None: |
| logger.warning("MCoreTensorAwareStateDict does not yet support the 'strict' argument.") |
|
|
| return tensor_aware_checkpoint.to_state_dict( |
| sharded_state_dict, |
| algo=self.local_ckpt_algo, |
| parallelization_group=self.parallelization_group, |
| ) |
|
|
|
|
| def update_trainer_local_checkpoint_io( |
| trainer: Trainer, |
| local_checkpoint_base_dir: str, |
| get_global_ckpt_iteration_fn: Callable[[_PATH], int], |
| **kwargs, |
| ) -> None: |
| """Update the Trainer with the corresponding MCoreHierarchicalCheckpointIO if local checkpointing is used. |
| |
| Args: |
| trainer (nl.Trainer): Trainer object to drive training loop. |
| local_checkpoint_base_dir (str): Root directory under which to save local checkpoints. |
| get_global_ckpt_iteration_fn (Callable): a function that retrieves the iteration of a global checkpoint |
| that will be compared with local checkpoint iteration in order to decide which to resume from. |
| **kwargs (dict): Additional kwargs passed to initialize MCoreHierarchicalCheckpointIO. |
| |
| Note: |
| Async saving of local checkpoints is inferred based on what was configured on the strategy, if available. |
| |
| """ |
| callbacks = trainer.callbacks |
| use_local_ckpt = any(isinstance(cb, LocalCheckpointCallback) for cb in callbacks) |
| if not use_local_ckpt: |
| return |
|
|
| checkpoint_io = trainer.strategy.checkpoint_io |
| |
| |
| async_save = getattr(trainer.strategy, "async_save", False) |
| if async_save: |
| |
| assert isinstance(checkpoint_io, AsyncFinalizableCheckpointIO), type(checkpoint_io) |
| checkpoint_io = checkpoint_io.checkpoint_io |
|
|
| if trainer.num_nodes > 1: |
| repl_strategy = LazyCliqueReplicationStrategy() |
| else: |
| |
| repl_strategy = None |
|
|
| local_ckpt_manager = LocalCheckpointManager( |
| os.path.join(local_checkpoint_base_dir, "local_ckpt", socket.gethostname()), |
| repl_strategy=repl_strategy, |
| ) |
| hierarchical_checkpointing_io = MCoreHierarchicalCheckpointIO( |
| checkpoint_io, |
| local_ckpt_manager, |
| get_global_ckpt_iteration_fn, |
| async_save=async_save, |
| **kwargs, |
| ) |
|
|
| if async_save: |
| hierarchical_checkpointing_io = AsyncFinalizableCheckpointIO(hierarchical_checkpointing_io) |
|
|
| trainer.strategy.checkpoint_io = hierarchical_checkpointing_io |
|
|