# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import shutil from datetime import timedelta from pathlib import Path from typing import Any, Dict, Iterable, List, Literal, Optional, Union import lightning import lightning.pytorch as pl import torch from _weakref import proxy from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint from lightning.pytorch.callbacks.model_checkpoint import _is_local_file_protocol from lightning.pytorch.utilities import rank_zero_info from nemo.lightning.ckpt_utils import ckpt_to_dir from nemo.lightning.io.pl import TrainerContext from nemo.utils import logging from nemo.utils.app_state import AppState class ModelCheckpoint(PTLModelCheckpoint): """Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. Adds support for asyncronous checkpointing and provides some additional logic to clean up invalid checkpoints Args: monitor: Metric to monitor when saving top-k checkpoints. verbose: Verbosity mode. save_last: When ``True``, saves a `*-last` copy whenever a checkpoint file gets saved. save_top_k: When ``True``, saves the top-k checkpoints according to ``monitor``. save_weights_only: if ``True``, then only the model's weights will be saved. Optimizer states will be omitted from all checkpoints. mode: One of {min, max}. Whether the objective is to minimize or maximize the monitored quantity. every_n_epochs: Number of epochs between checkpoints. every_n_train_steps: Number of train steps between checkpoints. train_time_interval: After each interval, monitor checkpoints. Not to be used with ``every_n_epochs`` or ``every_n_train_steps``. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint at the end of training. Only applicable when save_weights_only is ``False``. always_save_context: Whether to dump the artifacts needed to reinintialize the current model, trainer, and dataloader to allow for reproducibility of experiments. save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether ``always_save_context`` is ``True``. async_save: Whether to enable asynchronous checkpointing. Attributes: UNFINISHED_CHECKPOINT_SUFFIX (str): Suffix for unfinished checkpoint files. deferred_ckpts_to_remove (List[List[str]]): List of deferred checkpoints to remove once async save is completed. ckpts_to_link (Dict[str, str]): Dictionary of checkpoint paths that need to be symlinked. future_last_model_path (str): Path to the future 'last' checkpoint, used for symbolic linking. best_k_models (dict): Dictionary of best-k checkpoints based on the monitored metric. best_model_score (float): Score of the best checkpoint. best_model_path (str): Path to the best checkpoint. kth_best_model_path (str): Path to the kth best checkpoint. """ UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" def __init__( self, monitor: Optional[str] = "val_loss", verbose: bool = True, save_last: Optional[Union[bool, Literal["link"]]] = True, save_top_k: int = 3, save_weights_only: bool = False, # TODO: check support mode: str = "min", every_n_epochs: int = None, every_n_train_steps: Optional[int] = None, train_time_interval: Optional[timedelta] = None, # Save after training, not after validation save_on_train_epoch_end: Optional[bool] = False, save_optim_on_train_end: Optional[bool] = False, always_save_context: bool = True, save_context_on_train_end: bool = True, **kwargs, ): self.always_save_context = always_save_context self.save_context_on_train_end = save_context_on_train_end self.save_optim_on_train_end = save_optim_on_train_end # stores the next -last checkpoint to be saved, used only when save_last = 'link' # this is needed because when using symlinks, we need to update the non-last checkpoint's # last_model_path to point to the corresponding -last version self.future_last_model_path = "" # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` # is called, the last element is frozen and a new element is added. self.deferred_ckpts_to_remove: List[List[str]] = [] self.ckpts_to_link: Dict[str, str] = {} # Call the parent class constructor with the remaining kwargs. super().__init__( monitor=monitor, verbose=verbose, save_last=save_last, save_top_k=save_top_k, save_weights_only=save_weights_only, mode=mode, every_n_epochs=every_n_epochs, every_n_train_steps=every_n_train_steps, train_time_interval=train_time_interval, save_on_train_epoch_end=save_on_train_epoch_end, **kwargs, ) def on_train_start(self, trainer, pl_module): """ Initializes checkpointing by handling previous runs, setting up file logging, and managing files to move or copy. This method handles: - Moving old files to new folders - Copying relevant files to the log directory - Creating command argument and git information logs - Setting up logging for errors and Lightning logs Args: trainer (pl.Trainer): The PyTorch Lightning trainer object. pl_module (pl.LightningModule): The Lightning model to be trained. """ from nemo.utils.exp_manager import get_git_diff, get_git_hash from nemo.utils.get_rank import is_global_rank_zero from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger app_state = AppState() if self.save_top_k != -1 and app_state.restore: logging.debug("Checking previous runs") self.nemo_topk_check_previous_run() if is_global_rank_zero(): log_dir = app_state.log_dir # Check to see if any files exist that need to be moved files_to_move = app_state.files_to_move if len(files_to_move) > 0: # Move old files to a new folder other_run_dirs = Path(log_dir).glob("run_*") run_count = 0 for fold in other_run_dirs: if fold.is_dir(): run_count += 1 new_run_dir = Path(Path(log_dir) / f"run_{run_count}") if not new_run_dir.exists(): new_run_dir.mkdir() for _file in files_to_move: shutil.move(str(_file), str(new_run_dir)) # Move files_to_copy to folder and add git information if present if app_state.files_to_copy: for _file in app_state.files_to_copy: src_path = Path(_file) dst_path = Path(log_dir) / src_path.name if not dst_path.exists(): shutil.copy(src_path, dst_path) # Create files for cmd args and git info if app_state.cmd_args: cmd_args_file = log_dir / 'cmd-args.log' if not cmd_args_file.exists(): with open(cmd_args_file, 'w', encoding='utf-8') as _file: _file.write(" ".join(app_state.cmd_args)) # Try to get git hash git_repo, git_hash = get_git_hash() if git_repo: git_info_file = log_dir / 'git-info.log' if not git_info_file.exists(): with open(git_info_file, 'w', encoding='utf-8') as _file: _file.write(f'commit hash: {git_hash}\n') _file.write(get_git_diff()) # Add err_file logging to global_rank zero logging.add_err_file_handler(log_dir / 'nemo_error_log.txt') # Add lightning file logging to global_rank zero add_filehandlers_to_pl_logger(log_dir / 'lightning_logs.txt', log_dir / 'nemo_error_log.txt') if torch.distributed.is_initialized(): torch.distributed.barrier() super().on_train_start(trainer, pl_module) def nemo_topk_check_previous_run(self): """ Verifies and cleans up the top-k checkpoint state from previous training runs. This method ensures that: - The top-k models are correctly loaded and ordered. - Any outdated or invalid checkpoints are removed. - The best model is determined based on the monitored metric. Raises: AttributeError: If the expected attributes for the top-k model are not found. """ try: self.best_k_models self.kth_best_model_path self.best_model_score self.best_model_path except AttributeError: raise AttributeError( "Lightning's ModelCheckpoint was updated. NeMo's ModelCheckpoint will need an update." ) self.best_k_models = {} self.kth_best_model_path = "" self.best_model_score = None self.best_model_path = "" checkpoints = list(path for path in self._saved_checkpoint_paths if not self._is_ema_filepath(path)) for checkpoint in checkpoints: checkpoint = str(checkpoint) if checkpoint[-10:] == '-last.ckpt' or checkpoint[-5:] == '-last': continue # Find monitor in str + 1 for '=' index = checkpoint.find(self.monitor) + len(self.monitor) + 1 if index != len(self.monitor): match = re.search('[A-z]', checkpoint[index:]) if match: # -1 due to separator hyphen value = checkpoint[index : index + match.start() - 1] else: value = checkpoint[index:] self.best_k_models[checkpoint] = float(value) if len(self.best_k_models) < 1: return # No saved checkpoints yet _reverse = False if self.mode == "min" else True best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse) # This section should be ok as rank zero will delete all excess checkpoints, since all other ranks are # instantiated after rank zero. models_to_delete should be 0 for all other ranks. models_to_delete = len(best_k_models) - self.save_top_k models_to_delete = max(0, models_to_delete) logging.debug(f'Number of models to delete: {models_to_delete}') # If EMA enabled, delete the additional EMA weights ema_enabled = self._has_ema_ckpts(self._saved_checkpoint_paths) for _ in range(models_to_delete): model = best_k_models.pop(-1) self.best_k_models.pop(model) self._del_model_without_trainer(model) if ema_enabled and self._fs.exists(self._ema_format_filepath(model)): self._del_model_without_trainer(self._ema_format_filepath(model)) logging.debug(f"Removed checkpoint: {model}") self.kth_best_model_path = best_k_models[-1] self.best_model_path = best_k_models[0] self.best_model_score = self.best_k_models[self.best_model_path] def _remove_invalid_entries_from_topk(self): """ Removes invalid (incomplete or non-existing) checkpoints from the list of top-k checkpoints. This function is necessary when checkpointing might have been abruptly interrupted, leaving behind incomplete or corrupted checkpoints. The invalid checkpoints are identified by checking if their corresponding directory exists and if the checkpoint is not unfinished. After removing invalid entries, the method updates the best-k models based on the existing, valid checkpoints. Attributes Updated: - `best_k_models`: A dictionary of valid checkpoints from top-k models. - `best_model_path`: Path to the best model based on the current sorting order. - `best_model_score`: The score associated with the best model. - `kth_best_model_path`: Path to the kth best model. - `kth_value`: The score associated with the kth best model. """ # Removes invalid (incomplete or not existing) checkpoints from topk checkpoints. # This might be needed if the checkpointing was abruptly terminated. def __is_ckpt_ok(ckpt_path: str) -> bool: exists = os.path.isdir(ckpt_path.removesuffix('.ckpt')) return exists and not self.is_checkpoint_unfinished(ckpt_path) self.best_k_models = {k: v for k, v in self.best_k_models.items() if __is_ckpt_ok(k)} if len(self.best_k_models) > 0: reverse_arr = self.mode != "min" best_k_models_arr = sorted(self.best_k_models, key=self.best_k_models.get, reverse=reverse_arr) self.kth_best_model_path = best_k_models_arr[-1] self.kth_value = self.best_k_models[self.kth_best_model_path] self.best_model_path = best_k_models_arr[0] self.best_model_score = self.best_k_models[self.best_model_path] else: self.kth_best_model_path = "" self.kth_value = None self.best_model_path = "" self.best_model_score = None def state_dict(self): """ Returns the state dictionary of the model. This function adds additional logic to handle the case when using symlinks. If the model is configured to save the last checkpoint as a symlink, the path to the last checkpoint is updated in the returned state dictionary to avoid off-by-one errors in the checkpointing system. Returns: Dict[str, Any]: The state dictionary of the model, including any necessary modifications for symlinks. """ state = super().state_dict() # if using symlinks, overwrite last_model_path to avoid off-by-one issues if self.save_last == "link": state["last_model_path"] = self.future_last_model_path return state def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Loads the state dictionary into the model and removes invalid entries from the top-k checkpoints. This method ensures that after loading the model state, any invalid (incomplete or missing) checkpoints are removed from the top-k models list. Args: state_dict (Dict[str, Any]): The state dictionary to load into the model. """ super().load_state_dict(state_dict) self._remove_invalid_entries_from_topk() def setup(self, trainer, *args, **kwargs) -> None: """ Initializes the model and removes any unfinished checkpoints before training. This method is responsible for ensuring that unfinished checkpoints are removed prior to starting the training. It also synchronizes all ranks in a distributed setting to ensure that unfinished checkpoints are removed across all ranks. Args: trainer: The trainer instance used for training. *args: Additional arguments passed to the parent setup method. **kwargs: Additional keyword arguments passed to the parent setup method. """ from nemo.utils.get_rank import is_global_rank_zero if is_global_rank_zero(): logging.debug("Removing unfinished checkpoints if any...") ModelCheckpoint._remove_unfinished_checkpoints(self.dirpath) # Ensure that all ranks continue with unfinished checkpoints removed if torch.distributed.is_initialized(): torch.distributed.barrier() self.async_save = getattr(trainer.strategy, "async_save", False) super().setup(trainer, *args, **kwargs) def on_train_end(self, trainer, pl_module): """ Handles actions to be performed when training ends, such as saving the last checkpoint. This method ensures that the last checkpoint is saved if needed, particularly when validation steps aren't always run based on the interval. It also manages saving the training context to disk, if configured. Args: trainer: The trainer instance used for training. pl_module: The model being trained. """ from nemo.utils.get_rank import is_global_rank_zero if trainer.fast_dev_run: return None # check if we need to save a last checkpoint manually as validation isn't always run based on the interval if self.save_last and trainer.val_check_interval != 0: should_save_last_checkpoint = False if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0: should_save_last_checkpoint = True if isinstance(trainer.val_check_interval, int) and trainer.global_step % trainer.val_check_interval != 0: should_save_last_checkpoint = True if should_save_last_checkpoint: monitor_candidates = self._monitor_candidates(trainer) if self.last_model_path == self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST): logging.debug(f'Last checkpoint {self.last_model_path} already saved') else: super()._save_last_checkpoint(trainer, monitor_candidates) if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero(): TrainerContext.from_trainer(trainer).io_dump( ckpt_to_dir(self.last_model_path) / "context", yaml_attrs=["model"] ) # Call parent on_train_end() to save the -last checkpoint super().on_train_end(trainer, pl_module) def _del_model_without_trainer(self, filepath: str) -> None: """ Deletes the checkpoint model directory from distributed storage without requiring the trainer. This method ensures that distributed checkpoints are properly removed when necessary, especially if the model file is no longer needed or is incomplete. The removal only happens on the rank-zero process. Args: filepath (str): The path to the checkpoint model file to be deleted. """ from nemo.utils.get_rank import is_global_rank_zero filepath = Path(filepath) if is_global_rank_zero(): try: dist_ckpt = ckpt_to_dir(filepath) shutil.rmtree(dist_ckpt, ignore_errors=True) logging.info(f"Removed distributed checkpoint: {dist_ckpt}") except: logging.info(f"Tried to remove distributed checkpoint: {dist_ckpt} but failed.") if torch.distributed.is_initialized(): torch.distributed.barrier() def _ema_callback(self, trainer: 'lightning.pytorch.Trainer'): """ Retrieves the Exponential Moving Average (EMA) callback from the list of trainer callbacks. This method scans through the list of callbacks attached to the trainer and returns the EMA callback instance if present. The EMA callback is often used to track the exponential moving average of model parameters during training. Args: trainer ('lightning.pytorch.Trainer'): The trainer instance. Returns: EMA: The EMA callback instance if found, or None if not present. """ from nemo.collections.common.callbacks import EMA ema_callback = None for callback in trainer.callbacks: if isinstance(callback, EMA): ema_callback = callback return ema_callback @staticmethod def format_checkpoint_unfinished_marker_path(checkpoint_path: Union[Path, str]) -> Path: """Format the path to the unfinished checkpoint marker file. If the marker file exists, corresponding checkpoint is considered unfinished/incomplete. NOTE: Marker path for the EMA checkpoint part is the same as for the original checkpoint. Args: checkpoint_path: Path to the checkpoint file or dir. Does not need to exist. Returns: Path to the unfinished checkpoint marker file. """ marker_filepath = str(checkpoint_path).removesuffix(".ckpt") marker_filepath = marker_filepath.removesuffix("-EMA") return Path(marker_filepath + ModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX) @staticmethod def is_checkpoint_unfinished(checkpoint_path: Union[Path, str]) -> bool: """Check if the checkpoint is unfinished. Args: checkpoint_path: Path to the checkpoint file or dir. Does not need to exist. Returns: True if the checkpoint is unfinished, False otherwise. """ return ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path).exists() @staticmethod def set_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_after=False) -> None: """Marks given checkpoint as unfinished. Args: checkpoint_filepath: Path to the checkpoint file or dir. Does not need to exist. barrier_after: Synchronize ranks after writing the marker file. Defaults to False. """ from nemo.utils.get_rank import is_global_rank_zero if is_global_rank_zero(): marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) marker_path.parent.mkdir(parents=True, exist_ok=True) marker_path.touch() if barrier_after and torch.distributed.is_initialized(): torch.distributed.barrier() @staticmethod def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_before=False) -> None: """Clear unfinished marker for given checkpoint. Args: checkpoint_path: Path to the checkpoint file or dir. Does not need to exist. barrier_before: Synchronize ranks before removing the marker file. Defaults to False. """ from nemo.utils.get_rank import is_global_rank_zero try: if barrier_before and torch.distributed.is_initialized(): torch.distributed.barrier() if is_global_rank_zero(): marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) if marker_path.exists(): marker_path.unlink() except: return def file_exists(self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True) -> bool: """Checks if a file or a file without a suffix (distributed checkpoint) exists.""" exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(str(ckpt_to_dir(filepath)))) return trainer.strategy.broadcast(exists) def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]: """Broadcast loss from last pipeline stage.""" monitor_candidates = super()._monitor_candidates(trainer) from nemo.lightning._strategy_lib import _sync_from_last_pipeline_stage from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy keys = re.findall(r"[\{](.*?)[:\}]", self.filename) for loss_name in ['reduced_train_loss']: if loss_name in keys or loss_name == self.monitor: if loss_name not in monitor_candidates: monitor_candidates[loss_name] = torch.tensor(0.0, device=torch.cuda.current_device()) if isinstance(trainer.strategy, MegatronStrategy): _sync_from_last_pipeline_stage(monitor_candidates[loss_name], broadcast=True) return monitor_candidates def _link_checkpoint(self, trainer: "pl.Trainer", filepath: str, linkpath: str, override_async=False) -> None: """Check to see whether this step has already been saved as top_k in which case we can create a symlink otherwise, we have to save the checkpoint """ saved_current_step = str(ckpt_to_dir(linkpath)).replace("-last", "") == str(ckpt_to_dir(filepath)) if not saved_current_step: self._save_checkpoint(trainer, linkpath) return # linking will happen as part of the finalize fn if self.async_save and not override_async: self.ckpts_to_link[str(filepath)] = str(linkpath) return filepath = ckpt_to_dir(filepath) linkpath = ckpt_to_dir(linkpath) super()._link_checkpoint(trainer, filepath, linkpath) def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None: """Saves the checkpoint to the given filepath Args: trainer (lightning.pytorch.Trainer): the trainer obj filepath (str): path to save checkpoint to. Raises: ValueError: (mcore) async_save with EMA not supported ValueError: (mcore) Async save requires async compatible CheckpointIO """ from nemo.utils.get_rank import is_global_rank_zero # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) self._last_global_step_saved = trainer.global_step # manually update last_model_path so symlink is up-to-date # should only be done when using a symlink if self.save_last == "link": self.future_last_model_path = str(ckpt_to_dir(filepath)) if not str(ckpt_to_dir(filepath)).endswith("last"): self.future_last_model_path += "-last.ckpt" if ema_callback is not None: if self.async_save: raise ValueError('async_save with EMA not supported') with ema_callback.save_original_optimizer_state(trainer): super()._save_checkpoint(trainer, filepath) # save EMA copy of the model as well. with ema_callback.save_ema_model(trainer): rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") filepath = self._ema_format_filepath(filepath) if self.verbose: rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") super()._save_checkpoint(trainer, filepath) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: # Determine whether to include optimizer states in the checkpoint # optimizer states are included when # 1. save_weights_only is False and # 2. either save_optim_on_train_end is True, or save_optim_on_train_end is False but the checkpoint # is an intermediate checkpoint. save_weights_only = self.save_weights_only or ( not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps ) # Async save passes the finalization function to checkpoint_io, # sync save calls the finalization function immediately after save. finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step) if self.async_save: checkpoint_io = trainer.strategy.checkpoint_io from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): raise ValueError('Async save requires async compatible CheckpointIO') storage_options = dict(finalize_fn=finalize_fn) # Each upcoming ckpt removal request will be executed as part of this save finalization self.deferred_ckpts_to_remove.append([]) else: storage_options = None trainer.save_checkpoint(filepath, save_weights_only, storage_options=storage_options) if self.always_save_context and is_global_rank_zero(): TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context", yaml_attrs=["model"]) if self.async_save: self._last_checkpoint_saved = filepath logging.info(f'Scheduled async checkpoint save for {filepath}') else: finalize_fn() def _get_finalize_save_checkpoint_callback( self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int ): """Creates a callback that can be used to finalize async (and sync) ckpt saves.""" def _cb(): logging.debug(f'Finalize callback called for step {global_step}, filepath {filepath}') self._last_checkpoint_saved = filepath # notify loggers if trainer.is_global_zero: for logger in trainer.loggers: logger.after_save_checkpoint(proxy(self)) # barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker # we don't want to remove the marker until all checkpointing is done. self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) if not self.async_save: return logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.') if str(filepath) in self.ckpts_to_link: self._link_checkpoint(trainer, filepath, self.ckpts_to_link.pop(filepath), override_async=True) # Remove checkpoints marked for removal by `self._remove_checkpoint` # For each finalization there is exactly one entry in self.deferred_ckpts_to_remove assert self.deferred_ckpts_to_remove ckpts_to_remove = self.deferred_ckpts_to_remove.pop(0) logging.debug(f'Checkpoints to remove: {ckpts_to_remove}') for ckpt_to_remove in ckpts_to_remove: self._remove_checkpoint(trainer, ckpt_to_remove, override_async=True) return _cb def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str, override_async=False) -> None: """Performs checkpoint removal. With async save, `self._remove_checkpoint` is called before the checkpoint is actually finished so we can't remove it. Instead we add it to `self.deferred_ckpts_to_remove` for future removal. """ if self.async_save and not override_async: # Register checkpoint removal in the last (active) checkpoint removal list if len(self.deferred_ckpts_to_remove) == 0: self.deferred_ckpts_to_remove.append([]) self.deferred_ckpts_to_remove[-1].append(filepath) return # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during removal, we should be able to detect that data is incomplete. self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) try: super()._remove_checkpoint(trainer, filepath) except Exception as e: logging.warning( f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}' ) ema_callback = self._ema_callback(trainer) if ema_callback is not None: # remove EMA copy of the state dict as well. filepath = self._ema_format_filepath(filepath) try: super()._remove_checkpoint(trainer, filepath) except Exception as e: logging.warning( f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}' ) # barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker # we don't want to remove the marker until the checkpoint is actually removed. self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) def _ema_format_filepath(self, filepath: str) -> str: """Formats given path for EMA checkpoint Args: filepath (str): filepath Returns: str: EMA-formatted filepath """ return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool: """Checkes whether filepaths are EMA-formatted Args: checkpoints (Iterable[Path]): paths to check Returns: bool: True indicates path is EMA-formatted. """ return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints) def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: """Checkes whether filepaths are EMA-formatted Args: filepath (Union[Path, str]): path to check Returns: bool: True indicates path is EMA-formatted. """ return str(filepath).endswith(f'-EMA{self.FILE_EXTENSION}') @property def _saved_checkpoint_paths(self) -> Iterable[Path]: """ Retrieves a list of saved checkpoint paths while filtering out unfinished checkpoints. - If distributed checkpoints (directories) exist, return only those. - Otherwise, return individual checkpoint files with a .ckpt extension. - Filters out any checkpoints that are marked as unfinished. Returns: Iterable[Path]: An iterable containing valid checkpoint paths. """ # distributed checkpoints are directories so we check for them here # we filter out unfinished checkpoints, these should be deleted during next cleanup dist_checkpoints = [d for d in Path(self.dirpath).glob("*") if d.is_dir()] if dist_checkpoints: return filter(lambda p: not self.is_checkpoint_unfinished(p), dist_checkpoints) else: checkpoint_files = [f for f in Path(self.dirpath).rglob("*.ckpt")] return filter(lambda p: not self.is_checkpoint_unfinished(p), checkpoint_files) @staticmethod def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None: """ Removes all unfinished checkpoints and their associated marker files from the filesystem. - Ensures this function runs only on rank 0. - Deletes individual unfinished checkpoint files. - Removes directories corresponding to unfinished distributed checkpoints. - Deletes the marker files indicating unfinished checkpoints. Args: checkpoint_dir (Union[Path, str]): Path to the directory containing checkpoints. Raises: AssertionError: If the function is called from a non-rank 0 process. """ from nemo.utils.get_rank import is_global_rank_zero # Delete unfinished checkpoints from the filesystems. # "Unfinished marker" files are removed as well. if not is_global_rank_zero(): raise AssertionError("_remove_unfinished_checkpoints should run only on rank 0") checkpoint_dir = Path(checkpoint_dir) existing_marker_filepaths = { f.resolve() for f in checkpoint_dir.glob(f"*{ModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}") if f.is_file() } checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")} for filepath in checkpoint_filepaths: possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(filepath) if possible_marker_path in existing_marker_filepaths: logging.warning(f'Removing unfinished checkpoint: {filepath}') os.remove(filepath) # some directories might be distributed checkpoints, we remove these if they have a unfinished marker all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()} for ckpt_dirpath in all_dirpaths: possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_dirpath) if possible_marker_path in existing_marker_filepaths: logging.warning(f'Removing unfinished dist checkpoint: {ckpt_dirpath}') shutil.rmtree(ckpt_dirpath) # delete markers for marker_path in existing_marker_filepaths: os.remove(marker_path) def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool: """Checks if the previous checkpoint should be deleted. A checkpoint won't be deleted if any of the cases apply: - The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new) - The previous checkpoint is not in the current checkpoint directory and the filesystem is local - The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local and the resumed from checkpoint is not the last checkpoint """ if previous == current: return False if not _is_local_file_protocol(previous): return True previous = Path(previous).absolute() resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None if resume_path is not None and previous == resume_path: if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"): # delete the previous `-last.ckpt` checkpoint when current saved checkpoint # is also `-last.ckpt`, if they're in the same directory pass else: return False if self.dirpath is None: raise ValueError(f"{self.__class__}.dirpath is None.") dirpath = Path(self.dirpath).absolute() return dirpath in previous.parents