|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import re |
|
|
import shutil |
|
|
import time |
|
|
from copy import deepcopy |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Iterable, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from _weakref import proxy |
|
|
from lightning.fabric.utilities.cloud_io import get_filesystem |
|
|
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol |
|
|
from lightning.pytorch.trainer import call |
|
|
from lightning.pytorch.utilities import rank_zero_info |
|
|
|
|
|
from nemo.collections.common.callbacks import EMA |
|
|
from nemo.utils import logging |
|
|
from nemo.utils.app_state import AppState |
|
|
from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO |
|
|
from nemo.utils.get_rank import is_global_rank_zero |
|
|
from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank |
|
|
from nemo.utils.msc_utils import import_multistorageclient, is_multistorageclient_url |
|
|
|
|
|
|
|
|
class NeMoModelCheckpoint(ModelCheckpoint): |
|
|
"""Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. |
|
|
Extends Lightning's on_save_checkpoint func to save the .nemo file. Saves the .nemo file based |
|
|
on the best checkpoint saved (according to the monitor value). |
|
|
Also contains func to save the EMA copy of the model. |
|
|
""" |
|
|
|
|
|
UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
always_save_nemo: bool = False, |
|
|
save_nemo_on_train_end: bool = True, |
|
|
save_best_model: bool = False, |
|
|
postfix: str = ".nemo", |
|
|
n_resume: bool = False, |
|
|
model_parallel_size: int = None, |
|
|
async_save: bool = False, |
|
|
save_last_n_optim_states: int = -1, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
self.always_save_nemo = always_save_nemo |
|
|
self.save_nemo_on_train_end = save_nemo_on_train_end |
|
|
self.save_best_model = save_best_model |
|
|
self.save_last_n_optim_states = save_last_n_optim_states |
|
|
if self.save_best_model and not self.save_nemo_on_train_end: |
|
|
logging.warning( |
|
|
( |
|
|
"Found save_best_model is True and save_nemo_on_train_end is False. " |
|
|
"Set save_nemo_on_train_end to True to automatically save the best model." |
|
|
) |
|
|
) |
|
|
self.postfix = postfix |
|
|
self.previous_best_path = "" |
|
|
self.model_parallel_size = model_parallel_size |
|
|
self.async_save = async_save |
|
|
self.async_finalize_cb = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.deferred_ckpts_to_remove: List[List[str]] = [] |
|
|
|
|
|
|
|
|
if 'prefix' in kwargs: |
|
|
self.prefix = kwargs.pop('prefix') |
|
|
else: |
|
|
self.prefix = "" |
|
|
|
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if self.save_top_k != -1 and n_resume: |
|
|
logging.debug("Checking previous runs") |
|
|
self.nemo_topk_check_previous_run() |
|
|
|
|
|
def nemo_topk_check_previous_run(self): |
|
|
""" |
|
|
Check if there are previous runs. |
|
|
""" |
|
|
try: |
|
|
self.best_k_models |
|
|
self.kth_best_model_path |
|
|
self.best_model_score |
|
|
self.best_model_path |
|
|
except AttributeError: |
|
|
raise AttributeError("Lightning's ModelCheckpoint was updated. NeMoModelCheckpoint will need an update.") |
|
|
self.best_k_models = {} |
|
|
self.kth_best_model_path = "" |
|
|
self.best_model_score = None |
|
|
self.best_model_path = "" |
|
|
|
|
|
checkpoints = list(path for path in self._saved_checkpoint_paths if not self._is_ema_filepath(path)) |
|
|
for checkpoint in checkpoints: |
|
|
if 'mp_rank' in str(checkpoint) or 'tp_rank' in str(checkpoint): |
|
|
checkpoint = uninject_model_parallel_rank(checkpoint) |
|
|
checkpoint = str(checkpoint) |
|
|
|
|
|
if checkpoint[-10:] == '-last.ckpt' or checkpoint[-5:] == '-last': |
|
|
continue |
|
|
index = checkpoint.find(self.monitor) + len(self.monitor) + 1 |
|
|
if index != len(self.monitor): |
|
|
match = re.search('[A-z]', checkpoint[index:]) |
|
|
if match: |
|
|
value = checkpoint[index : index + match.start() - 1] |
|
|
self.best_k_models[checkpoint] = float(value) |
|
|
if len(self.best_k_models) < 1: |
|
|
return |
|
|
|
|
|
_reverse = False if self.mode == "min" else True |
|
|
|
|
|
best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse) |
|
|
|
|
|
|
|
|
|
|
|
if self.model_parallel_size is not None: |
|
|
|
|
|
if checkpoints[0].is_dir(): |
|
|
models_to_delete = len(best_k_models) - self.save_top_k |
|
|
else: |
|
|
models_to_delete = len(best_k_models) - self.model_parallel_size * self.save_top_k |
|
|
else: |
|
|
models_to_delete = len(best_k_models) - self.save_top_k |
|
|
|
|
|
models_to_delete = max(0, models_to_delete) |
|
|
logging.debug(f'Number of models to delete: {models_to_delete}') |
|
|
|
|
|
|
|
|
ema_enabled = self._has_ema_ckpts(self._saved_checkpoint_paths) |
|
|
|
|
|
for _ in range(models_to_delete): |
|
|
model = best_k_models.pop(-1) |
|
|
self.best_k_models.pop(model) |
|
|
self._del_model_without_trainer(model) |
|
|
if ema_enabled and self._fs.exists(self._ema_format_filepath(model)): |
|
|
self._del_model_without_trainer(self._ema_format_filepath(model)) |
|
|
logging.debug(f"Removed checkpoint: {model}") |
|
|
|
|
|
self.kth_best_model_path = best_k_models[-1] |
|
|
self.best_model_path = best_k_models[0] |
|
|
self.best_model_score = self.best_k_models[self.best_model_path] |
|
|
|
|
|
def _remove_invalid_entries_from_topk(self): |
|
|
|
|
|
|
|
|
def __is_ckpt_ok(ckpt_path: str) -> bool: |
|
|
exists = ( |
|
|
os.path.isfile(ckpt_path) |
|
|
or os.path.isfile(inject_model_parallel_rank(ckpt_path)) |
|
|
or os.path.isdir(ckpt_path.removesuffix('.ckpt')) |
|
|
) |
|
|
return exists and not self.is_checkpoint_unfinished(ckpt_path) |
|
|
|
|
|
self.best_k_models = {k: v for k, v in self.best_k_models.items() if __is_ckpt_ok(k)} |
|
|
if len(self.best_k_models) > 0: |
|
|
reverse_arr = self.mode != "min" |
|
|
best_k_models_arr = sorted(self.best_k_models, key=self.best_k_models.get, reverse=reverse_arr) |
|
|
self.kth_best_model_path = best_k_models_arr[-1] |
|
|
self.kth_value = self.best_k_models[self.kth_best_model_path] |
|
|
self.best_model_path = best_k_models_arr[0] |
|
|
self.best_model_score = self.best_k_models[self.best_model_path] |
|
|
else: |
|
|
self.kth_best_model_path = "" |
|
|
self.kth_value = None |
|
|
self.best_model_path = "" |
|
|
self.best_model_score = None |
|
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
|
|
""" |
|
|
Load the state dict. |
|
|
""" |
|
|
super().load_state_dict(state_dict) |
|
|
self._remove_invalid_entries_from_topk() |
|
|
|
|
|
def setup(self, trainer, pl_module, stage: str) -> None: |
|
|
""" |
|
|
Setup the checkpoint. |
|
|
""" |
|
|
if is_global_rank_zero(): |
|
|
logging.debug("Removing unfinished checkpoints if any...") |
|
|
NeMoModelCheckpoint._remove_unfinished_checkpoints(self.dirpath) |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
|
torch.distributed.barrier() |
|
|
super().setup(trainer, pl_module, stage) |
|
|
|
|
|
|
|
|
path = trainer.strategy.broadcast(trainer.ckpt_path) |
|
|
trainer.ckpt_path = path |
|
|
|
|
|
self.last_model_path = trainer.strategy.broadcast(self.last_model_path) |
|
|
|
|
|
def on_save_checkpoint(self, trainer, pl_module, checkpoint): |
|
|
""" |
|
|
Save the checkpoint. |
|
|
""" |
|
|
output = super().on_save_checkpoint(trainer, pl_module, checkpoint) |
|
|
if not self.always_save_nemo: |
|
|
return output |
|
|
|
|
|
app_state = AppState() |
|
|
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: |
|
|
logging.warning('always_save_nemo will slow down training for model_parallel > 1.') |
|
|
|
|
|
app_state.model_restore_path = self._format_nemo_checkpoint_name() |
|
|
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: |
|
|
maybe_injected_best_model_path = inject_model_parallel_rank(self.best_model_path) |
|
|
else: |
|
|
maybe_injected_best_model_path = self.best_model_path |
|
|
|
|
|
if self.save_best_model: |
|
|
if not os.path.exists(maybe_injected_best_model_path): |
|
|
return |
|
|
|
|
|
if self.best_model_path == self.previous_best_path: |
|
|
logging.debug('Best model has not changed, skipping save.') |
|
|
return output |
|
|
|
|
|
self.previous_best_path = self.best_model_path |
|
|
old_state_dict = deepcopy(pl_module.state_dict()) |
|
|
checkpoint = torch.load(maybe_injected_best_model_path, map_location='cpu', weights_only=False) |
|
|
if 'state_dict' in checkpoint: |
|
|
checkpoint = checkpoint['state_dict'] |
|
|
|
|
|
pl_module.load_state_dict(checkpoint, strict=True) |
|
|
if torch.distributed.is_initialized(): |
|
|
torch.distributed.barrier() |
|
|
backup_path = self._backup_existing_nemo_ckpt(trainer) |
|
|
pl_module.save_to(save_path=app_state.model_restore_path) |
|
|
logging.info(f"New best .nemo model saved to: {app_state.model_restore_path}") |
|
|
pl_module.load_state_dict(old_state_dict, strict=True) |
|
|
else: |
|
|
if torch.distributed.is_initialized(): |
|
|
torch.distributed.barrier() |
|
|
backup_path = self._backup_existing_nemo_ckpt(trainer) |
|
|
pl_module.save_to(save_path=app_state.model_restore_path) |
|
|
logging.info(f"New .nemo model saved to: {app_state.model_restore_path}") |
|
|
if backup_path is not None and is_global_rank_zero(): |
|
|
logging.info(f'Removing old .nemo backup {backup_path}') |
|
|
get_filesystem(backup_path).rm(backup_path) |
|
|
return output |
|
|
|
|
|
def on_train_end(self, trainer, pl_module): |
|
|
""" |
|
|
Save the checkpoint on train end. |
|
|
""" |
|
|
if trainer.fast_dev_run: |
|
|
return None |
|
|
|
|
|
|
|
|
if self.save_last and trainer.val_check_interval != 0: |
|
|
should_save_last_checkpoint = False |
|
|
if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0: |
|
|
should_save_last_checkpoint = True |
|
|
if isinstance(trainer.val_check_interval, int) and trainer.global_step % trainer.val_check_interval != 0: |
|
|
should_save_last_checkpoint = True |
|
|
if should_save_last_checkpoint: |
|
|
monitor_candidates = self._monitor_candidates(trainer) |
|
|
if self.last_model_path == self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST): |
|
|
logging.debug(f'Last checkpoint {self.last_model_path} already saved') |
|
|
else: |
|
|
super()._save_last_checkpoint(trainer, monitor_candidates) |
|
|
|
|
|
super().on_train_end(trainer, pl_module) |
|
|
|
|
|
|
|
|
if self.save_best_model: |
|
|
|
|
|
trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end") |
|
|
if self.best_model_path == "": |
|
|
logging.warning( |
|
|
f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints " |
|
|
"were found. Saving latest model instead." |
|
|
) |
|
|
else: |
|
|
if os.path.isdir(self.best_model_path.split('.ckpt')[0]): |
|
|
self.best_model_path = self.best_model_path.split('.ckpt')[0] |
|
|
self.best_model_path = trainer.strategy.broadcast(self.best_model_path) |
|
|
trainer._checkpoint_connector.restore(self.best_model_path) |
|
|
|
|
|
if self.save_nemo_on_train_end: |
|
|
backup_path = self._backup_existing_nemo_ckpt(trainer) |
|
|
pl_module.save_to(save_path=self._format_nemo_checkpoint_name()) |
|
|
if backup_path is not None and is_global_rank_zero(): |
|
|
logging.info(f'Removing old .nemo backup {backup_path}') |
|
|
get_filesystem(backup_path).rm(backup_path) |
|
|
|
|
|
def _backup_existing_nemo_ckpt(self, trainer) -> Optional[str]: |
|
|
"""Search for an available name with version infix and rename existing checkpoint. |
|
|
|
|
|
NOTE: this behavior is slightly different from regular checkpoints. |
|
|
PTL creates new regular checkpoint with the first available name. |
|
|
Here, for backward compatibility, we create .nemo checkpoint as before |
|
|
and create a backup under the first available name. |
|
|
|
|
|
Args: |
|
|
trainer (Trainer): trainer instance. |
|
|
|
|
|
Returns: |
|
|
Path to the backup checkpoint or None, if no backup was created |
|
|
""" |
|
|
base_path = self._format_nemo_checkpoint_name() |
|
|
available_path = base_path |
|
|
if self._enable_version_counter: |
|
|
version_cnt = self.STARTING_VERSION |
|
|
while self.file_exists(available_path, trainer, check_dist_ckpt=False): |
|
|
available_path = self._format_nemo_checkpoint_name(version_cnt) |
|
|
version_cnt += 1 |
|
|
if available_path == base_path: |
|
|
|
|
|
return None |
|
|
if trainer.is_global_zero: |
|
|
logging.info(f'{base_path} already exists, moving existing checkpoint to {available_path}') |
|
|
if is_multistorageclient_url(base_path): |
|
|
|
|
|
|
|
|
pass |
|
|
else: |
|
|
shutil.move(base_path, available_path) |
|
|
trainer.strategy.barrier() |
|
|
return available_path |
|
|
|
|
|
def _format_nemo_checkpoint_name(self, ver: Optional[int] = None) -> str: |
|
|
version_infix = '' if ver is None else f'{self.CHECKPOINT_JOIN_CHAR}v{ver}' |
|
|
if is_multistorageclient_url(self.dirpath): |
|
|
return f"{self.dirpath}/{self.prefix + version_infix + self.postfix}" |
|
|
return os.path.abspath( |
|
|
os.path.expanduser(os.path.join(self.dirpath, self.prefix + version_infix + self.postfix)) |
|
|
) |
|
|
|
|
|
def _del_model_without_trainer(self, filepath: str) -> None: |
|
|
|
|
|
filepath = Path(filepath) |
|
|
|
|
|
|
|
|
if ckpt_to_dir(filepath).is_dir(): |
|
|
if is_global_rank_zero(): |
|
|
try: |
|
|
dist_ckpt = ckpt_to_dir(filepath) |
|
|
shutil.rmtree(dist_ckpt, ignore_errors=True) |
|
|
logging.info(f"Removed distributed checkpoint: {dist_ckpt}") |
|
|
except: |
|
|
logging.info(f"Tried to remove distributed checkpoint: {dist_ckpt} but failed.") |
|
|
|
|
|
else: |
|
|
app_state = AppState() |
|
|
|
|
|
|
|
|
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: |
|
|
|
|
|
filepath = inject_model_parallel_rank(filepath) |
|
|
|
|
|
|
|
|
if is_global_rank_zero() or ( |
|
|
app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0 |
|
|
): |
|
|
try: |
|
|
self._fs.rm(filepath) |
|
|
logging.info(f"Removed checkpoint: {filepath}") |
|
|
except: |
|
|
logging.info(f"Tried to remove checkpoint: {filepath} but failed.") |
|
|
|
|
|
def _ema_callback(self, trainer: 'lightning.pytorch.Trainer') -> Optional[EMA]: |
|
|
ema_callback = None |
|
|
for callback in trainer.callbacks: |
|
|
if isinstance(callback, EMA): |
|
|
ema_callback = callback |
|
|
return ema_callback |
|
|
|
|
|
def _drop_optimizer_states(self, trainer, filepath: Union[str, Path], storage_options: Optional[Any]) -> None: |
|
|
|
|
|
checkpoints = self._get_checkpoints_list(filepath) |
|
|
suffix = "-no-optim" |
|
|
|
|
|
|
|
|
checkpoint_index = len(checkpoints) - self.save_last_n_optim_states - 1 |
|
|
if len(checkpoints) > self.save_last_n_optim_states: |
|
|
checkpoint_path = checkpoints[checkpoint_index] |
|
|
|
|
|
logging.info(f"Loading '{checkpoint_path}' checkpoint to drop optimizer states...") |
|
|
checkpoint = trainer.strategy.load_checkpoint(checkpoint_path=checkpoint_path, load_optimizer_states=False) |
|
|
|
|
|
|
|
|
self._load_current_state_dict(trainer, checkpoint) |
|
|
|
|
|
|
|
|
if storage_options is None: |
|
|
storage_options = dict(include_optimizer=False) |
|
|
else: |
|
|
storage_options["include_optimizer"] = False |
|
|
|
|
|
trainer.save_checkpoint( |
|
|
f"{checkpoint_path}{suffix}.ckpt", self.save_weights_only, storage_options=storage_options |
|
|
) |
|
|
|
|
|
|
|
|
if is_global_rank_zero(): |
|
|
trainer.strategy.remove_checkpoint(checkpoint_path) |
|
|
shutil.move(f"{checkpoint_path}{suffix}", checkpoint_path) |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
|
torch.distributed.barrier() |
|
|
|
|
|
|
|
|
|
|
|
checkpoint = trainer.strategy.load_checkpoint( |
|
|
checkpoint_path=ckpt_to_dir(filepath), load_optimizer_states=False |
|
|
) |
|
|
self._load_current_state_dict(trainer, checkpoint) |
|
|
|
|
|
logging.info(f"Successfully dropped optimizer states for '{checkpoint_path}' checkpoint.") |
|
|
|
|
|
def _get_checkpoints_list(self, filepath: Union[str, Path]) -> List[str]: |
|
|
|
|
|
checkpoints_dir = os.path.dirname(filepath) |
|
|
|
|
|
|
|
|
checkpoints = [ |
|
|
d |
|
|
for d in os.listdir(checkpoints_dir) |
|
|
if os.path.isdir(os.path.join(checkpoints_dir, d)) and '-last' not in d |
|
|
] |
|
|
checkpoints = sorted(checkpoints, key=lambda x: int(x.split('-step=')[1].split('-')[0])) |
|
|
checkpoints = [os.path.join(checkpoints_dir, checkpoint) for checkpoint in checkpoints] |
|
|
|
|
|
return checkpoints |
|
|
|
|
|
def _load_current_state_dict(self, trainer, checkpoint) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
call._call_lightning_module_hook(trainer, "on_load_checkpoint", checkpoint) |
|
|
|
|
|
|
|
|
trainer.strategy.load_model_state_dict( |
|
|
checkpoint, |
|
|
strict=trainer.lightning_module.strict_loading, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def format_checkpoint_unfinished_marker_path(checkpoint_path: Union[Path, str]) -> Path: |
|
|
"""Format the path to the unfinished checkpoint marker file. |
|
|
|
|
|
If the marker file exists, corresponding checkpoint is considered unfinished/incomplete. |
|
|
NOTE: Marker path for the EMA checkpoint part is the same as for the original checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the checkpoint file or dir. |
|
|
Does not need to exist. |
|
|
|
|
|
Returns: |
|
|
Path to the unfinished checkpoint marker file. |
|
|
""" |
|
|
marker_filepath = str(uninject_model_parallel_rank(checkpoint_path)) |
|
|
marker_filepath = marker_filepath.removesuffix(".nemo") |
|
|
marker_filepath = marker_filepath.removesuffix(".ckpt") |
|
|
marker_filepath = marker_filepath.removesuffix("-EMA") |
|
|
return Path(marker_filepath + NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX) |
|
|
|
|
|
@staticmethod |
|
|
def is_checkpoint_unfinished(checkpoint_path: Union[Path, str]) -> bool: |
|
|
"""Check if the checkpoint is unfinished. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the checkpoint file or dir. |
|
|
Does not need to exist. |
|
|
|
|
|
Returns: |
|
|
True if the checkpoint is unfinished, False otherwise. |
|
|
""" |
|
|
return NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path).exists() |
|
|
|
|
|
@staticmethod |
|
|
def set_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_after=False) -> None: |
|
|
"""Marks given checkpoint as unfinished. |
|
|
|
|
|
Args: |
|
|
checkpoint_filepath: Path to the checkpoint file or dir. |
|
|
Does not need to exist. |
|
|
barrier_after: Synchronize ranks after writing the marker file. |
|
|
Defaults to False. |
|
|
""" |
|
|
if is_global_rank_zero(): |
|
|
marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) |
|
|
marker_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
marker_path.touch() |
|
|
if barrier_after and torch.distributed.is_initialized(): |
|
|
torch.distributed.barrier() |
|
|
|
|
|
@staticmethod |
|
|
def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_before=False) -> None: |
|
|
"""Clear unfinished marker for given checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the checkpoint file or dir. |
|
|
Does not need to exist. |
|
|
barrier_before: Synchronize ranks before removing the marker file. |
|
|
Defaults to False. |
|
|
""" |
|
|
try: |
|
|
if barrier_before and torch.distributed.is_initialized(): |
|
|
torch.distributed.barrier() |
|
|
if is_global_rank_zero(): |
|
|
marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) |
|
|
if marker_path.exists(): |
|
|
marker_path.unlink() |
|
|
except: |
|
|
return |
|
|
|
|
|
def file_exists( |
|
|
self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True |
|
|
) -> bool: |
|
|
"""Checks if a file or a file without a suffix (distributed checkpoint) exists.""" |
|
|
if is_multistorageclient_url(filepath): |
|
|
exists = self._fs.exists(filepath) |
|
|
else: |
|
|
exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath))) |
|
|
|
|
|
return trainer.strategy.broadcast(exists) |
|
|
|
|
|
def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None: |
|
|
|
|
|
|
|
|
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) |
|
|
ema_callback = self._ema_callback(trainer) |
|
|
if ema_callback is not None: |
|
|
if self.async_save: |
|
|
raise ValueError('async_save with EMA not supported') |
|
|
with ema_callback.save_original_optimizer_state(trainer): |
|
|
super()._save_checkpoint(trainer, filepath) |
|
|
|
|
|
|
|
|
with ema_callback.save_ema_model(trainer): |
|
|
filepath = self._ema_format_filepath(filepath) |
|
|
if self.verbose: |
|
|
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") |
|
|
super()._save_checkpoint(trainer, filepath) |
|
|
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) |
|
|
else: |
|
|
|
|
|
|
|
|
finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step) |
|
|
if self.async_save: |
|
|
checkpoint_io = trainer.strategy.checkpoint_io |
|
|
if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): |
|
|
raise ValueError('Async save requires async compatible CheckpointIO') |
|
|
storage_options = dict(finalize_fn=finalize_fn) |
|
|
|
|
|
self.deferred_ckpts_to_remove.append([]) |
|
|
else: |
|
|
storage_options = None |
|
|
logging.info(f'Checkpoint save for step {trainer.global_step} started at {time.time()}.') |
|
|
trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options) |
|
|
if self.async_save: |
|
|
logging.info(f'Scheduled async checkpoint save for {filepath}') |
|
|
else: |
|
|
finalize_fn() |
|
|
|
|
|
if self.save_last_n_optim_states >= 0 and '-last' in filepath: |
|
|
self._drop_optimizer_states(trainer, filepath, storage_options) |
|
|
|
|
|
def _get_finalize_save_checkpoint_callback( |
|
|
self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int |
|
|
): |
|
|
"""Creates a callback that can be used to finalize async (and sync) ckpt saves.""" |
|
|
|
|
|
def _cb(): |
|
|
logging.debug(f'Finalize callback called for step {global_step}, filepath {filepath}') |
|
|
self._last_global_step_saved = global_step |
|
|
self._last_checkpoint_saved = filepath |
|
|
|
|
|
|
|
|
if trainer.is_global_zero: |
|
|
for logger in trainer.loggers: |
|
|
logger.after_save_checkpoint(proxy(self)) |
|
|
|
|
|
|
|
|
|
|
|
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) |
|
|
|
|
|
if not self.async_save: |
|
|
return |
|
|
|
|
|
logging.info( |
|
|
f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully at {time.time()}.' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
assert self.deferred_ckpts_to_remove |
|
|
ckpts_to_remove = self.deferred_ckpts_to_remove.pop(0) |
|
|
logging.debug(f'Checkpoints to remove: {ckpts_to_remove}') |
|
|
for ckpt_to_remove in ckpts_to_remove: |
|
|
self._remove_checkpoint(trainer, ckpt_to_remove, override_async=True) |
|
|
|
|
|
return _cb |
|
|
|
|
|
def _remove_checkpoint( |
|
|
self, trainer: "lightning.pytorch.Trainer", filepath: str, override_async=False |
|
|
) -> None: |
|
|
"""Performs checkpoint removal or deferred removal. |
|
|
|
|
|
With async save, `self._remove_checkpoint` is called before the checkpoint |
|
|
is actually finished so we can't remove it. Instead we add it to |
|
|
`self.deferred_ckpts_to_remove` for future removal. |
|
|
""" |
|
|
if self.async_save and not override_async: |
|
|
|
|
|
self.deferred_ckpts_to_remove[-1].append(filepath) |
|
|
return |
|
|
|
|
|
|
|
|
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) |
|
|
super()._remove_checkpoint(trainer, filepath) |
|
|
ema_callback = self._ema_callback(trainer) |
|
|
if ema_callback is not None: |
|
|
|
|
|
filepath = self._ema_format_filepath(filepath) |
|
|
super()._remove_checkpoint(trainer, filepath) |
|
|
|
|
|
|
|
|
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) |
|
|
|
|
|
def _ema_format_filepath(self, filepath: str) -> str: |
|
|
return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') |
|
|
|
|
|
def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool: |
|
|
return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints) |
|
|
|
|
|
def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: |
|
|
return str(filepath).endswith(f'-EMA{self.FILE_EXTENSION}') |
|
|
|
|
|
@property |
|
|
def _saved_checkpoint_paths(self) -> Iterable[Path]: |
|
|
|
|
|
|
|
|
|
|
|
if is_multistorageclient_url(self.dirpath): |
|
|
msc = import_multistorageclient() |
|
|
return msc.glob(f"{self.dirpath}/*.ckpt") |
|
|
else: |
|
|
dist_checkpoints = [d for d in Path(self.dirpath).glob("*") if d.is_dir()] |
|
|
if dist_checkpoints: |
|
|
return filter(lambda p: not self.is_checkpoint_unfinished(p), dist_checkpoints) |
|
|
else: |
|
|
checkpoint_files = [f for f in Path(self.dirpath).rglob("*.ckpt")] |
|
|
return filter(lambda p: not self.is_checkpoint_unfinished(p), checkpoint_files) |
|
|
|
|
|
@staticmethod |
|
|
def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not is_global_rank_zero(): |
|
|
raise AssertionError("_remove_unfinished_checkpoints should run only on rank 0") |
|
|
|
|
|
if is_multistorageclient_url(checkpoint_dir): |
|
|
msc = import_multistorageclient() |
|
|
existing_marker_filepaths = msc.glob( |
|
|
f"{checkpoint_dir}*{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}" |
|
|
) |
|
|
fs = get_filesystem(checkpoint_dir) |
|
|
for ckpt_filepath in existing_marker_filepaths: |
|
|
fs.rm(ckpt_filepath) |
|
|
else: |
|
|
checkpoint_dir = Path(checkpoint_dir) |
|
|
|
|
|
existing_marker_filepaths = { |
|
|
f.resolve() |
|
|
for f in checkpoint_dir.glob(f"*{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}") |
|
|
if f.is_file() |
|
|
} |
|
|
|
|
|
checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")} |
|
|
for ckpt_filepath in checkpoint_filepaths: |
|
|
possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_filepath) |
|
|
if possible_marker_path in existing_marker_filepaths: |
|
|
logging.warning(f'Removing unfinished checkpoint: {ckpt_filepath}') |
|
|
os.remove(ckpt_filepath) |
|
|
|
|
|
|
|
|
all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()} |
|
|
for ckpt_dirpath in all_dirpaths: |
|
|
possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_dirpath) |
|
|
if possible_marker_path in existing_marker_filepaths: |
|
|
logging.warning(f'Removing unfinished dist checkpoint: {ckpt_dirpath}') |
|
|
shutil.rmtree(ckpt_dirpath) |
|
|
|
|
|
|
|
|
for marker_path in existing_marker_filepaths: |
|
|
os.remove(marker_path) |
|
|
|
|
|
def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool: |
|
|
"""Checks if the previous checkpoint should be deleted. |
|
|
A checkpoint won't be deleted if any of the cases apply: |
|
|
- The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new) |
|
|
- The previous checkpoint is not in the current checkpoint directory and the filesystem is local |
|
|
- The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local |
|
|
and the resumed from checkpoint is not the last checkpoint |
|
|
""" |
|
|
if previous == current: |
|
|
return False |
|
|
if not _is_local_file_protocol(previous): |
|
|
return True |
|
|
previous = Path(previous).absolute() |
|
|
resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None |
|
|
|
|
|
if resume_path is not None and previous == resume_path: |
|
|
if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"): |
|
|
|
|
|
|
|
|
pass |
|
|
else: |
|
|
return False |
|
|
if self.dirpath is None: |
|
|
raise ValueError(f"{self.__class__}.dirpath is None.") |
|
|
dirpath = Path(self.dirpath).absolute() |
|
|
return dirpath in previous.parents |
|
|
|