Spaces:
Running
Running
| import os | |
| import os.path as osp | |
| import sys | |
| import time | |
| import yaml | |
| import imageio | |
| import random | |
| import shutil | |
| import random | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| class ConfigParser(): | |
| def __init__(self, args): | |
| """ | |
| class to parse configuration. | |
| """ | |
| args = args.parse_args() | |
| self.cfg = self.merge_config_file(args) | |
| # set random seed | |
| self.set_seed() | |
| def __str__(self): | |
| return str(self.cfg.__dict__) | |
| def __getattr__(self, name): | |
| """ | |
| Access items use dot.notation. | |
| """ | |
| return self.cfg.__dict__[name] | |
| def __getitem__(self, name): | |
| """ | |
| Access items like ordinary dict. | |
| """ | |
| return self.cfg.__dict__[name] | |
| def merge_config_file(self, args, allow_invalid=True): | |
| """ | |
| Load json config file and merge the arguments | |
| """ | |
| assert args.config is not None | |
| with open(args.config, 'r') as f: | |
| cfg = yaml.safe_load(f) | |
| if 'config' in cfg.keys(): | |
| del cfg['config'] | |
| f.close() | |
| invalid_args = list(set(cfg.keys()) - set(dir(args))) | |
| if invalid_args and not allow_invalid: | |
| raise ValueError(f"Invalid args {invalid_args} in {args.config}.") | |
| for k in list(cfg.keys()): | |
| if k in args.__dict__.keys() and args.__dict__[k] is not None: | |
| print('=========> overwrite config: {} = {}'.format(k, args.__dict__[k])) | |
| del cfg[k] | |
| args.__dict__.update(cfg) | |
| return args | |
| def set_seed(self): | |
| ''' set random seed for random, numpy and torch. ''' | |
| if 'seed' not in self.cfg.__dict__.keys(): | |
| return | |
| if self.cfg.seed is None: | |
| self.cfg.seed = int(time.time()) % 1000000 | |
| print('=========> set random seed: {}'.format(self.cfg.seed)) | |
| # fix random seeds for reproducibility | |
| random.seed(self.cfg.seed) | |
| np.random.seed(self.cfg.seed) | |
| torch.manual_seed(self.cfg.seed) | |
| torch.cuda.manual_seed(self.cfg.seed) | |
| def save_codes_and_config(self, save_path): | |
| """ | |
| save codes and config to $save_path. | |
| """ | |
| cur_codes_path = osp.dirname(osp.dirname(os.path.abspath(__file__))) | |
| if os.path.exists(save_path): | |
| shutil.rmtree(save_path) | |
| shutil.copytree(cur_codes_path, osp.join(save_path, 'codes'), \ | |
| ignore=shutil.ignore_patterns('*debug*', '*data*', '*output*', '*exps*', '*.txt', '*.json', '*.mp4', '*.png', '*.jpg', '*.bvh', '*.csv', '*.pth', '*.tar', '*.npz')) | |
| with open(osp.join(save_path, 'config.yaml'), 'w') as f: | |
| f.write(yaml.dump(self.cfg.__dict__)) | |
| f.close() | |
| # other utils | |
| class logger: | |
| """Keeps track of the levels and steps of optimization. Logs it via TQDM""" | |
| def __init__(self, n_steps, n_lvls): | |
| self.n_steps = n_steps | |
| self.n_lvls = n_lvls | |
| self.lvl = -1 | |
| self.lvl_step = 0 | |
| self.steps = 0 | |
| self.pbar = tqdm(total=self.n_lvls * self.n_steps, desc='Starting') | |
| def step(self): | |
| self.pbar.update(1) | |
| self.steps += 1 | |
| self.lvl_step += 1 | |
| def new_lvl(self): | |
| self.lvl += 1 | |
| self.lvl_step = 0 | |
| def print(self): | |
| self.pbar.set_description(f'Lvl {self.lvl}/{self.n_lvls-1}, step {self.lvl_step}/{self.n_steps}') | |
| def set_seed(seed): | |
| if seed is not None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| # debug utils | |
| def draw_trajectory(trajectory, save_path=None, anim=True): | |
| r = max(abs(trajectory.min()), trajectory.max()) | |
| if anim: | |
| imgs = [] | |
| for i in tqdm(range(1, trajectory.shape[0])): | |
| plt.plot(trajectory[:i, 0], trajectory[:i, 2], color='red') | |
| plt.xlim(-r-1, r+1) | |
| plt.ylim(-r-1, r+1) | |
| plt.savefig(save_path + '.png') | |
| imgs += [imageio.imread(save_path + '.png')] | |
| imageio.mimwrite(save_path + '.mp4', imgs) | |
| plt.close() | |
| else: | |
| # plt.scatter(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]) | |
| plt.plot(trajectory[:, 0], trajectory[:, 2], color='red') | |
| plt.xlim(-r*1.5, r*1.5) | |
| plt.ylim(-r*1.5, r*1.5) | |
| if save_path is not None: | |
| plt.savefig(save_path + '.png') | |
| plt.close() | |
| # velo = self.raw_motion[0, self.mask, :].numpy() | |
| # print(velo.shape) | |
| # imgs = [] | |