File size: 2,667 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import hydra
import pytorch_lightning as pl
import rich
import rich.syntax
import rich.tree
import os
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
from ppd.utils.logger import Log, monitor_process_wrapper
@monitor_process_wrapper
def get_data(cfg: DictConfig, wo_train: bool = False) -> pl.LightningDataModule:
datamodule = hydra.utils.instantiate(cfg.data, wo_train=wo_train, _recursive_=False)
return datamodule
@monitor_process_wrapper
def get_model(cfg: DictConfig) -> pl.LightningModule:
model = hydra.utils.instantiate(cfg.model, _recursive_=False)
return model
@monitor_process_wrapper
def get_callbacks(cfg: DictConfig) -> list:
if not hasattr(cfg, "callbacks"):
return None
callbacks = []
for callback in cfg.callbacks.values():
if callback is not None:
callbacks.append(hydra.utils.instantiate(callback, _recursive_=False))
return callbacks
@rank_zero_only
def print_cfg(cfg: DictConfig, use_rich: bool = False):
if use_rich:
print_order = ("data", "model", "callbacks", "logger", "pl_trainer", "exp")
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
# add fields from `print_order` to queue
# add all the other fields to queue (not specified in `print_order`)
queue = []
for field in print_order:
queue.append(field) if field in cfg else Log.warn(f"Field '{field}' not found in config. Skipping.")
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, style=style, guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = OmegaConf.to_yaml(config_group, resolve=False)
else:
branch_content = str(config_group)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree)
else:
Log.info(OmegaConf.to_yaml(cfg, resolve=False))
def find_last_ckpt_path(dirpath):
"""
Assume ckpt is named as e{}* or last*, following the convention of pytorch-lightning.
"""
dirpath = Path(dirpath)
model_paths = []
for p in sorted(list(dirpath.glob("*.ckpt"))):
if "last" in p.name:
continue
model_paths.append(p)
if len(model_paths) > 0:
return model_paths[-1]
else:
Log.info("No checkpoint found, set model_path to None")
return None |