| import os |
| from omegaconf import OmegaConf |
| from collections import OrderedDict |
| import logging |
|
|
| mainlogger = logging.getLogger("mainlogger") |
|
|
| import torch |
| from collections import OrderedDict |
|
|
|
|
| def init_workspace(name, logdir, model_config, lightning_config, rank=0): |
| workdir = os.path.join(logdir, name) |
| ckptdir = os.path.join(workdir, "checkpoints") |
| cfgdir = os.path.join(workdir, "configs") |
| loginfo = os.path.join(workdir, "loginfo") |
|
|
| |
| os.makedirs(workdir, exist_ok=True) |
| os.makedirs(ckptdir, exist_ok=True) |
| os.makedirs(cfgdir, exist_ok=True) |
| os.makedirs(loginfo, exist_ok=True) |
|
|
| if rank == 0: |
| if ( |
| "callbacks" in lightning_config |
| and "metrics_over_trainsteps_checkpoint" in lightning_config.callbacks |
| ): |
| os.makedirs(os.path.join(ckptdir, "trainstep_checkpoints"), exist_ok=True) |
| OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml")) |
| OmegaConf.save( |
| OmegaConf.create({"lightning": lightning_config}), |
| os.path.join(cfgdir, "lightning.yaml"), |
| ) |
| return workdir, ckptdir, cfgdir, loginfo |
|
|
|
|
| def check_config_attribute(config, name): |
| if name in config: |
| value = getattr(config, name) |
| return value |
| else: |
| return None |
|
|
|
|
| def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger): |
| default_callbacks_cfg = { |
| "model_checkpoint": { |
| "target": "pytorch_lightning.callbacks.ModelCheckpoint", |
| "params": { |
| "dirpath": ckptdir, |
| "filename": "{epoch}", |
| "verbose": True, |
| "save_last": True, |
| }, |
| }, |
| "batch_logger": { |
| "target": "utils.callbacks.ImageLogger", |
| "params": { |
| "save_dir": logdir, |
| "batch_frequency": 1000, |
| "max_images": 4, |
| "clamp": True, |
| }, |
| }, |
| "learning_rate_logger": { |
| "target": "pytorch_lightning.callbacks.LearningRateMonitor", |
| "params": {"logging_interval": "step", "log_momentum": False}, |
| }, |
| "cuda_callback": {"target": "utils.callbacks.CUDACallback"}, |
| } |
|
|
| |
| monitor_metric = check_config_attribute(config.model.params, "monitor") |
| if monitor_metric is not None: |
| mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.") |
| default_callbacks_cfg["model_checkpoint"]["params"]["monitor"] = monitor_metric |
| default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3 |
| default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min" |
|
|
| if "metrics_over_trainsteps_checkpoint" in lightning_config.callbacks: |
| mainlogger.info( |
| "Caution: Saving checkpoints every n train steps without deleting. This might require some free space." |
| ) |
| default_metrics_over_trainsteps_ckpt_dict = { |
| "metrics_over_trainsteps_checkpoint": { |
| "target": "pytorch_lightning.callbacks.ModelCheckpoint", |
| "params": { |
| "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), |
| "filename": "{epoch}-{step}", |
| "verbose": True, |
| "save_top_k": -1, |
| "every_n_train_steps": 10000, |
| "save_weights_only": True, |
| }, |
| } |
| } |
| default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) |
|
|
| if "callbacks" in lightning_config: |
| callbacks_cfg = lightning_config.callbacks |
| else: |
| callbacks_cfg = OmegaConf.create() |
| callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) |
|
|
| return callbacks_cfg |
|
|
|
|
| def get_trainer_logger(lightning_config, logdir, on_debug): |
| default_logger_cfgs = { |
| "tensorboard": { |
| "target": "pytorch_lightning.loggers.TensorBoardLogger", |
| "params": { |
| "save_dir": logdir, |
| "name": "tensorboard", |
| }, |
| }, |
| "testtube": { |
| "target": "pytorch_lightning.loggers.CSVLogger", |
| "params": { |
| "name": "testtube", |
| "save_dir": logdir, |
| }, |
| }, |
| } |
| os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True) |
| default_logger_cfg = default_logger_cfgs["tensorboard"] |
| if "logger" in lightning_config: |
| logger_cfg = lightning_config.logger |
| else: |
| logger_cfg = OmegaConf.create() |
| logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) |
| return logger_cfg |
|
|
|
|
| def get_trainer_strategy(lightning_config): |
| default_strategy_dict = { |
| "target": "pytorch_lightning.strategies.DDPShardedStrategy" |
| } |
| if "strategy" in lightning_config: |
| strategy_cfg = lightning_config.strategy |
| return strategy_cfg |
| else: |
| strategy_cfg = OmegaConf.create() |
|
|
| strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg) |
| return strategy_cfg |
|
|
|
|
| def load_checkpoints(model, model_cfg): |
| |
| if check_config_attribute(model_cfg, "adapter_only"): |
| pretrained_ckpt = model_cfg.pretrained_checkpoint |
| assert os.path.exists(pretrained_ckpt), ( |
| "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt |
| ) |
| mainlogger.info( |
| ">>> Load weights from pretrained checkpoint (training adapter only)" |
| ) |
| print(f"Loading model from {pretrained_ckpt}") |
| |
| state_dict = torch.load(pretrained_ckpt, map_location=f"cpu") |
| if "state_dict" in list(state_dict.keys()): |
| state_dict = state_dict["state_dict"] |
| else: |
| |
| dp_state_dict = OrderedDict() |
| for key in state_dict["module"].keys(): |
| dp_state_dict[key[16:]] = state_dict["module"][key] |
| state_dict = dp_state_dict |
| model.load_state_dict(state_dict, strict=False) |
| model.empty_paras = None |
| return model |
| empty_paras = None |
|
|
| if check_config_attribute(model_cfg, "pretrained_checkpoint"): |
| pretrained_ckpt = model_cfg.pretrained_checkpoint |
| assert os.path.exists(pretrained_ckpt), ( |
| "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt |
| ) |
| mainlogger.info(">>> Load weights from pretrained checkpoint") |
| |
| print("Loading model from {pretrained_ckpt}") |
| pl_sd = torch.load(pretrained_ckpt, map_location="cpu") |
| try: |
| if "state_dict" in pl_sd.keys(): |
| model.load_state_dict(pl_sd["state_dict"]) |
| else: |
| |
| new_pl_sd = OrderedDict() |
| for key in pl_sd["module"].keys(): |
| new_pl_sd[key[16:]] = pl_sd["module"][key] |
| model.load_state_dict(new_pl_sd) |
| except: |
| model.load_state_dict(pl_sd) |
| else: |
| empty_paras = None |
|
|
| |
| model.empty_paras = empty_paras |
| return model |
|
|
|
|
| def get_autoresume_path(logdir): |
| ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") |
| if os.path.exists(ckpt): |
| try: |
| tmp = torch.load(ckpt, map_location="cpu") |
| e = tmp["epoch"] |
| gs = tmp["global_step"] |
| mainlogger.info(f"[INFO] Resume from epoch {e}, global step {gs}!") |
| del tmp |
| except: |
| try: |
| mainlogger.info("Load last.ckpt failed!") |
| ckpts = sorted( |
| [ |
| f |
| for f in os.listdir(os.path.join(logdir, "checkpoints")) |
| if not os.path.isdir(f) |
| ] |
| ) |
| mainlogger.info(f"all avaible checkpoints: {ckpts}") |
| ckpts.remove("last.ckpt") |
| if "trainstep_checkpoints" in ckpts: |
| ckpts.remove("trainstep_checkpoints") |
| ckpt_path = ckpts[-1] |
| ckpt = os.path.join(logdir, "checkpoints", ckpt_path) |
| mainlogger.info(f"Select resuming ckpt: {ckpt}") |
| except ValueError: |
| mainlogger.info("Load last.ckpt failed! and there is no other ckpts") |
|
|
| resume_checkpt_path = ckpt |
| mainlogger.info(f"[INFO] resume from: {ckpt}") |
| else: |
| resume_checkpt_path = None |
| mainlogger.info( |
| f"[INFO] no checkpoint found in current workspace: {os.path.join(logdir, 'checkpoints')}" |
| ) |
|
|
| return resume_checkpt_path |
|
|
|
|
| def set_logger(logfile, name="mainlogger"): |
| logger = logging.getLogger(name) |
| logger.setLevel(logging.INFO) |
| fh = logging.FileHandler(logfile, mode="w") |
| fh.setLevel(logging.INFO) |
| ch = logging.StreamHandler() |
| ch.setLevel(logging.DEBUG) |
| fh.setFormatter(logging.Formatter("%(asctime)s-%(levelname)s: %(message)s")) |
| ch.setFormatter(logging.Formatter("%(message)s")) |
| logger.addHandler(fh) |
| logger.addHandler(ch) |
| return logger |
|
|