| """ |
| A set of utilities to manage and load checkpoints of training experiments. |
| |
| Author: Paul-Edouard Sarlin (skydes) |
| """ |
|
|
| import logging |
| import os |
| import re |
| import shutil |
| from pathlib import Path |
|
|
| import torch |
| from omegaconf import OmegaConf |
|
|
| from siclib.models import get_model |
| from siclib.settings import TRAINING_PATH |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| |
|
|
|
|
| def list_checkpoints(dir_): |
| """List all valid checkpoints in a given directory.""" |
| checkpoints = [] |
| for p in dir_.glob("checkpoint_*.tar"): |
| numbers = re.findall(r"(\d+)", p.name) |
| assert len(numbers) <= 2 |
| if len(numbers) == 0: |
| continue |
| if len(numbers) == 1: |
| checkpoints.append((int(numbers[0]), p)) |
| else: |
| checkpoints.append((int(numbers[1]), p)) |
| return checkpoints |
|
|
|
|
| def get_last_checkpoint(exper, allow_interrupted=True): |
| """Get the last saved checkpoint for a given experiment name.""" |
| ckpts = list_checkpoints(Path(TRAINING_PATH, exper)) |
| if not allow_interrupted: |
| ckpts = [(n, p) for (n, p) in ckpts if "_interrupted" not in p.name] |
| assert len(ckpts) > 0 |
| return sorted(ckpts)[-1][1] |
|
|
|
|
| def get_best_checkpoint(exper): |
| """Get the checkpoint with the best loss, for a given experiment name.""" |
| return Path(TRAINING_PATH, exper, "checkpoint_best.tar") |
|
|
|
|
| def delete_old_checkpoints(dir_, num_keep): |
| """Delete all but the num_keep last saved checkpoints.""" |
| ckpts = list_checkpoints(dir_) |
| ckpts = sorted(ckpts)[::-1] |
| kept = 0 |
| for ckpt in ckpts: |
| if ("_interrupted" in str(ckpt[1]) and kept > 0) or kept >= num_keep: |
| logger.info(f"Deleting checkpoint {ckpt[1].name}") |
| ckpt[1].unlink() |
| else: |
| kept += 1 |
|
|
|
|
| def load_experiment(exper, conf=None, get_last=False, ckpt=None): |
| """Load and return the model of a given experiment.""" |
| if conf is None: |
| conf = {} |
|
|
| exper = Path(exper) |
| if exper.suffix != ".tar": |
| ckpt = get_last_checkpoint(exper) if get_last else get_best_checkpoint(exper) |
| else: |
| ckpt = exper |
| logger.info(f"Loading checkpoint {ckpt.name}") |
| ckpt = torch.load(str(ckpt), map_location="cpu") |
|
|
| loaded_conf = OmegaConf.create(ckpt["conf"]) |
| OmegaConf.set_struct(loaded_conf, False) |
| conf = OmegaConf.merge(loaded_conf.model, OmegaConf.create(conf)) |
| model = get_model(conf.name)(conf).eval() |
|
|
| state_dict = ckpt["model"] |
|
|
| dict_params = set(state_dict.keys()) |
| model_params = set(map(lambda n: n[0], model.named_parameters())) |
| diff = model_params - dict_params |
| if len(diff) > 0: |
| subs = os.path.commonprefix(list(diff)).rstrip(".") |
| logger.warning(f"Missing {len(diff)} parameters in {subs}: {diff}") |
| model.load_state_dict(state_dict, strict=False) |
| return model |
|
|
|
|
| def save_experiment( |
| model, |
| optimizer, |
| lr_scheduler, |
| conf, |
| losses, |
| results, |
| best_eval, |
| epoch, |
| iter_i, |
| output_dir, |
| stop=False, |
| distributed=False, |
| cp_name=None, |
| ): |
| """Save the current model to a checkpoint |
| and return the best result so far.""" |
| state = (model.module if distributed else model).state_dict() |
| checkpoint = { |
| "model": state, |
| "optimizer": optimizer.state_dict(), |
| "lr_scheduler": lr_scheduler.state_dict(), |
| "conf": OmegaConf.to_container(conf, resolve=True), |
| "epoch": epoch, |
| "losses": losses, |
| "eval": results, |
| } |
| if cp_name is None: |
| cp_name = f"checkpoint_{epoch}_{iter_i}" + ("_interrupted" if stop else "") + ".tar" |
| logger.info(f"Saving checkpoint {cp_name}") |
| cp_path = str(output_dir / cp_name) |
| torch.save(checkpoint, cp_path) |
|
|
| if cp_name != "checkpoint_best.tar" and results[conf.train.best_key] < best_eval: |
| best_eval = results[conf.train.best_key] |
| logger.info(f"New best val: {conf.train.best_key}={best_eval}") |
| shutil.copy(cp_path, str(output_dir / "checkpoint_best.tar")) |
| delete_old_checkpoints(output_dir, conf.train.keep_last_checkpoints) |
| return best_eval |
|
|