Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import shutil
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Iterable, List, Literal, Optional, Union
import lightning
import lightning.pytorch as pl
import torch
from _weakref import proxy
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint
from lightning.pytorch.callbacks.model_checkpoint import _is_local_file_protocol
from lightning.pytorch.utilities import rank_zero_info
from nemo.lightning.ckpt_utils import ckpt_to_dir
from nemo.lightning.io.pl import TrainerContext
from nemo.utils import logging
from nemo.utils.app_state import AppState
class ModelCheckpoint(PTLModelCheckpoint):
"""Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end.
Adds support for asyncronous checkpointing and provides some additional logic to clean up invalid checkpoints
Args:
monitor: Metric to monitor when saving top-k checkpoints.
verbose: Verbosity mode.
save_last: When ``True``, saves a `*-last` copy whenever a checkpoint file gets saved.
save_top_k: When ``True``, saves the top-k checkpoints according to ``monitor``.
save_weights_only: if ``True``, then only the model's weights will be saved. Optimizer states will
be omitted from all checkpoints.
mode: One of {min, max}. Whether the objective is to minimize or maximize the monitored quantity.
every_n_epochs: Number of epochs between checkpoints.
every_n_train_steps: Number of train steps between checkpoints.
train_time_interval: After each interval, monitor checkpoints. Not to be used with
``every_n_epochs`` or ``every_n_train_steps``.
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch
save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint
at the end of training. Only applicable when save_weights_only is ``False``.
always_save_context: Whether to dump the artifacts needed to reinintialize the current
model, trainer, and dataloader to allow for reproducibility of experiments.
save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether
``always_save_context`` is ``True``.
async_save: Whether to enable asynchronous checkpointing.
Attributes:
UNFINISHED_CHECKPOINT_SUFFIX (str): Suffix for unfinished checkpoint files.
deferred_ckpts_to_remove (List[List[str]]): List of deferred checkpoints
to remove once async save is completed.
ckpts_to_link (Dict[str, str]): Dictionary of checkpoint paths that need to be symlinked.
future_last_model_path (str): Path to the future 'last' checkpoint, used for symbolic linking.
best_k_models (dict): Dictionary of best-k checkpoints based on the monitored metric.
best_model_score (float): Score of the best checkpoint.
best_model_path (str): Path to the best checkpoint.
kth_best_model_path (str): Path to the kth best checkpoint.
"""
UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
def __init__(
self,
monitor: Optional[str] = "val_loss",
verbose: bool = True,
save_last: Optional[Union[bool, Literal["link"]]] = True,
save_top_k: int = 3,
save_weights_only: bool = False, # TODO: check support
mode: str = "min",
every_n_epochs: int = None,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
# Save after training, not after validation
save_on_train_epoch_end: Optional[bool] = False,
save_optim_on_train_end: Optional[bool] = False,
always_save_context: bool = True,
save_context_on_train_end: bool = True,
**kwargs,
):
self.always_save_context = always_save_context
self.save_context_on_train_end = save_context_on_train_end
self.save_optim_on_train_end = save_optim_on_train_end
# stores the next -last checkpoint to be saved, used only when save_last = 'link'
# this is needed because when using symlinks, we need to update the non-last checkpoint's
# last_model_path to point to the corresponding -last version
self.future_last_model_path = ""
# Checkpoints which removal is deferred until async save is done.
# Each element of `deferred_ckpts_to_remove` is a growing list
# that `self._remove_checkpoint` adds to. Once `self._save_checkpoint`
# is called, the last element is frozen and a new element is added.
self.deferred_ckpts_to_remove: List[List[str]] = []
self.ckpts_to_link: Dict[str, str] = {}
# Call the parent class constructor with the remaining kwargs.
super().__init__(
monitor=monitor,
verbose=verbose,
save_last=save_last,
save_top_k=save_top_k,
save_weights_only=save_weights_only,
mode=mode,
every_n_epochs=every_n_epochs,
every_n_train_steps=every_n_train_steps,
train_time_interval=train_time_interval,
save_on_train_epoch_end=save_on_train_epoch_end,
**kwargs,
)
def on_train_start(self, trainer, pl_module):
"""
Initializes checkpointing by handling previous runs,
setting up file logging, and managing files to move or copy.
This method handles:
- Moving old files to new folders
- Copying relevant files to the log directory
- Creating command argument and git information logs
- Setting up logging for errors and Lightning logs
Args:
trainer (pl.Trainer): The PyTorch Lightning trainer object.
pl_module (pl.LightningModule): The Lightning model to be trained.
"""
from nemo.utils.exp_manager import get_git_diff, get_git_hash
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger
app_state = AppState()
if self.save_top_k != -1 and app_state.restore:
logging.debug("Checking previous runs")
self.nemo_topk_check_previous_run()
if is_global_rank_zero():
log_dir = app_state.log_dir
# Check to see if any files exist that need to be moved
files_to_move = app_state.files_to_move
if len(files_to_move) > 0:
# Move old files to a new folder
other_run_dirs = Path(log_dir).glob("run_*")
run_count = 0
for fold in other_run_dirs:
if fold.is_dir():
run_count += 1
new_run_dir = Path(Path(log_dir) / f"run_{run_count}")
if not new_run_dir.exists():
new_run_dir.mkdir()
for _file in files_to_move:
shutil.move(str(_file), str(new_run_dir))
# Move files_to_copy to folder and add git information if present
if app_state.files_to_copy:
for _file in app_state.files_to_copy:
src_path = Path(_file)
dst_path = Path(log_dir) / src_path.name
if not dst_path.exists():
shutil.copy(src_path, dst_path)
# Create files for cmd args and git info
if app_state.cmd_args:
cmd_args_file = log_dir / 'cmd-args.log'
if not cmd_args_file.exists():
with open(cmd_args_file, 'w', encoding='utf-8') as _file:
_file.write(" ".join(app_state.cmd_args))
# Try to get git hash
git_repo, git_hash = get_git_hash()
if git_repo:
git_info_file = log_dir / 'git-info.log'
if not git_info_file.exists():
with open(git_info_file, 'w', encoding='utf-8') as _file:
_file.write(f'commit hash: {git_hash}\n')
_file.write(get_git_diff())
# Add err_file logging to global_rank zero
logging.add_err_file_handler(log_dir / 'nemo_error_log.txt')
# Add lightning file logging to global_rank zero
add_filehandlers_to_pl_logger(log_dir / 'lightning_logs.txt', log_dir / 'nemo_error_log.txt')
if torch.distributed.is_initialized():
torch.distributed.barrier()
super().on_train_start(trainer, pl_module)
def nemo_topk_check_previous_run(self):
"""
Verifies and cleans up the top-k checkpoint state from previous training runs.
This method ensures that:
- The top-k models are correctly loaded and ordered.
- Any outdated or invalid checkpoints are removed.
- The best model is determined based on the monitored metric.
Raises:
AttributeError: If the expected attributes for the top-k model are not found.
"""
try:
self.best_k_models
self.kth_best_model_path
self.best_model_score
self.best_model_path
except AttributeError:
raise AttributeError(
"Lightning's ModelCheckpoint was updated. NeMo's ModelCheckpoint will need an update."
)
self.best_k_models = {}
self.kth_best_model_path = ""
self.best_model_score = None
self.best_model_path = ""
checkpoints = list(path for path in self._saved_checkpoint_paths if not self._is_ema_filepath(path))
for checkpoint in checkpoints:
checkpoint = str(checkpoint)
if checkpoint[-10:] == '-last.ckpt' or checkpoint[-5:] == '-last':
continue
# Find monitor in str + 1 for '='
index = checkpoint.find(self.monitor) + len(self.monitor) + 1
if index != len(self.monitor):
match = re.search('[A-z]', checkpoint[index:])
if match:
# -1 due to separator hyphen
value = checkpoint[index : index + match.start() - 1]
else:
value = checkpoint[index:]
self.best_k_models[checkpoint] = float(value)
if len(self.best_k_models) < 1:
return # No saved checkpoints yet
_reverse = False if self.mode == "min" else True
best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse)
# This section should be ok as rank zero will delete all excess checkpoints, since all other ranks are
# instantiated after rank zero. models_to_delete should be 0 for all other ranks.
models_to_delete = len(best_k_models) - self.save_top_k
models_to_delete = max(0, models_to_delete)
logging.debug(f'Number of models to delete: {models_to_delete}')
# If EMA enabled, delete the additional EMA weights
ema_enabled = self._has_ema_ckpts(self._saved_checkpoint_paths)
for _ in range(models_to_delete):
model = best_k_models.pop(-1)
self.best_k_models.pop(model)
self._del_model_without_trainer(model)
if ema_enabled and self._fs.exists(self._ema_format_filepath(model)):
self._del_model_without_trainer(self._ema_format_filepath(model))
logging.debug(f"Removed checkpoint: {model}")
self.kth_best_model_path = best_k_models[-1]
self.best_model_path = best_k_models[0]
self.best_model_score = self.best_k_models[self.best_model_path]
def _remove_invalid_entries_from_topk(self):
"""
Removes invalid (incomplete or non-existing) checkpoints from the list of top-k checkpoints.
This function is necessary when checkpointing might have been abruptly interrupted, leaving behind
incomplete or corrupted checkpoints. The invalid checkpoints are identified by checking if their
corresponding directory exists and if the checkpoint is not unfinished.
After removing invalid entries, the method updates the best-k models based on the existing, valid checkpoints.
Attributes Updated:
- `best_k_models`: A dictionary of valid checkpoints from top-k models.
- `best_model_path`: Path to the best model based on the current sorting order.
- `best_model_score`: The score associated with the best model.
- `kth_best_model_path`: Path to the kth best model.
- `kth_value`: The score associated with the kth best model.
"""
# Removes invalid (incomplete or not existing) checkpoints from topk checkpoints.
# This might be needed if the checkpointing was abruptly terminated.
def __is_ckpt_ok(ckpt_path: str) -> bool:
exists = os.path.isdir(ckpt_path.removesuffix('.ckpt'))
return exists and not self.is_checkpoint_unfinished(ckpt_path)
self.best_k_models = {k: v for k, v in self.best_k_models.items() if __is_ckpt_ok(k)}
if len(self.best_k_models) > 0:
reverse_arr = self.mode != "min"
best_k_models_arr = sorted(self.best_k_models, key=self.best_k_models.get, reverse=reverse_arr)
self.kth_best_model_path = best_k_models_arr[-1]
self.kth_value = self.best_k_models[self.kth_best_model_path]
self.best_model_path = best_k_models_arr[0]
self.best_model_score = self.best_k_models[self.best_model_path]
else:
self.kth_best_model_path = ""
self.kth_value = None
self.best_model_path = ""
self.best_model_score = None
def state_dict(self):
"""
Returns the state dictionary of the model.
This function adds additional logic to handle the case when using symlinks. If the model is configured
to save the last checkpoint as a symlink, the path to the last checkpoint is updated in the returned
state dictionary to avoid off-by-one errors in the checkpointing system.
Returns:
Dict[str, Any]: The state dictionary of the model, including any necessary modifications for symlinks.
"""
state = super().state_dict()
# if using symlinks, overwrite last_model_path to avoid off-by-one issues
if self.save_last == "link":
state["last_model_path"] = self.future_last_model_path
return state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Loads the state dictionary into the model and removes invalid entries from the top-k checkpoints.
This method ensures that after loading the model state, any invalid (incomplete or missing) checkpoints
are removed from the top-k models list.
Args:
state_dict (Dict[str, Any]): The state dictionary to load into the model.
"""
super().load_state_dict(state_dict)
self._remove_invalid_entries_from_topk()
def setup(self, trainer, *args, **kwargs) -> None:
"""
Initializes the model and removes any unfinished checkpoints before training.
This method is responsible for ensuring that unfinished checkpoints are removed prior to starting the training.
It also synchronizes all ranks in a distributed setting to ensure that unfinished checkpoints are removed
across all ranks.
Args:
trainer: The trainer instance used for training.
*args: Additional arguments passed to the parent setup method.
**kwargs: Additional keyword arguments passed to the parent setup method.
"""
from nemo.utils.get_rank import is_global_rank_zero
if is_global_rank_zero():
logging.debug("Removing unfinished checkpoints if any...")
ModelCheckpoint._remove_unfinished_checkpoints(self.dirpath)
# Ensure that all ranks continue with unfinished checkpoints removed
if torch.distributed.is_initialized():
torch.distributed.barrier()
self.async_save = getattr(trainer.strategy, "async_save", False)
super().setup(trainer, *args, **kwargs)
def on_train_end(self, trainer, pl_module):
"""
Handles actions to be performed when training ends, such as saving the last checkpoint.
This method ensures that the last checkpoint is saved if needed, particularly when validation steps
aren't always run based on the interval. It also manages saving the training context to disk, if configured.
Args:
trainer: The trainer instance used for training.
pl_module: The model being trained.
"""
from nemo.utils.get_rank import is_global_rank_zero
if trainer.fast_dev_run:
return None
# check if we need to save a last checkpoint manually as validation isn't always run based on the interval
if self.save_last and trainer.val_check_interval != 0:
should_save_last_checkpoint = False
if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0:
should_save_last_checkpoint = True
if isinstance(trainer.val_check_interval, int) and trainer.global_step % trainer.val_check_interval != 0:
should_save_last_checkpoint = True
if should_save_last_checkpoint:
monitor_candidates = self._monitor_candidates(trainer)
if self.last_model_path == self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST):
logging.debug(f'Last checkpoint {self.last_model_path} already saved')
else:
super()._save_last_checkpoint(trainer, monitor_candidates)
if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(
ckpt_to_dir(self.last_model_path) / "context", yaml_attrs=["model"]
)
# Call parent on_train_end() to save the -last checkpoint
super().on_train_end(trainer, pl_module)
def _del_model_without_trainer(self, filepath: str) -> None:
"""
Deletes the checkpoint model directory from distributed storage without requiring the trainer.
This method ensures that distributed checkpoints are properly removed when necessary, especially
if the model file is no longer needed or is incomplete. The removal only happens on the rank-zero process.
Args:
filepath (str): The path to the checkpoint model file to be deleted.
"""
from nemo.utils.get_rank import is_global_rank_zero
filepath = Path(filepath)
if is_global_rank_zero():
try:
dist_ckpt = ckpt_to_dir(filepath)
shutil.rmtree(dist_ckpt, ignore_errors=True)
logging.info(f"Removed distributed checkpoint: {dist_ckpt}")
except:
logging.info(f"Tried to remove distributed checkpoint: {dist_ckpt} but failed.")
if torch.distributed.is_initialized():
torch.distributed.barrier()
def _ema_callback(self, trainer: 'lightning.pytorch.Trainer'):
"""
Retrieves the Exponential Moving Average (EMA) callback from the list of trainer callbacks.
This method scans through the list of callbacks attached to the trainer and returns the EMA callback
instance if present. The EMA callback is often used to track the exponential moving average of model parameters
during training.
Args:
trainer ('lightning.pytorch.Trainer'): The trainer instance.
Returns:
EMA: The EMA callback instance if found, or None if not present.
"""
from nemo.collections.common.callbacks import EMA
ema_callback = None
for callback in trainer.callbacks:
if isinstance(callback, EMA):
ema_callback = callback
return ema_callback
@staticmethod
def format_checkpoint_unfinished_marker_path(checkpoint_path: Union[Path, str]) -> Path:
"""Format the path to the unfinished checkpoint marker file.
If the marker file exists, corresponding checkpoint is considered unfinished/incomplete.
NOTE: Marker path for the EMA checkpoint part is the same as for the original checkpoint.
Args:
checkpoint_path: Path to the checkpoint file or dir.
Does not need to exist.
Returns:
Path to the unfinished checkpoint marker file.
"""
marker_filepath = str(checkpoint_path).removesuffix(".ckpt")
marker_filepath = marker_filepath.removesuffix("-EMA")
return Path(marker_filepath + ModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX)
@staticmethod
def is_checkpoint_unfinished(checkpoint_path: Union[Path, str]) -> bool:
"""Check if the checkpoint is unfinished.
Args:
checkpoint_path: Path to the checkpoint file or dir.
Does not need to exist.
Returns:
True if the checkpoint is unfinished, False otherwise.
"""
return ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path).exists()
@staticmethod
def set_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_after=False) -> None:
"""Marks given checkpoint as unfinished.
Args:
checkpoint_filepath: Path to the checkpoint file or dir.
Does not need to exist.
barrier_after: Synchronize ranks after writing the marker file.
Defaults to False.
"""
from nemo.utils.get_rank import is_global_rank_zero
if is_global_rank_zero():
marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path)
marker_path.parent.mkdir(parents=True, exist_ok=True)
marker_path.touch()
if barrier_after and torch.distributed.is_initialized():
torch.distributed.barrier()
@staticmethod
def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_before=False) -> None:
"""Clear unfinished marker for given checkpoint.
Args:
checkpoint_path: Path to the checkpoint file or dir.
Does not need to exist.
barrier_before: Synchronize ranks before removing the marker file.
Defaults to False.
"""
from nemo.utils.get_rank import is_global_rank_zero
try:
if barrier_before and torch.distributed.is_initialized():
torch.distributed.barrier()
if is_global_rank_zero():
marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path)
if marker_path.exists():
marker_path.unlink()
except:
return
def file_exists(self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True) -> bool:
"""Checks if a file or a file without a suffix (distributed checkpoint) exists."""
exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(str(ckpt_to_dir(filepath))))
return trainer.strategy.broadcast(exists)
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]:
"""Broadcast loss from last pipeline stage."""
monitor_candidates = super()._monitor_candidates(trainer)
from nemo.lightning._strategy_lib import _sync_from_last_pipeline_stage
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
keys = re.findall(r"[\{](.*?)[:\}]", self.filename)
for loss_name in ['reduced_train_loss']:
if loss_name in keys or loss_name == self.monitor:
if loss_name not in monitor_candidates:
monitor_candidates[loss_name] = torch.tensor(0.0, device=torch.cuda.current_device())
if isinstance(trainer.strategy, MegatronStrategy):
_sync_from_last_pipeline_stage(monitor_candidates[loss_name], broadcast=True)
return monitor_candidates
def _link_checkpoint(self, trainer: "pl.Trainer", filepath: str, linkpath: str, override_async=False) -> None:
"""Check to see whether this step has already been saved as top_k
in which case we can create a symlink
otherwise, we have to save the checkpoint
"""
saved_current_step = str(ckpt_to_dir(linkpath)).replace("-last", "") == str(ckpt_to_dir(filepath))
if not saved_current_step:
self._save_checkpoint(trainer, linkpath)
return
# linking will happen as part of the finalize fn
if self.async_save and not override_async:
self.ckpts_to_link[str(filepath)] = str(linkpath)
return
filepath = ckpt_to_dir(filepath)
linkpath = ckpt_to_dir(linkpath)
super()._link_checkpoint(trainer, filepath, linkpath)
def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None:
"""Saves the checkpoint to the given filepath
Args:
trainer (lightning.pytorch.Trainer): the trainer obj
filepath (str): path to save checkpoint to.
Raises:
ValueError: (mcore) async_save with EMA not supported
ValueError: (mcore) Async save requires async compatible CheckpointIO
"""
from nemo.utils.get_rank import is_global_rank_zero
# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
ema_callback = self._ema_callback(trainer)
self._last_global_step_saved = trainer.global_step
# manually update last_model_path so symlink is up-to-date
# should only be done when using a symlink
if self.save_last == "link":
self.future_last_model_path = str(ckpt_to_dir(filepath))
if not str(ckpt_to_dir(filepath)).endswith("last"):
self.future_last_model_path += "-last.ckpt"
if ema_callback is not None:
if self.async_save:
raise ValueError('async_save with EMA not supported')
with ema_callback.save_original_optimizer_state(trainer):
super()._save_checkpoint(trainer, filepath)
# save EMA copy of the model as well.
with ema_callback.save_ema_model(trainer):
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
filepath = self._ema_format_filepath(filepath)
if self.verbose:
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
else:
# Determine whether to include optimizer states in the checkpoint
# optimizer states are included when
# 1. save_weights_only is False and
# 2. either save_optim_on_train_end is True, or save_optim_on_train_end is False but the checkpoint
# is an intermediate checkpoint.
save_weights_only = self.save_weights_only or (
not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps
)
# Async save passes the finalization function to checkpoint_io,
# sync save calls the finalization function immediately after save.
finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step)
if self.async_save:
checkpoint_io = trainer.strategy.checkpoint_io
from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO
if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO):
raise ValueError('Async save requires async compatible CheckpointIO')
storage_options = dict(finalize_fn=finalize_fn)
# Each upcoming ckpt removal request will be executed as part of this save finalization
self.deferred_ckpts_to_remove.append([])
else:
storage_options = None
trainer.save_checkpoint(filepath, save_weights_only, storage_options=storage_options)
if self.always_save_context and is_global_rank_zero():
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context", yaml_attrs=["model"])
if self.async_save:
self._last_checkpoint_saved = filepath
logging.info(f'Scheduled async checkpoint save for {filepath}')
else:
finalize_fn()
def _get_finalize_save_checkpoint_callback(
self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int
):
"""Creates a callback that can be used to finalize async (and sync) ckpt saves."""
def _cb():
logging.debug(f'Finalize callback called for step {global_step}, filepath {filepath}')
self._last_checkpoint_saved = filepath
# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))
# barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker
# we don't want to remove the marker until all checkpointing is done.
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
if not self.async_save:
return
logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.')
if str(filepath) in self.ckpts_to_link:
self._link_checkpoint(trainer, filepath, self.ckpts_to_link.pop(filepath), override_async=True)
# Remove checkpoints marked for removal by `self._remove_checkpoint`
# For each finalization there is exactly one entry in self.deferred_ckpts_to_remove
assert self.deferred_ckpts_to_remove
ckpts_to_remove = self.deferred_ckpts_to_remove.pop(0)
logging.debug(f'Checkpoints to remove: {ckpts_to_remove}')
for ckpt_to_remove in ckpts_to_remove:
self._remove_checkpoint(trainer, ckpt_to_remove, override_async=True)
return _cb
def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str, override_async=False) -> None:
"""Performs checkpoint removal.
With async save, `self._remove_checkpoint` is called before the checkpoint
is actually finished so we can't remove it. Instead we add it to
`self.deferred_ckpts_to_remove` for future removal.
"""
if self.async_save and not override_async:
# Register checkpoint removal in the last (active) checkpoint removal list
if len(self.deferred_ckpts_to_remove) == 0:
self.deferred_ckpts_to_remove.append([])
self.deferred_ckpts_to_remove[-1].append(filepath)
return
# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during removal, we should be able to detect that data is incomplete.
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
try:
super()._remove_checkpoint(trainer, filepath)
except Exception as e:
logging.warning(
f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}'
)
ema_callback = self._ema_callback(trainer)
if ema_callback is not None:
# remove EMA copy of the state dict as well.
filepath = self._ema_format_filepath(filepath)
try:
super()._remove_checkpoint(trainer, filepath)
except Exception as e:
logging.warning(
f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}'
)
# barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker
# we don't want to remove the marker until the checkpoint is actually removed.
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
def _ema_format_filepath(self, filepath: str) -> str:
"""Formats given path for EMA checkpoint
Args:
filepath (str): filepath
Returns:
str: EMA-formatted filepath
"""
return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}')
def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool:
"""Checkes whether filepaths are EMA-formatted
Args:
checkpoints (Iterable[Path]): paths to check
Returns:
bool: True indicates path is EMA-formatted.
"""
return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints)
def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool:
"""Checkes whether filepaths are EMA-formatted
Args:
filepath (Union[Path, str]): path to check
Returns:
bool: True indicates path is EMA-formatted.
"""
return str(filepath).endswith(f'-EMA{self.FILE_EXTENSION}')
@property
def _saved_checkpoint_paths(self) -> Iterable[Path]:
"""
Retrieves a list of saved checkpoint paths while filtering out unfinished checkpoints.
- If distributed checkpoints (directories) exist, return only those.
- Otherwise, return individual checkpoint files with a .ckpt extension.
- Filters out any checkpoints that are marked as unfinished.
Returns:
Iterable[Path]: An iterable containing valid checkpoint paths.
"""
# distributed checkpoints are directories so we check for them here
# we filter out unfinished checkpoints, these should be deleted during next cleanup
dist_checkpoints = [d for d in Path(self.dirpath).glob("*") if d.is_dir()]
if dist_checkpoints:
return filter(lambda p: not self.is_checkpoint_unfinished(p), dist_checkpoints)
else:
checkpoint_files = [f for f in Path(self.dirpath).rglob("*.ckpt")]
return filter(lambda p: not self.is_checkpoint_unfinished(p), checkpoint_files)
@staticmethod
def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
"""
Removes all unfinished checkpoints and their associated marker files from the filesystem.
- Ensures this function runs only on rank 0.
- Deletes individual unfinished checkpoint files.
- Removes directories corresponding to unfinished distributed checkpoints.
- Deletes the marker files indicating unfinished checkpoints.
Args:
checkpoint_dir (Union[Path, str]): Path to the directory containing checkpoints.
Raises:
AssertionError: If the function is called from a non-rank 0 process.
"""
from nemo.utils.get_rank import is_global_rank_zero
# Delete unfinished checkpoints from the filesystems.
# "Unfinished marker" files are removed as well.
if not is_global_rank_zero():
raise AssertionError("_remove_unfinished_checkpoints should run only on rank 0")
checkpoint_dir = Path(checkpoint_dir)
existing_marker_filepaths = {
f.resolve() for f in checkpoint_dir.glob(f"*{ModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}") if f.is_file()
}
checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")}
for filepath in checkpoint_filepaths:
possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(filepath)
if possible_marker_path in existing_marker_filepaths:
logging.warning(f'Removing unfinished checkpoint: {filepath}')
os.remove(filepath)
# some directories might be distributed checkpoints, we remove these if they have a unfinished marker
all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()}
for ckpt_dirpath in all_dirpaths:
possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_dirpath)
if possible_marker_path in existing_marker_filepaths:
logging.warning(f'Removing unfinished dist checkpoint: {ckpt_dirpath}')
shutil.rmtree(ckpt_dirpath)
# delete markers
for marker_path in existing_marker_filepaths:
os.remove(marker_path)
def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
"""Checks if the previous checkpoint should be deleted.
A checkpoint won't be deleted if any of the cases apply:
- The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
- The previous checkpoint is not in the current checkpoint directory and the filesystem is local
- The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
and the resumed from checkpoint is not the last checkpoint
"""
if previous == current:
return False
if not _is_local_file_protocol(previous):
return True
previous = Path(previous).absolute()
resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None
if resume_path is not None and previous == resume_path:
if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"):
# delete the previous `-last.ckpt` checkpoint when current saved checkpoint
# is also `-last.ckpt`, if they're in the same directory
pass
else:
return False
if self.dirpath is None:
raise ValueError(f"{self.__class__}.dirpath is None.")
dirpath = Path(self.dirpath).absolute()
return dirpath in previous.parents