| import importlib |
| from argparse import ArgumentParser |
| from omegaconf import OmegaConf |
| from os.path import join as pjoin |
| import os |
| import glob |
|
|
|
|
| def get_module_config(cfg, filepath="./configs"): |
| """ |
| Load yaml config files from subfolders |
| """ |
|
|
| yamls = glob.glob(pjoin(filepath, '*', '*.yaml')) |
| yamls = [y.replace(filepath, '') for y in yamls] |
| for yaml in yamls: |
| nodes = yaml.replace('.yaml', '').replace('/', '.') |
| nodes = nodes[1:] if nodes[0] == '.' else nodes |
| OmegaConf.update(cfg, nodes, OmegaConf.load('./configs' + yaml)) |
|
|
| return cfg |
|
|
|
|
| def get_obj_from_str(string, reload=False): |
| """ |
| Get object from string |
| """ |
|
|
| module, cls = string.rsplit(".", 1) |
| if reload: |
| module_imp = importlib.import_module(module) |
| importlib.reload(module_imp) |
| return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
| def instantiate_from_config(config): |
| """ |
| Instantiate object from config |
| """ |
| if not "target" in config: |
| raise KeyError("Expected key `target` to instantiate.") |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
|
|
| def resume_config(cfg: OmegaConf): |
| """ |
| Resume model and wandb |
| """ |
| |
| if cfg.TRAIN.RESUME: |
| resume = cfg.TRAIN.RESUME |
| if os.path.exists(resume): |
| |
| cfg.TRAIN.PRETRAINED = pjoin(resume, "checkpoints", "last.ckpt") |
| |
| wandb_files = os.listdir(pjoin(resume, "wandb", "latest-run")) |
| wandb_run = [item for item in wandb_files if "run-" in item][0] |
| cfg.LOGGER.WANDB.params.id = wandb_run.replace("run-","").replace(".wandb", "") |
| else: |
| raise ValueError("Resume path is not right.") |
|
|
| return cfg |
|
|
| def parse_args(phase="train"): |
| """ |
| Parse arguments and load config files |
| """ |
|
|
| parser = ArgumentParser() |
| group = parser.add_argument_group("Training options") |
|
|
| |
| group.add_argument( |
| "--cfg_assets", |
| type=str, |
| required=False, |
| default="./configs/assets.yaml", |
| help="config file for asset paths", |
| ) |
|
|
| |
| if phase in ["train", "test"]: |
| cfg_defualt = "./configs/default.yaml" |
| elif phase == "render": |
| cfg_defualt = "./configs/render.yaml" |
| elif phase == "webui": |
| cfg_defualt = "./configs/webui.yaml" |
| |
| group.add_argument( |
| "--cfg", |
| type=str, |
| required=False, |
| default=cfg_defualt, |
| help="config file", |
| ) |
|
|
| |
| if phase in ["train", "test"]: |
| group.add_argument("--batch_size", |
| type=int, |
| required=False, |
| help="training batch size") |
| group.add_argument("--num_nodes", |
| type=int, |
| required=False, |
| help="number of nodes") |
| group.add_argument("--device", |
| type=int, |
| nargs="+", |
| required=False, |
| help="training device") |
| group.add_argument("--task", |
| type=str, |
| required=False, |
| help="evaluation task type") |
| group.add_argument("--nodebug", |
| action="store_true", |
| required=False, |
| help="debug or not") |
|
|
|
|
| if phase == "demo": |
| group.add_argument( |
| "--example", |
| type=str, |
| required=False, |
| help="input text and lengths with txt format", |
| ) |
| group.add_argument( |
| "--out_dir", |
| type=str, |
| required=False, |
| help="output dir", |
| ) |
| group.add_argument("--task", |
| type=str, |
| required=False, |
| help="evaluation task type") |
|
|
| if phase == "render": |
| group.add_argument("--npy", |
| type=str, |
| required=False, |
| default=None, |
| help="npy motion files") |
| group.add_argument("--dir", |
| type=str, |
| required=False, |
| default=None, |
| help="npy motion folder") |
| group.add_argument("--fps", |
| type=int, |
| required=False, |
| default=30, |
| help="render fps") |
| group.add_argument( |
| "--mode", |
| type=str, |
| required=False, |
| default="sequence", |
| help="render target: video, sequence, frame", |
| ) |
|
|
| params = parser.parse_args() |
| |
| |
| OmegaConf.register_new_resolver("eval", eval) |
| cfg_assets = OmegaConf.load(params.cfg_assets) |
| cfg_base = OmegaConf.load(pjoin(cfg_assets.CONFIG_FOLDER, 'default.yaml')) |
| cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg)) |
| if not cfg_exp.FULL_CONFIG: |
| cfg_exp = get_module_config(cfg_exp, cfg_assets.CONFIG_FOLDER) |
| cfg = OmegaConf.merge(cfg_exp, cfg_assets) |
|
|
| |
| if phase in ["train", "test"]: |
| cfg.TRAIN.BATCH_SIZE = params.batch_size if params.batch_size else cfg.TRAIN.BATCH_SIZE |
| cfg.DEVICE = params.device if params.device else cfg.DEVICE |
| cfg.NUM_NODES = params.num_nodes if params.num_nodes else cfg.NUM_NODES |
| cfg.model.params.task = params.task if params.task else cfg.model.params.task |
| cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG |
|
|
| |
| if phase == "test": |
| cfg.DEBUG = False |
| cfg.DEVICE = [0] |
| print("Force no debugging and one gpu when testing") |
|
|
| if phase == "demo": |
| cfg.DEMO.RENDER = params.render |
| cfg.DEMO.FRAME_RATE = params.frame_rate |
| cfg.DEMO.EXAMPLE = params.example |
| cfg.DEMO.TASK = params.task |
| cfg.TEST.FOLDER = params.out_dir if params.out_dir else cfg.TEST.FOLDER |
| os.makedirs(cfg.TEST.FOLDER, exist_ok=True) |
|
|
| if phase == "render": |
| if params.npy: |
| cfg.RENDER.NPY = params.npy |
| cfg.RENDER.INPUT_MODE = "npy" |
| if params.dir: |
| cfg.RENDER.DIR = params.dir |
| cfg.RENDER.INPUT_MODE = "dir" |
| if params.fps: |
| cfg.RENDER.FPS = float(params.fps) |
| cfg.RENDER.MODE = params.mode |
|
|
| |
| if cfg.DEBUG: |
| cfg.NAME = "debug--" + cfg.NAME |
| cfg.LOGGER.WANDB.params.offline = True |
| cfg.LOGGER.VAL_EVERY_STEPS = 1 |
| |
| |
| cfg = resume_config(cfg) |
|
|
| return cfg |
|
|