| |
| |
| |
| |
| |
|
|
| from enum import Enum |
| import logging |
| from pathlib import Path |
| import re |
| import typing as tp |
|
|
| import flashy |
| import torch |
|
|
| from ..environment import AudioCraftEnvironment |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class CheckpointSource(Enum): |
| CURRENT_XP = "current_xp" |
| PRETRAINED = "pretrained" |
| OTHER = "other" |
|
|
|
|
| def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str: |
| """Checkpoint name formatted for all use in AudioCraft codebase and has the following format: |
| `checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint, |
| 'best' for the best checkpoint or the epoch number. |
| |
| Args: |
| name (str, optional): Name suffix for the checkpoint file stem. |
| rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. |
| use_fsdp (bool): Whether the calling solver relies on FSDP. |
| Returns: |
| str: The checkpoint name. |
| """ |
| suffix = '' |
| if rank is None: |
| rank = flashy.distrib.rank() |
| if rank > 0 and use_fsdp: |
| suffix = '.' + str(rank) |
| name_part = '' |
| if name is not None: |
| name_part = f'_{name}' |
| return f'checkpoint{name_part}.th{suffix}' |
|
|
|
|
| def is_sharded_checkpoint(path: Path) -> bool: |
| """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank.""" |
| return re.search(r'\.th\.\d+$', path.name) is not None |
|
|
|
|
| def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None, |
| use_fsdp: bool = False) -> tp.Optional[Path]: |
| """Resolve a given checkpoint path for a provided dora sig or path. |
| |
| Args: |
| sig_or_path (Path or str): Checkpoint path or dora signature. |
| name (str, optional): Name suffix for the checkpoint file stem. |
| rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. |
| use_fsdp (bool): Whether the calling solver relies on FSDP. |
| Returns: |
| Path, optional: Resolved checkpoint path, if it exists. |
| """ |
| from audiocraft import train |
| xps_root = train.main.dora.dir / 'xps' |
| sig_or_path = str(sig_or_path) |
| if sig_or_path.startswith('//sig/'): |
| sig = sig_or_path[len('//sig/'):] |
| path = xps_root / sig |
| else: |
| path = Path(sig_or_path) |
| path = AudioCraftEnvironment.resolve_reference_path(path) |
|
|
| if path.is_dir(): |
| path = path / checkpoint_name(name, use_fsdp=use_fsdp) |
|
|
| if path.exists(): |
| return path |
| else: |
| return None |
|
|
|
|
| def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any: |
| """Load state from checkpoints at the specified checkpoint path.""" |
| if is_sharded: |
| rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False) |
| if rank0_checkpoint_path.exists(): |
| check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path) |
| state = torch.load(checkpoint_path, 'cpu') |
| logger.info("Checkpoint loaded from %s", checkpoint_path) |
| return state |
|
|
|
|
| def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: |
| """Save state to disk to the specified checkpoint_path.""" |
| _safe_save_checkpoint(state, checkpoint_path, is_sharded) |
| logger.info("Checkpoint saved to %s", checkpoint_path) |
|
|
|
|
| def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None: |
| """Flush checkpoints to only keep last N checkpoints.""" |
| if keep_last is None or keep_last <= 0: |
| return |
| checkpoint_dir = checkpoint_path.parent |
| suffix = '' |
| if flashy.distrib.rank() > 0: |
| suffix = f'.{flashy.distrib.rank()}' |
| checkpoint_files_with_epoch = [] |
| for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'): |
| epoch_part = path.name.split('.', 1)[0].split('_', 1)[1] |
| if epoch_part.isdigit(): |
| checkpoint_files_with_epoch.append((path, int(epoch_part))) |
| checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))] |
| total_to_flush = max(0, len(checkpoint_files) - keep_last) |
| files_to_flush = checkpoint_files[:total_to_flush] |
| for path in files_to_flush: |
| logger.debug("Removing checkpoint: %s", str(path)) |
| path.unlink(missing_ok=True) |
|
|
|
|
| def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None: |
| """Check sharded checkpoint state, ensuring the checkpoints are not corrupted.""" |
| |
| old_path = Path(str(checkpoint_path) + '.old') |
| if old_path.exists(): |
| raise RuntimeError( |
| f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.") |
| token = Path(str(rank0_checkpoint_path) + '.tmp.done') |
| tmp_path = Path(str(checkpoint_path) + '.tmp') |
| if token.exists(): |
| if tmp_path.exists(): |
| tmp_path.rename(checkpoint_path) |
| flashy.distrib.barrier() |
| if flashy.distrib.is_rank_zero() and token.exists(): |
| token.unlink() |
|
|
|
|
| def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: |
| """Save checkpoints in a safe manner even with when sharded checkpoints across nodes.""" |
| def _barrier_if_sharded(): |
| if is_sharded: |
| flashy.distrib.barrier() |
|
|
| if flashy.distrib.is_rank_zero(): |
| token = Path(str(checkpoint_path) + '.tmp.done') |
| if token.exists(): |
| token.unlink() |
| _barrier_if_sharded() |
| with flashy.utils.write_and_rename(checkpoint_path) as f: |
| torch.save(state, f) |
| _barrier_if_sharded() |
| if flashy.distrib.is_rank_zero(): |
| token.touch() |
| _barrier_if_sharded() |
| _barrier_if_sharded() |
| if flashy.distrib.rank() == 0: |
| token.unlink() |
|
|