Spaces:
Runtime error
Runtime error
| import copy | |
| import pathlib | |
| from typing import Any, Dict, List, Optional | |
| from loguru import logger | |
| import torch | |
| from torch import nn | |
| import virtex.utils.distributed as dist | |
| class CheckpointManager(object): | |
| r""" | |
| A helper class to periodically serialize models and other checkpointable | |
| objects (optimizers, LR schedulers etc., which implement ``state_dict`` | |
| method) during training, and optionally record best performing checkpoint | |
| based on an observed metric. | |
| .. note:: | |
| For :class:`~torch.nn.parallel.DistributedDataParallel` objects, | |
| ``state_dict`` of internal model is serialized. | |
| .. note:: | |
| The observed metric for keeping best checkpoint is assumed "higher is | |
| better", flip the sign if otherwise. | |
| Parameters | |
| ---------- | |
| serialization_dir: str | |
| Path to a directory to save checkpoints. | |
| keep_recent: int, optional (default = 100) | |
| Number of recent ``k`` checkpoints to keep on disk. Older checkpoints | |
| will be removed. Set to a very large value for keeping all checkpoints. | |
| checkpointables: Any | |
| Keyword arguments with any checkpointable objects, for example: model, | |
| optimizer, learning rate scheduler. | |
| Examples | |
| -------- | |
| >>> model = torch.nn.Linear(10, 2) | |
| >>> optimizer = torch.optim.Adam(model.parameters()) | |
| >>> ckpt_manager = CheckpointManager("/tmp", model=model, optimizer=optimizer) | |
| >>> num_epochs = 20 | |
| >>> for epoch in range(num_epochs): | |
| ... train(model) | |
| ... val_loss = validate(model) | |
| ... ckpt_manager.step(- val_loss, epoch) | |
| """ | |
| def __init__( | |
| self, | |
| serialization_dir: str = "/tmp", | |
| keep_recent: int = 200, | |
| **checkpointables: Any, | |
| ): | |
| self.serialization_dir = pathlib.Path(serialization_dir) | |
| self.keep_recent = keep_recent | |
| # Shallow copy, keeps references to tensors as original objects. | |
| self.checkpointables = copy.copy(checkpointables) | |
| # Initialize members to hold state dict of best checkpoint and its | |
| # performance. | |
| self._best_metric: float = -1e-12 | |
| self._best_ckpt: Dict[str, Any] = {} | |
| # Keep epoch/iteration numbers of recently saved 'k' checkpoints. | |
| self._recent_iterations: List[int] = [] | |
| def step(self, iteration: int, metric: Optional[float] = None): | |
| r""" | |
| Serialize checkpoint and update best checkpoint based on metric. Keys | |
| in serialized checkpoint match those in :attr:`checkpointables`. | |
| Parameters | |
| ---------- | |
| iteration: int | |
| Current training iteration. Will be saved with other checkpointables. | |
| metric: float, optional (default = None) | |
| Observed metric (higher is better) for keeping track of best | |
| checkpoint. If this is ``None``, best chckpoint will not be | |
| recorded/updated. | |
| """ | |
| checkpointable_state_dict: Dict[str, Any] = self._state_dict() | |
| # We also checkpoint current iteration. | |
| checkpointable_state_dict["iteration"] = iteration | |
| # Update the best checkpoint based on metric, if provided. | |
| if metric is not None and metric > self._best_metric: | |
| self._best_metric = metric | |
| self._best_ckpt = copy.copy(checkpointable_state_dict) | |
| # Serialize checkpoint corresponding to current iteration. | |
| torch.save( | |
| checkpointable_state_dict, | |
| self.serialization_dir / f"checkpoint_{iteration}.pth", | |
| ) | |
| if self._best_metric != -1e-12: | |
| # Serialize best performing checkpoint observed so far. | |
| torch.save( | |
| self._best_ckpt, self.serialization_dir / "checkpoint_best.pth" | |
| ) | |
| # Remove earliest checkpoint if there are more on disk. | |
| self._recent_iterations.append(iteration) | |
| if len(self._recent_iterations) > self.keep_recent: | |
| self.remove_earliest_checkpoint() | |
| def _state_dict(self): | |
| r"""Return a dict containing state dict of all checkpointables.""" | |
| __state_dict: Dict[str, Any] = {} | |
| for key in self.checkpointables: | |
| if isinstance( | |
| self.checkpointables[key], nn.parallel.DistributedDataParallel | |
| ): | |
| __state_dict[key] = self.checkpointables[key].module.state_dict() | |
| else: | |
| __state_dict[key] = self.checkpointables[key].state_dict() | |
| return __state_dict | |
| def remove_earliest_checkpoint(self): | |
| r"""Remove earliest serialized checkpoint from disk.""" | |
| earliest_iteration = self._recent_iterations.pop(0) | |
| (self.serialization_dir / f"checkpoint_{earliest_iteration}.pth").unlink() | |
| def load(self, checkpoint_path: str): | |
| r""" | |
| Load a serialized checkpoint from a path. This method will try to find | |
| each of :attr:`checkpointables` in the file and load its state dict. | |
| Since our checkpointables are held as references, this method does not | |
| return them. | |
| Parameters | |
| ---------- | |
| checkpoint_path: str | |
| Path to a checkpoint serialized by :meth:`step`. | |
| Returns | |
| ------- | |
| int | |
| Iteration corresponding to the loaded checkpoint. Useful for | |
| resuming training. This will be -1 in case of best checkpoint, | |
| or if info does not exist. | |
| """ | |
| # Each process will log a message after loading checkpoint. | |
| rank = dist.get_rank() | |
| logger.info(f"Rank {rank}: Loading checkpoint from {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| iteration = checkpoint.pop("iteration", -1) | |
| # Keep flags of all checkpointables to lo which ones were not loaded. | |
| is_loaded = {key: False for key in self.checkpointables} | |
| # Load each checkpointable from checkpoint. | |
| for key in checkpoint: | |
| if key in self.checkpointables: | |
| logger.info(f"Rank {rank}: Loading {key} from {checkpoint_path}") | |
| if isinstance( | |
| self.checkpointables[key], nn.parallel.DistributedDataParallel | |
| ): | |
| self.checkpointables[key].module.load_state_dict(checkpoint[key]) | |
| else: | |
| self.checkpointables[key].load_state_dict(checkpoint[key]) | |
| is_loaded[key] = True | |
| else: | |
| logger.info(f"Rank {rank}: {key} not found in `checkpointables`.") | |
| not_loaded: List[str] = [key for key in is_loaded if not is_loaded[key]] | |
| if len(not_loaded) > 0: | |
| logger.info( | |
| f"Rank {rank}: Checkpointables not found in file: {not_loaded}" | |
| ) | |
| return iteration | |