| import os | |
| import shutil | |
| import accelerate | |
| import torch | |
| import glob | |
| def restore_checkpoint( | |
| checkpoint_dir, | |
| accelerator: accelerate.Accelerator, | |
| logger=None | |
| ): | |
| dirs = glob.glob(os.path.join(checkpoint_dir, "*")) | |
| dirs.sort() | |
| path = dirs[-1] if len(dirs) > 0 else None | |
| if path is None: | |
| if logger is not None: | |
| logger.info("Checkpoint does not exist. Starting a new training run.") | |
| init_step = 0 | |
| else: | |
| if logger is not None: | |
| logger.info(f"Resuming from checkpoint {path}") | |
| accelerator.load_state(path) | |
| init_step = int(os.path.basename(path)) | |
| return init_step | |
| def save_checkpoint(save_dir, | |
| accelerator: accelerate.Accelerator, | |
| step=0, | |
| total_limit=3): | |
| if total_limit > 0: | |
| folders = glob.glob(os.path.join(save_dir, "*")) | |
| folders.sort() | |
| for folder in folders[: len(folders) + 1 - total_limit]: | |
| shutil.rmtree(folder) | |
| accelerator.save_state(os.path.join(save_dir, f"{step:06d}")) | |