| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import re |
| | import shutil |
| | from datetime import timedelta |
| | from pathlib import Path |
| | from typing import Any, Dict, Iterable, List, Literal, Optional, Union |
| |
|
| | import lightning |
| | import lightning.pytorch as pl |
| | import torch |
| | from _weakref import proxy |
| | from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint |
| | from lightning.pytorch.callbacks.model_checkpoint import _is_local_file_protocol |
| | from lightning.pytorch.utilities import rank_zero_info |
| |
|
| | from nemo.lightning.ckpt_utils import ckpt_to_dir |
| | from nemo.lightning.io.pl import TrainerContext |
| | from nemo.utils import logging |
| | from nemo.utils.app_state import AppState |
| |
|
| |
|
| | class ModelCheckpoint(PTLModelCheckpoint): |
| | """Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. |
| | Adds support for asyncronous checkpointing and provides some additional logic to clean up invalid checkpoints |
| | |
| | Args: |
| | monitor: Metric to monitor when saving top-k checkpoints. |
| | verbose: Verbosity mode. |
| | save_last: When ``True``, saves a `*-last` copy whenever a checkpoint file gets saved. |
| | save_top_k: When ``True``, saves the top-k checkpoints according to ``monitor``. |
| | save_weights_only: if ``True``, then only the model's weights will be saved. Optimizer states will |
| | be omitted from all checkpoints. |
| | mode: One of {min, max}. Whether the objective is to minimize or maximize the monitored quantity. |
| | every_n_epochs: Number of epochs between checkpoints. |
| | every_n_train_steps: Number of train steps between checkpoints. |
| | train_time_interval: After each interval, monitor checkpoints. Not to be used with |
| | ``every_n_epochs`` or ``every_n_train_steps``. |
| | save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch |
| | save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint |
| | at the end of training. Only applicable when save_weights_only is ``False``. |
| | always_save_context: Whether to dump the artifacts needed to reinintialize the current |
| | model, trainer, and dataloader to allow for reproducibility of experiments. |
| | save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether |
| | ``always_save_context`` is ``True``. |
| | async_save: Whether to enable asynchronous checkpointing. |
| | |
| | Attributes: |
| | UNFINISHED_CHECKPOINT_SUFFIX (str): Suffix for unfinished checkpoint files. |
| | deferred_ckpts_to_remove (List[List[str]]): List of deferred checkpoints |
| | to remove once async save is completed. |
| | ckpts_to_link (Dict[str, str]): Dictionary of checkpoint paths that need to be symlinked. |
| | future_last_model_path (str): Path to the future 'last' checkpoint, used for symbolic linking. |
| | best_k_models (dict): Dictionary of best-k checkpoints based on the monitored metric. |
| | best_model_score (float): Score of the best checkpoint. |
| | best_model_path (str): Path to the best checkpoint. |
| | kth_best_model_path (str): Path to the kth best checkpoint. |
| | """ |
| |
|
| | UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" |
| |
|
| | def __init__( |
| | self, |
| | monitor: Optional[str] = "val_loss", |
| | verbose: bool = True, |
| | save_last: Optional[Union[bool, Literal["link"]]] = True, |
| | save_top_k: int = 3, |
| | save_weights_only: bool = False, |
| | mode: str = "min", |
| | every_n_epochs: int = None, |
| | every_n_train_steps: Optional[int] = None, |
| | train_time_interval: Optional[timedelta] = None, |
| | |
| | save_on_train_epoch_end: Optional[bool] = False, |
| | save_optim_on_train_end: Optional[bool] = False, |
| | always_save_context: bool = True, |
| | save_context_on_train_end: bool = True, |
| | **kwargs, |
| | ): |
| | self.always_save_context = always_save_context |
| | self.save_context_on_train_end = save_context_on_train_end |
| | self.save_optim_on_train_end = save_optim_on_train_end |
| |
|
| | |
| | |
| | |
| | self.future_last_model_path = "" |
| |
|
| | |
| | |
| | |
| | |
| | self.deferred_ckpts_to_remove: List[List[str]] = [] |
| | self.ckpts_to_link: Dict[str, str] = {} |
| |
|
| | |
| | super().__init__( |
| | monitor=monitor, |
| | verbose=verbose, |
| | save_last=save_last, |
| | save_top_k=save_top_k, |
| | save_weights_only=save_weights_only, |
| | mode=mode, |
| | every_n_epochs=every_n_epochs, |
| | every_n_train_steps=every_n_train_steps, |
| | train_time_interval=train_time_interval, |
| | save_on_train_epoch_end=save_on_train_epoch_end, |
| | **kwargs, |
| | ) |
| |
|
| | def on_train_start(self, trainer, pl_module): |
| | """ |
| | Initializes checkpointing by handling previous runs, |
| | setting up file logging, and managing files to move or copy. |
| | |
| | This method handles: |
| | - Moving old files to new folders |
| | - Copying relevant files to the log directory |
| | - Creating command argument and git information logs |
| | - Setting up logging for errors and Lightning logs |
| | |
| | Args: |
| | trainer (pl.Trainer): The PyTorch Lightning trainer object. |
| | pl_module (pl.LightningModule): The Lightning model to be trained. |
| | """ |
| | from nemo.utils.exp_manager import get_git_diff, get_git_hash |
| | from nemo.utils.get_rank import is_global_rank_zero |
| | from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger |
| |
|
| | app_state = AppState() |
| | if self.save_top_k != -1 and app_state.restore: |
| | logging.debug("Checking previous runs") |
| | self.nemo_topk_check_previous_run() |
| |
|
| | if is_global_rank_zero(): |
| | log_dir = app_state.log_dir |
| |
|
| | |
| | files_to_move = app_state.files_to_move |
| |
|
| | if len(files_to_move) > 0: |
| | |
| | other_run_dirs = Path(log_dir).glob("run_*") |
| | run_count = 0 |
| | for fold in other_run_dirs: |
| | if fold.is_dir(): |
| | run_count += 1 |
| | new_run_dir = Path(Path(log_dir) / f"run_{run_count}") |
| | if not new_run_dir.exists(): |
| | new_run_dir.mkdir() |
| | for _file in files_to_move: |
| | shutil.move(str(_file), str(new_run_dir)) |
| |
|
| | |
| | if app_state.files_to_copy: |
| | for _file in app_state.files_to_copy: |
| | src_path = Path(_file) |
| | dst_path = Path(log_dir) / src_path.name |
| | if not dst_path.exists(): |
| | shutil.copy(src_path, dst_path) |
| |
|
| | |
| | if app_state.cmd_args: |
| | cmd_args_file = log_dir / 'cmd-args.log' |
| | if not cmd_args_file.exists(): |
| | with open(cmd_args_file, 'w', encoding='utf-8') as _file: |
| | _file.write(" ".join(app_state.cmd_args)) |
| |
|
| | |
| | git_repo, git_hash = get_git_hash() |
| | if git_repo: |
| | git_info_file = log_dir / 'git-info.log' |
| | if not git_info_file.exists(): |
| | with open(git_info_file, 'w', encoding='utf-8') as _file: |
| | _file.write(f'commit hash: {git_hash}\n') |
| | _file.write(get_git_diff()) |
| |
|
| | |
| | logging.add_err_file_handler(log_dir / 'nemo_error_log.txt') |
| |
|
| | |
| | add_filehandlers_to_pl_logger(log_dir / 'lightning_logs.txt', log_dir / 'nemo_error_log.txt') |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.barrier() |
| |
|
| | super().on_train_start(trainer, pl_module) |
| |
|
| | def nemo_topk_check_previous_run(self): |
| | """ |
| | Verifies and cleans up the top-k checkpoint state from previous training runs. |
| | |
| | This method ensures that: |
| | - The top-k models are correctly loaded and ordered. |
| | - Any outdated or invalid checkpoints are removed. |
| | - The best model is determined based on the monitored metric. |
| | |
| | Raises: |
| | AttributeError: If the expected attributes for the top-k model are not found. |
| | """ |
| | try: |
| | self.best_k_models |
| | self.kth_best_model_path |
| | self.best_model_score |
| | self.best_model_path |
| | except AttributeError: |
| | raise AttributeError( |
| | "Lightning's ModelCheckpoint was updated. NeMo's ModelCheckpoint will need an update." |
| | ) |
| | self.best_k_models = {} |
| | self.kth_best_model_path = "" |
| | self.best_model_score = None |
| | self.best_model_path = "" |
| |
|
| | checkpoints = list(path for path in self._saved_checkpoint_paths if not self._is_ema_filepath(path)) |
| | for checkpoint in checkpoints: |
| | checkpoint = str(checkpoint) |
| | if checkpoint[-10:] == '-last.ckpt' or checkpoint[-5:] == '-last': |
| | continue |
| | |
| | index = checkpoint.find(self.monitor) + len(self.monitor) + 1 |
| | if index != len(self.monitor): |
| | match = re.search('[A-z]', checkpoint[index:]) |
| | if match: |
| | |
| | value = checkpoint[index : index + match.start() - 1] |
| | else: |
| | value = checkpoint[index:] |
| | self.best_k_models[checkpoint] = float(value) |
| | if len(self.best_k_models) < 1: |
| | return |
| |
|
| | _reverse = False if self.mode == "min" else True |
| |
|
| | best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse) |
| |
|
| | |
| | |
| | models_to_delete = len(best_k_models) - self.save_top_k |
| | models_to_delete = max(0, models_to_delete) |
| | logging.debug(f'Number of models to delete: {models_to_delete}') |
| |
|
| | |
| | ema_enabled = self._has_ema_ckpts(self._saved_checkpoint_paths) |
| |
|
| | for _ in range(models_to_delete): |
| | model = best_k_models.pop(-1) |
| | self.best_k_models.pop(model) |
| | self._del_model_without_trainer(model) |
| | if ema_enabled and self._fs.exists(self._ema_format_filepath(model)): |
| | self._del_model_without_trainer(self._ema_format_filepath(model)) |
| | logging.debug(f"Removed checkpoint: {model}") |
| |
|
| | self.kth_best_model_path = best_k_models[-1] |
| | self.best_model_path = best_k_models[0] |
| | self.best_model_score = self.best_k_models[self.best_model_path] |
| |
|
| | def _remove_invalid_entries_from_topk(self): |
| | """ |
| | Removes invalid (incomplete or non-existing) checkpoints from the list of top-k checkpoints. |
| | |
| | This function is necessary when checkpointing might have been abruptly interrupted, leaving behind |
| | incomplete or corrupted checkpoints. The invalid checkpoints are identified by checking if their |
| | corresponding directory exists and if the checkpoint is not unfinished. |
| | |
| | After removing invalid entries, the method updates the best-k models based on the existing, valid checkpoints. |
| | |
| | Attributes Updated: |
| | - `best_k_models`: A dictionary of valid checkpoints from top-k models. |
| | - `best_model_path`: Path to the best model based on the current sorting order. |
| | - `best_model_score`: The score associated with the best model. |
| | - `kth_best_model_path`: Path to the kth best model. |
| | - `kth_value`: The score associated with the kth best model. |
| | """ |
| |
|
| | |
| | |
| | def __is_ckpt_ok(ckpt_path: str) -> bool: |
| | exists = os.path.isdir(ckpt_path.removesuffix('.ckpt')) |
| | return exists and not self.is_checkpoint_unfinished(ckpt_path) |
| |
|
| | self.best_k_models = {k: v for k, v in self.best_k_models.items() if __is_ckpt_ok(k)} |
| | if len(self.best_k_models) > 0: |
| | reverse_arr = self.mode != "min" |
| | best_k_models_arr = sorted(self.best_k_models, key=self.best_k_models.get, reverse=reverse_arr) |
| | self.kth_best_model_path = best_k_models_arr[-1] |
| | self.kth_value = self.best_k_models[self.kth_best_model_path] |
| | self.best_model_path = best_k_models_arr[0] |
| | self.best_model_score = self.best_k_models[self.best_model_path] |
| | else: |
| | self.kth_best_model_path = "" |
| | self.kth_value = None |
| | self.best_model_path = "" |
| | self.best_model_score = None |
| |
|
| | def state_dict(self): |
| | """ |
| | Returns the state dictionary of the model. |
| | |
| | This function adds additional logic to handle the case when using symlinks. If the model is configured |
| | to save the last checkpoint as a symlink, the path to the last checkpoint is updated in the returned |
| | state dictionary to avoid off-by-one errors in the checkpointing system. |
| | |
| | Returns: |
| | Dict[str, Any]: The state dictionary of the model, including any necessary modifications for symlinks. |
| | """ |
| | state = super().state_dict() |
| | |
| | if self.save_last == "link": |
| | state["last_model_path"] = self.future_last_model_path |
| | return state |
| |
|
| | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
| | """ |
| | Loads the state dictionary into the model and removes invalid entries from the top-k checkpoints. |
| | |
| | This method ensures that after loading the model state, any invalid (incomplete or missing) checkpoints |
| | are removed from the top-k models list. |
| | |
| | Args: |
| | state_dict (Dict[str, Any]): The state dictionary to load into the model. |
| | """ |
| | super().load_state_dict(state_dict) |
| | self._remove_invalid_entries_from_topk() |
| |
|
| | def setup(self, trainer, *args, **kwargs) -> None: |
| | """ |
| | Initializes the model and removes any unfinished checkpoints before training. |
| | |
| | This method is responsible for ensuring that unfinished checkpoints are removed prior to starting the training. |
| | It also synchronizes all ranks in a distributed setting to ensure that unfinished checkpoints are removed |
| | across all ranks. |
| | |
| | Args: |
| | trainer: The trainer instance used for training. |
| | *args: Additional arguments passed to the parent setup method. |
| | **kwargs: Additional keyword arguments passed to the parent setup method. |
| | """ |
| | from nemo.utils.get_rank import is_global_rank_zero |
| |
|
| | if is_global_rank_zero(): |
| | logging.debug("Removing unfinished checkpoints if any...") |
| | ModelCheckpoint._remove_unfinished_checkpoints(self.dirpath) |
| | |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.barrier() |
| |
|
| | self.async_save = getattr(trainer.strategy, "async_save", False) |
| | super().setup(trainer, *args, **kwargs) |
| |
|
| | def on_train_end(self, trainer, pl_module): |
| | """ |
| | Handles actions to be performed when training ends, such as saving the last checkpoint. |
| | |
| | This method ensures that the last checkpoint is saved if needed, particularly when validation steps |
| | aren't always run based on the interval. It also manages saving the training context to disk, if configured. |
| | |
| | Args: |
| | trainer: The trainer instance used for training. |
| | pl_module: The model being trained. |
| | """ |
| | from nemo.utils.get_rank import is_global_rank_zero |
| |
|
| | if trainer.fast_dev_run: |
| | return None |
| |
|
| | |
| | if self.save_last and trainer.val_check_interval != 0: |
| | should_save_last_checkpoint = False |
| | if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0: |
| | should_save_last_checkpoint = True |
| | if isinstance(trainer.val_check_interval, int) and trainer.global_step % trainer.val_check_interval != 0: |
| | should_save_last_checkpoint = True |
| | if should_save_last_checkpoint: |
| | monitor_candidates = self._monitor_candidates(trainer) |
| | if self.last_model_path == self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST): |
| | logging.debug(f'Last checkpoint {self.last_model_path} already saved') |
| | else: |
| | super()._save_last_checkpoint(trainer, monitor_candidates) |
| | if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero(): |
| | TrainerContext.from_trainer(trainer).io_dump( |
| | ckpt_to_dir(self.last_model_path) / "context", yaml_attrs=["model"] |
| | ) |
| | |
| | super().on_train_end(trainer, pl_module) |
| |
|
| | def _del_model_without_trainer(self, filepath: str) -> None: |
| | """ |
| | Deletes the checkpoint model directory from distributed storage without requiring the trainer. |
| | |
| | This method ensures that distributed checkpoints are properly removed when necessary, especially |
| | if the model file is no longer needed or is incomplete. The removal only happens on the rank-zero process. |
| | |
| | Args: |
| | filepath (str): The path to the checkpoint model file to be deleted. |
| | """ |
| |
|
| | from nemo.utils.get_rank import is_global_rank_zero |
| |
|
| | filepath = Path(filepath) |
| |
|
| | if is_global_rank_zero(): |
| | try: |
| | dist_ckpt = ckpt_to_dir(filepath) |
| | shutil.rmtree(dist_ckpt, ignore_errors=True) |
| | logging.info(f"Removed distributed checkpoint: {dist_ckpt}") |
| | except: |
| | logging.info(f"Tried to remove distributed checkpoint: {dist_ckpt} but failed.") |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.barrier() |
| |
|
| | def _ema_callback(self, trainer: 'lightning.pytorch.Trainer'): |
| | """ |
| | Retrieves the Exponential Moving Average (EMA) callback from the list of trainer callbacks. |
| | |
| | This method scans through the list of callbacks attached to the trainer and returns the EMA callback |
| | instance if present. The EMA callback is often used to track the exponential moving average of model parameters |
| | during training. |
| | |
| | Args: |
| | trainer ('lightning.pytorch.Trainer'): The trainer instance. |
| | |
| | Returns: |
| | EMA: The EMA callback instance if found, or None if not present. |
| | """ |
| | from nemo.collections.common.callbacks import EMA |
| |
|
| | ema_callback = None |
| | for callback in trainer.callbacks: |
| | if isinstance(callback, EMA): |
| | ema_callback = callback |
| | return ema_callback |
| |
|
| | @staticmethod |
| | def format_checkpoint_unfinished_marker_path(checkpoint_path: Union[Path, str]) -> Path: |
| | """Format the path to the unfinished checkpoint marker file. |
| | |
| | If the marker file exists, corresponding checkpoint is considered unfinished/incomplete. |
| | NOTE: Marker path for the EMA checkpoint part is the same as for the original checkpoint. |
| | |
| | Args: |
| | checkpoint_path: Path to the checkpoint file or dir. |
| | Does not need to exist. |
| | |
| | Returns: |
| | Path to the unfinished checkpoint marker file. |
| | """ |
| | marker_filepath = str(checkpoint_path).removesuffix(".ckpt") |
| | marker_filepath = marker_filepath.removesuffix("-EMA") |
| | return Path(marker_filepath + ModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX) |
| |
|
| | @staticmethod |
| | def is_checkpoint_unfinished(checkpoint_path: Union[Path, str]) -> bool: |
| | """Check if the checkpoint is unfinished. |
| | |
| | Args: |
| | checkpoint_path: Path to the checkpoint file or dir. |
| | Does not need to exist. |
| | |
| | Returns: |
| | True if the checkpoint is unfinished, False otherwise. |
| | """ |
| | return ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path).exists() |
| |
|
| | @staticmethod |
| | def set_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_after=False) -> None: |
| | """Marks given checkpoint as unfinished. |
| | |
| | Args: |
| | checkpoint_filepath: Path to the checkpoint file or dir. |
| | Does not need to exist. |
| | barrier_after: Synchronize ranks after writing the marker file. |
| | Defaults to False. |
| | """ |
| | from nemo.utils.get_rank import is_global_rank_zero |
| |
|
| | if is_global_rank_zero(): |
| | marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) |
| | marker_path.parent.mkdir(parents=True, exist_ok=True) |
| | marker_path.touch() |
| | if barrier_after and torch.distributed.is_initialized(): |
| | torch.distributed.barrier() |
| |
|
| | @staticmethod |
| | def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_before=False) -> None: |
| | """Clear unfinished marker for given checkpoint. |
| | |
| | Args: |
| | checkpoint_path: Path to the checkpoint file or dir. |
| | Does not need to exist. |
| | barrier_before: Synchronize ranks before removing the marker file. |
| | Defaults to False. |
| | """ |
| | from nemo.utils.get_rank import is_global_rank_zero |
| |
|
| | try: |
| | if barrier_before and torch.distributed.is_initialized(): |
| | torch.distributed.barrier() |
| | if is_global_rank_zero(): |
| | marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) |
| | if marker_path.exists(): |
| | marker_path.unlink() |
| | except: |
| | return |
| |
|
| | def file_exists(self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True) -> bool: |
| | """Checks if a file or a file without a suffix (distributed checkpoint) exists.""" |
| | exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(str(ckpt_to_dir(filepath)))) |
| | return trainer.strategy.broadcast(exists) |
| |
|
| | def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]: |
| | """Broadcast loss from last pipeline stage.""" |
| | monitor_candidates = super()._monitor_candidates(trainer) |
| |
|
| | from nemo.lightning._strategy_lib import _sync_from_last_pipeline_stage |
| | from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy |
| |
|
| | keys = re.findall(r"[\{](.*?)[:\}]", self.filename) |
| | for loss_name in ['reduced_train_loss']: |
| | if loss_name in keys or loss_name == self.monitor: |
| | if loss_name not in monitor_candidates: |
| | monitor_candidates[loss_name] = torch.tensor(0.0, device=torch.cuda.current_device()) |
| | if isinstance(trainer.strategy, MegatronStrategy): |
| | _sync_from_last_pipeline_stage(monitor_candidates[loss_name], broadcast=True) |
| |
|
| | return monitor_candidates |
| |
|
| | def _link_checkpoint(self, trainer: "pl.Trainer", filepath: str, linkpath: str, override_async=False) -> None: |
| | """Check to see whether this step has already been saved as top_k |
| | in which case we can create a symlink |
| | otherwise, we have to save the checkpoint |
| | """ |
| | saved_current_step = str(ckpt_to_dir(linkpath)).replace("-last", "") == str(ckpt_to_dir(filepath)) |
| | if not saved_current_step: |
| | self._save_checkpoint(trainer, linkpath) |
| | return |
| |
|
| | |
| | if self.async_save and not override_async: |
| | self.ckpts_to_link[str(filepath)] = str(linkpath) |
| | return |
| |
|
| | filepath = ckpt_to_dir(filepath) |
| | linkpath = ckpt_to_dir(linkpath) |
| | super()._link_checkpoint(trainer, filepath, linkpath) |
| |
|
| | def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None: |
| | """Saves the checkpoint to the given filepath |
| | |
| | Args: |
| | trainer (lightning.pytorch.Trainer): the trainer obj |
| | filepath (str): path to save checkpoint to. |
| | |
| | Raises: |
| | ValueError: (mcore) async_save with EMA not supported |
| | ValueError: (mcore) Async save requires async compatible CheckpointIO |
| | """ |
| |
|
| | from nemo.utils.get_rank import is_global_rank_zero |
| |
|
| | |
| | |
| | self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) |
| | ema_callback = self._ema_callback(trainer) |
| |
|
| | self._last_global_step_saved = trainer.global_step |
| |
|
| | |
| | |
| | if self.save_last == "link": |
| | self.future_last_model_path = str(ckpt_to_dir(filepath)) |
| | if not str(ckpt_to_dir(filepath)).endswith("last"): |
| | self.future_last_model_path += "-last.ckpt" |
| |
|
| | if ema_callback is not None: |
| | if self.async_save: |
| | raise ValueError('async_save with EMA not supported') |
| | with ema_callback.save_original_optimizer_state(trainer): |
| | super()._save_checkpoint(trainer, filepath) |
| |
|
| | |
| | with ema_callback.save_ema_model(trainer): |
| | rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") |
| | filepath = self._ema_format_filepath(filepath) |
| | if self.verbose: |
| | rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") |
| | super()._save_checkpoint(trainer, filepath) |
| | self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | save_weights_only = self.save_weights_only or ( |
| | not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps |
| | ) |
| |
|
| | |
| | |
| | finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step) |
| | if self.async_save: |
| | checkpoint_io = trainer.strategy.checkpoint_io |
| | from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO |
| |
|
| | if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): |
| | raise ValueError('Async save requires async compatible CheckpointIO') |
| | storage_options = dict(finalize_fn=finalize_fn) |
| | |
| | self.deferred_ckpts_to_remove.append([]) |
| | else: |
| | storage_options = None |
| | trainer.save_checkpoint(filepath, save_weights_only, storage_options=storage_options) |
| |
|
| | if self.always_save_context and is_global_rank_zero(): |
| | TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context", yaml_attrs=["model"]) |
| |
|
| | if self.async_save: |
| | self._last_checkpoint_saved = filepath |
| | logging.info(f'Scheduled async checkpoint save for {filepath}') |
| | else: |
| | finalize_fn() |
| |
|
| | def _get_finalize_save_checkpoint_callback( |
| | self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int |
| | ): |
| | """Creates a callback that can be used to finalize async (and sync) ckpt saves.""" |
| |
|
| | def _cb(): |
| | logging.debug(f'Finalize callback called for step {global_step}, filepath {filepath}') |
| | self._last_checkpoint_saved = filepath |
| |
|
| | |
| | if trainer.is_global_zero: |
| | for logger in trainer.loggers: |
| | logger.after_save_checkpoint(proxy(self)) |
| |
|
| | |
| | |
| | self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) |
| |
|
| | if not self.async_save: |
| | return |
| |
|
| | logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.') |
| |
|
| | if str(filepath) in self.ckpts_to_link: |
| | self._link_checkpoint(trainer, filepath, self.ckpts_to_link.pop(filepath), override_async=True) |
| |
|
| | |
| | |
| | assert self.deferred_ckpts_to_remove |
| | ckpts_to_remove = self.deferred_ckpts_to_remove.pop(0) |
| | logging.debug(f'Checkpoints to remove: {ckpts_to_remove}') |
| | for ckpt_to_remove in ckpts_to_remove: |
| | self._remove_checkpoint(trainer, ckpt_to_remove, override_async=True) |
| |
|
| | return _cb |
| |
|
| | def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str, override_async=False) -> None: |
| | """Performs checkpoint removal. |
| | |
| | With async save, `self._remove_checkpoint` is called before the checkpoint |
| | is actually finished so we can't remove it. Instead we add it to |
| | `self.deferred_ckpts_to_remove` for future removal. |
| | """ |
| | if self.async_save and not override_async: |
| | |
| | if len(self.deferred_ckpts_to_remove) == 0: |
| | self.deferred_ckpts_to_remove.append([]) |
| | self.deferred_ckpts_to_remove[-1].append(filepath) |
| | return |
| | |
| | |
| | self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) |
| | try: |
| | super()._remove_checkpoint(trainer, filepath) |
| | except Exception as e: |
| | logging.warning( |
| | f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}' |
| | ) |
| | ema_callback = self._ema_callback(trainer) |
| | if ema_callback is not None: |
| | |
| |
|
| | filepath = self._ema_format_filepath(filepath) |
| | try: |
| | super()._remove_checkpoint(trainer, filepath) |
| | except Exception as e: |
| | logging.warning( |
| | f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}' |
| | ) |
| | |
| | |
| | self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) |
| |
|
| | def _ema_format_filepath(self, filepath: str) -> str: |
| | """Formats given path for EMA checkpoint |
| | |
| | Args: |
| | filepath (str): filepath |
| | |
| | Returns: |
| | str: EMA-formatted filepath |
| | """ |
| | return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') |
| |
|
| | def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool: |
| | """Checkes whether filepaths are EMA-formatted |
| | |
| | Args: |
| | checkpoints (Iterable[Path]): paths to check |
| | |
| | Returns: |
| | bool: True indicates path is EMA-formatted. |
| | """ |
| | return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints) |
| |
|
| | def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: |
| | """Checkes whether filepaths are EMA-formatted |
| | |
| | Args: |
| | filepath (Union[Path, str]): path to check |
| | |
| | Returns: |
| | bool: True indicates path is EMA-formatted. |
| | """ |
| | return str(filepath).endswith(f'-EMA{self.FILE_EXTENSION}') |
| |
|
| | @property |
| | def _saved_checkpoint_paths(self) -> Iterable[Path]: |
| | """ |
| | Retrieves a list of saved checkpoint paths while filtering out unfinished checkpoints. |
| | |
| | - If distributed checkpoints (directories) exist, return only those. |
| | - Otherwise, return individual checkpoint files with a .ckpt extension. |
| | - Filters out any checkpoints that are marked as unfinished. |
| | |
| | Returns: |
| | Iterable[Path]: An iterable containing valid checkpoint paths. |
| | """ |
| | |
| | |
| | dist_checkpoints = [d for d in Path(self.dirpath).glob("*") if d.is_dir()] |
| | if dist_checkpoints: |
| | return filter(lambda p: not self.is_checkpoint_unfinished(p), dist_checkpoints) |
| | else: |
| | checkpoint_files = [f for f in Path(self.dirpath).rglob("*.ckpt")] |
| | return filter(lambda p: not self.is_checkpoint_unfinished(p), checkpoint_files) |
| |
|
| | @staticmethod |
| | def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None: |
| | """ |
| | Removes all unfinished checkpoints and their associated marker files from the filesystem. |
| | |
| | - Ensures this function runs only on rank 0. |
| | - Deletes individual unfinished checkpoint files. |
| | - Removes directories corresponding to unfinished distributed checkpoints. |
| | - Deletes the marker files indicating unfinished checkpoints. |
| | |
| | Args: |
| | checkpoint_dir (Union[Path, str]): Path to the directory containing checkpoints. |
| | |
| | Raises: |
| | AssertionError: If the function is called from a non-rank 0 process. |
| | """ |
| | from nemo.utils.get_rank import is_global_rank_zero |
| |
|
| | |
| | |
| |
|
| | if not is_global_rank_zero(): |
| | raise AssertionError("_remove_unfinished_checkpoints should run only on rank 0") |
| |
|
| | checkpoint_dir = Path(checkpoint_dir) |
| |
|
| | existing_marker_filepaths = { |
| | f.resolve() for f in checkpoint_dir.glob(f"*{ModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}") if f.is_file() |
| | } |
| |
|
| | checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")} |
| | for filepath in checkpoint_filepaths: |
| | possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(filepath) |
| | if possible_marker_path in existing_marker_filepaths: |
| | logging.warning(f'Removing unfinished checkpoint: {filepath}') |
| | os.remove(filepath) |
| |
|
| | |
| | all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()} |
| | for ckpt_dirpath in all_dirpaths: |
| | possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_dirpath) |
| | if possible_marker_path in existing_marker_filepaths: |
| | logging.warning(f'Removing unfinished dist checkpoint: {ckpt_dirpath}') |
| | shutil.rmtree(ckpt_dirpath) |
| |
|
| | |
| | for marker_path in existing_marker_filepaths: |
| | os.remove(marker_path) |
| |
|
| | def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool: |
| | """Checks if the previous checkpoint should be deleted. |
| | A checkpoint won't be deleted if any of the cases apply: |
| | - The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new) |
| | - The previous checkpoint is not in the current checkpoint directory and the filesystem is local |
| | - The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local |
| | and the resumed from checkpoint is not the last checkpoint |
| | """ |
| | if previous == current: |
| | return False |
| | if not _is_local_file_protocol(previous): |
| | return True |
| | previous = Path(previous).absolute() |
| | resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None |
| |
|
| | if resume_path is not None and previous == resume_path: |
| | if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"): |
| | |
| | |
| | pass |
| | else: |
| | return False |
| | if self.dirpath is None: |
| | raise ValueError(f"{self.__class__}.dirpath is None.") |
| | dirpath = Path(self.dirpath).absolute() |
| | return dirpath in previous.parents |
| |
|