File size: 2,667 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import hydra
import pytorch_lightning as pl
import rich
import rich.syntax
import rich.tree
import os
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
from ppd.utils.logger import Log, monitor_process_wrapper


@monitor_process_wrapper
def get_data(cfg: DictConfig, wo_train: bool = False) -> pl.LightningDataModule:
    datamodule = hydra.utils.instantiate(cfg.data, wo_train=wo_train, _recursive_=False)
    return datamodule


@monitor_process_wrapper
def get_model(cfg: DictConfig) -> pl.LightningModule:
    model = hydra.utils.instantiate(cfg.model, _recursive_=False)
    return model


@monitor_process_wrapper
def get_callbacks(cfg: DictConfig) -> list:
    if not hasattr(cfg, "callbacks"):
        return None
    callbacks = []
    for callback in cfg.callbacks.values():
        if callback is not None:
            callbacks.append(hydra.utils.instantiate(callback, _recursive_=False))
    return callbacks


@rank_zero_only
def print_cfg(cfg: DictConfig, use_rich: bool = False):
    if use_rich:
        print_order = ("data", "model", "callbacks", "logger", "pl_trainer", "exp")
        style = "dim"
        tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)

        # add fields from `print_order` to queue
        # add all the other fields to queue (not specified in `print_order`)
        queue = []
        for field in print_order:
            queue.append(field) if field in cfg else Log.warn(f"Field '{field}' not found in config. Skipping.")
        for field in cfg:
            if field not in queue:
                queue.append(field)

        # generate config tree from queue
        for field in queue:
            branch = tree.add(field, style=style, guide_style=style)
            config_group = cfg[field]
            if isinstance(config_group, DictConfig):
                branch_content = OmegaConf.to_yaml(config_group, resolve=False)
            else:
                branch_content = str(config_group)
            branch.add(rich.syntax.Syntax(branch_content, "yaml"))
        rich.print(tree)
    else:
        Log.info(OmegaConf.to_yaml(cfg, resolve=False))

def find_last_ckpt_path(dirpath):
    """
    Assume ckpt is named as e{}* or last*, following the convention of pytorch-lightning.
    """
    dirpath = Path(dirpath)
    model_paths = []
    for p in sorted(list(dirpath.glob("*.ckpt"))):
        if "last" in p.name:
            continue
        model_paths.append(p)
    if len(model_paths) > 0:
        return model_paths[-1]
    else:
        Log.info("No checkpoint found, set model_path to None")
        return None