| |
|
| | import os,math,yaml,time,random,glob |
| | from datetime import datetime |
| | import logging |
| | from zoneinfo import ZoneInfo |
| | from omegaconf import OmegaConf |
| | import colored |
| | from rich.progress import BarColumn, Progress, ProgressColumn, Text, TimeElapsedColumn, TimeRemainingColumn, filesize |
| | from tqdm.std import tqdm as std_tqdm |
| | import numpy as np |
| | import torch |
| | class ConfigDict(dict): |
| | def __init__(self, model_config_path=None, data_config_path=None, init_dict=None): |
| | if init_dict is None: |
| | |
| |
|
| | config_dict = read_config(model_config_path) |
| | if data_config_path is not None: |
| | dataset_dict = read_config(data_config_path) |
| | merge_a_into_b(dataset_dict, config_dict) |
| | |
| | experiment_string = '{}_{}'.format( |
| | config_dict['MODEL']['NAME'], config_dict['DATASET']['NAME'] |
| | ) |
| | timeInTokyo = datetime.now() |
| | timeInTokyo = timeInTokyo.astimezone(ZoneInfo('Asia/Tokyo')) |
| | time_string = timeInTokyo.strftime("%b%d_%H%M_")+ \ |
| | "".join(random.sample('zyxwvutsrqponmlkjihgfedcba',5)) |
| | config_dict['TRAIN']['EXP_STR'] = experiment_string |
| | config_dict['TRAIN']['TIME_STR'] = time_string |
| | else: |
| | config_dict = init_dict |
| | super().__init__(config_dict) |
| | self._dot_config = OmegaConf.create(dict(self)) |
| | OmegaConf.set_readonly(self._dot_config, True) |
| | |
| | def __getattr__(self, name): |
| | if name == '_dump': |
| | return dict(self) |
| | if name == '_raw_string': |
| | import re |
| | ansi_escape = re.compile(r''' |
| | \x1B # ESC |
| | (?: # 7-bit C1 Fe (except CSI) |
| | [@-Z\\-_] |
| | | # or [ for CSI, followed by a control sequence |
| | \[ |
| | [0-?]* # Parameter bytes |
| | [ -/]* # Intermediate bytes |
| | [@-~] # Final byte |
| | ) |
| | ''', re.VERBOSE) |
| | result = '\n' + ansi_escape.sub('', pretty_dict(self)) |
| | return result |
| | return getattr(self._dot_config, name) |
| |
|
| | def __str__(self, ): |
| | return pretty_dict(self) |
| |
|
| | def update(self, key, value): |
| | OmegaConf.set_readonly(self._dot_config, False) |
| | self._dot_config[key] = value |
| | self[key] = value |
| | OmegaConf.set_readonly(self._dot_config, True) |
| |
|
| | def add_extra_cfgs(meta_cfg): |
| | |
| | OmegaConf.set_readonly(meta_cfg, False) |
| | |
| | if 'with_smplx_gaussian' not in meta_cfg.MODEL.keys(): |
| | meta_cfg.MODEL['with_smplx_gaussian'] = True |
| | |
| | OmegaConf.set_readonly(meta_cfg, True) |
| | return meta_cfg |
| |
|
| | def read_config(path): |
| | if not os.path.exists(path): |
| | raise FileNotFoundError(f"{path} was not found.") |
| | with open(path) as f: |
| | config = yaml.load(f, Loader=yaml.Loader) |
| | return config |
| |
|
| | def merge_a_into_b(a, b): |
| | |
| | for k, v in a.items(): |
| | if isinstance(v, dict) and k in b: |
| | assert isinstance( |
| | b[k], dict |
| | ), "Cannot inherit key '{}' from base!".format(k) |
| | merge_a_into_b(v, b[k]) |
| | else: |
| | b[k] = v |
| |
|
| | def pretty_dict(input_dict, indent=0, highlight_keys=[]): |
| | out_line = "" |
| | tab = " " |
| | for key, value in input_dict.items(): |
| | if key in highlight_keys: |
| | out_line += tab * indent + colored.stylize(str(key), colored.fg(1)) |
| | else: |
| | out_line += tab * indent + colored.stylize(str(key), colored.fg(2)) |
| | if isinstance(value, dict): |
| | out_line += ':\n' |
| | out_line += pretty_dict(value, indent+1, highlight_keys) |
| | else: |
| | if key in highlight_keys: |
| | out_line += ":" + "\t" + colored.stylize(str(value), colored.fg(1)) + '\n' |
| | else: |
| | out_line += ":" + "\t" + colored.stylize(str(value), colored.fg(2)) + '\n' |
| | if indent == 0: |
| | max_length = 0 |
| | for line in out_line.split('\n'): |
| | max_length = max(max_length, len(line.split('\t')[0])) |
| | max_length += 4 |
| | aligned_line = "" |
| | for line in out_line.split('\n'): |
| | if '\t' in line: |
| | aligned_number = max_length - len(line.split('\t')[0]) |
| | line = line.replace('\t', aligned_number * ' ') |
| | aligned_line += line+'\n' |
| | return aligned_line[:-2] |
| | return out_line |
| |
|
| |
|
| | class rtqdm(std_tqdm): |
| | """Experimental rich.progress GUI version of tqdm!""" |
| | |
| | def __init__(self, *args, **kwargs): |
| | """ |
| | This class accepts the following parameters *in addition* to |
| | the parameters accepted by `tqdm`. |
| | |
| | Parameters |
| | ---------- |
| | progress : tuple, optional |
| | arguments for `rich.progress.Progress()`. |
| | options : dict, optional |
| | keyword arguments for `rich.progress.Progress()`. |
| | """ |
| | kwargs = kwargs.copy() |
| | kwargs['gui'] = True |
| | |
| | kwargs['disable'] = bool(kwargs.get('disable', False)) |
| | progress = kwargs.pop('progress', None) |
| | options = kwargs.pop('options', {}).copy() |
| | super(rtqdm, self).__init__(*args, **kwargs) |
| |
|
| | if self.disable: |
| | return |
| |
|
| | |
| | d = self.format_dict |
| | if progress is None: |
| | progress = ( |
| | "[progress.description]" |
| | "[progress.percentage]{task.percentage:>4.0f}%", |
| | BarColumn(bar_width=66), |
| | FractionColumn(unit_scale=d['unit_scale'], unit_divisor=d['unit_divisor']), |
| | "[", |
| | TimeElapsedColumn(), "<", TimeRemainingColumn(), ",", |
| | RateColumn(unit=d['unit'], unit_scale=d['unit_scale'], unit_divisor=d['unit_divisor']), |
| | "{task.description}", |
| | "]", |
| | ) |
| | options.setdefault('transient', not self.leave) |
| | self._prog = Progress(*progress, **options) |
| | self._prog.__enter__() |
| | self._task_id = self._prog.add_task(self.desc or "", **d) |
| |
|
| | def close(self): |
| | if self.disable: |
| | return |
| | super(rtqdm, self).close() |
| | self._prog.__exit__(None, None, None) |
| |
|
| | def clear(self, *_, **__): |
| | pass |
| |
|
| | def set_postfix(self, desc): |
| | |
| | desc_str = ", "+" , ".join([ |
| | colored.stylize(str(f"{k}"), colored.fg(3)) + " = " + |
| | colored.stylize(str(f"{v}"), colored.fg(4)) |
| | for k, v in desc.items()] |
| | ) |
| | self.desc = desc_str |
| | self.display() |
| |
|
| | def display(self, *_, **__): |
| | if not hasattr(self, '_prog'): |
| | return |
| | self._prog.update(self._task_id, completed=self.n, description=self.desc) |
| |
|
| | def reset(self, total=None): |
| | """ |
| | Resets to 0 iterations for repeated use. |
| | |
| | Parameters |
| | ---------- |
| | total : int or float, optional. Total to use for the new bar. |
| | """ |
| | if hasattr(self, '_prog'): |
| | self._prog.reset(total=total) |
| | super(rtqdm, self).reset(total=total) |
| | |
| | class FractionColumn(ProgressColumn): |
| | """Renders completed/total, e.g. '0.5/2.3 G'.""" |
| | def __init__(self, unit_scale=False, unit_divisor=1000): |
| | self.unit_scale = unit_scale |
| | self.unit_divisor = unit_divisor |
| | super().__init__() |
| |
|
| | def render(self, task): |
| | """Calculate common unit for completed and total.""" |
| | completed = int(task.completed) |
| | total = int(task.total) |
| | if self.unit_scale: |
| | unit, suffix = filesize.pick_unit_and_suffix( |
| | total, |
| | ["", "K", "M", "G", "T", "P", "E", "Z", "Y"], |
| | self.unit_divisor, |
| | ) |
| | else: |
| | unit, suffix = filesize.pick_unit_and_suffix(total, [""], 1) |
| | precision = 0 if unit == 1 else 1 |
| | return Text( |
| | f"{completed/unit:,.{precision}f}/{total/unit:,.{precision}f} {suffix}", |
| | style="progress.download") |
| |
|
| |
|
| | class RateColumn(ProgressColumn): |
| | """Renders human readable transfer speed.""" |
| | def __init__(self, unit="", unit_scale=False, unit_divisor=1000): |
| | self.unit = unit |
| | self.unit_scale = unit_scale |
| | self.unit_divisor = unit_divisor |
| | super().__init__() |
| |
|
| | def render(self, task): |
| | """Show data transfer speed.""" |
| | speed = task.speed |
| | if speed is None: |
| | return Text(f"? {self.unit}/s", style="progress.data.speed") |
| | if self.unit_scale: |
| | unit, suffix = filesize.pick_unit_and_suffix( |
| | speed, |
| | ["", "K", "M", "G", "T", "P", "E", "Z", "Y"], |
| | self.unit_divisor, |
| | ) |
| | else: |
| | unit, suffix = filesize.pick_unit_and_suffix(speed, [""], 1) |
| | precision = 0 if unit == 1 else 1 |
| | return Text(f"{speed/unit:,.{precision}f} {suffix}{self.unit}/s", |
| | style="progress.data.speed") |
| |
|
| | def device_parser(str_device): |
| | def parser_dash(str_device): |
| | device_id = str_device.split('-') |
| | device_id = [i for i in range(int(device_id[0]), int(device_id[-1])+1)] |
| | return device_id |
| | if 'cpu' in str_device: |
| | device_id = ['cpu'] |
| | else: |
| | device_id = str_device.split(',') |
| | device_id = [parser_dash(i) for i in device_id] |
| | res = [] |
| | for i in device_id: |
| | res += i |
| | return res |
| |
|
| | def device_parser(str_device): |
| | def parser_dash(str_device): |
| | device_id = str_device.split('-') |
| | device_id = [i for i in range(int(device_id[0]), int(device_id[-1])+1)] |
| | return device_id |
| | if 'cpu' in str_device: |
| | device_id = ['cpu'] |
| | else: |
| | device_id = str_device.split(',') |
| | device_id = [parser_dash(i) for i in device_id] |
| | res = [] |
| | for i in device_id: |
| | res += i |
| | return res |
| |
|
| | def calc_parameters(models): |
| | op_para_nums=0 |
| | all_para_nums=0 |
| | for model in models: |
| | op_para_num = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | all_para_num = sum(p.numel() for p in model.parameters()) |
| | op_para_nums += op_para_num |
| | all_para_nums += all_para_num |
| | return op_para_nums, all_para_nums |
| |
|
| | def biuld_logger(log_path, name='test_logger'): |
| | logger = logging.getLogger(name) |
| | logger.setLevel(logging.DEBUG) |
| | if not os.path.exists(os.path.dirname(log_path)): |
| | os.makedirs(os.path.dirname(log_path)) |
| | file_handler = logging.FileHandler(log_path) |
| | file_handler.setLevel(logging.DEBUG) |
| | console_handler = logging.StreamHandler() |
| | console_handler.setLevel(logging.INFO) |
| | formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") |
| | file_handler.setFormatter(formatter) |
| | console_handler.setFormatter(formatter) |
| | logger.addHandler(file_handler) |
| | logger.addHandler(console_handler) |
| | return logger |
| |
|
| | def find_pt_file(base_path, prefix): |
| | pt_files = glob.glob(os.path.join(base_path, f"{prefix}*.pt")) |
| | if pt_files: |
| | return max(pt_files, key=os.path.getmtime) |
| | return None |
| |
|
| | def to8b(img): |
| | return (255 * np.clip(img, 0, 1)).astype(np.uint8) |
| |
|
| | def inverse_sigmoid(x): |
| | return torch.log(x/(1-x)) |