| import hydra |
| import pytorch_lightning as pl |
| import rich |
| import rich.syntax |
| import rich.tree |
| import os |
| from pathlib import Path |
| from omegaconf import DictConfig, OmegaConf |
| from pytorch_lightning.utilities import rank_zero_only |
| from ppd.utils.logger import Log, monitor_process_wrapper |
|
|
|
|
| @monitor_process_wrapper |
| def get_data(cfg: DictConfig, wo_train: bool = False) -> pl.LightningDataModule: |
| datamodule = hydra.utils.instantiate(cfg.data, wo_train=wo_train, _recursive_=False) |
| return datamodule |
|
|
|
|
| @monitor_process_wrapper |
| def get_model(cfg: DictConfig) -> pl.LightningModule: |
| model = hydra.utils.instantiate(cfg.model, _recursive_=False) |
| return model |
|
|
|
|
| @monitor_process_wrapper |
| def get_callbacks(cfg: DictConfig) -> list: |
| if not hasattr(cfg, "callbacks"): |
| return None |
| callbacks = [] |
| for callback in cfg.callbacks.values(): |
| if callback is not None: |
| callbacks.append(hydra.utils.instantiate(callback, _recursive_=False)) |
| return callbacks |
|
|
|
|
| @rank_zero_only |
| def print_cfg(cfg: DictConfig, use_rich: bool = False): |
| if use_rich: |
| print_order = ("data", "model", "callbacks", "logger", "pl_trainer", "exp") |
| style = "dim" |
| tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) |
|
|
| |
| |
| queue = [] |
| for field in print_order: |
| queue.append(field) if field in cfg else Log.warn(f"Field '{field}' not found in config. Skipping.") |
| for field in cfg: |
| if field not in queue: |
| queue.append(field) |
|
|
| |
| for field in queue: |
| branch = tree.add(field, style=style, guide_style=style) |
| config_group = cfg[field] |
| if isinstance(config_group, DictConfig): |
| branch_content = OmegaConf.to_yaml(config_group, resolve=False) |
| else: |
| branch_content = str(config_group) |
| branch.add(rich.syntax.Syntax(branch_content, "yaml")) |
| rich.print(tree) |
| else: |
| Log.info(OmegaConf.to_yaml(cfg, resolve=False)) |
|
|
| def find_last_ckpt_path(dirpath): |
| """ |
| Assume ckpt is named as e{}* or last*, following the convention of pytorch-lightning. |
| """ |
| dirpath = Path(dirpath) |
| model_paths = [] |
| for p in sorted(list(dirpath.glob("*.ckpt"))): |
| if "last" in p.name: |
| continue |
| model_paths.append(p) |
| if len(model_paths) > 0: |
| return model_paths[-1] |
| else: |
| Log.info("No checkpoint found, set model_path to None") |
| return None |