PEAR / utils /general_utils.py
BestWJH's picture
Upload 205 files
2c68f56 verified
import os,math,yaml,time,random,glob
from datetime import datetime
import logging
from zoneinfo import ZoneInfo
from omegaconf import OmegaConf
import colored
from rich.progress import BarColumn, Progress, ProgressColumn, Text, TimeElapsedColumn, TimeRemainingColumn, filesize
from tqdm.std import tqdm as std_tqdm
import numpy as np
import torch
class ConfigDict(dict):
def __init__(self, model_config_path=None, data_config_path=None, init_dict=None):
if init_dict is None:
# build new config
config_dict = read_config(model_config_path)
if data_config_path is not None:
dataset_dict = read_config(data_config_path)
merge_a_into_b(dataset_dict, config_dict)
# set output path
experiment_string = '{}_{}'.format(
config_dict['MODEL']['NAME'], config_dict['DATASET']['NAME']
)
timeInTokyo = datetime.now()
timeInTokyo = timeInTokyo.astimezone(ZoneInfo('Asia/Tokyo'))
time_string = timeInTokyo.strftime("%b%d_%H%M_")+ \
"".join(random.sample('zyxwvutsrqponmlkjihgfedcba',5))
config_dict['TRAIN']['EXP_STR'] = experiment_string
config_dict['TRAIN']['TIME_STR'] = time_string
else:
config_dict = init_dict
super().__init__(config_dict)
self._dot_config = OmegaConf.create(dict(self))
OmegaConf.set_readonly(self._dot_config, True)
def __getattr__(self, name):
if name == '_dump':
return dict(self)
if name == '_raw_string':
import re
ansi_escape = re.compile(r'''
\x1B # ESC
(?: # 7-bit C1 Fe (except CSI)
[@-Z\\-_]
| # or [ for CSI, followed by a control sequence
\[
[0-?]* # Parameter bytes
[ -/]* # Intermediate bytes
[@-~] # Final byte
)
''', re.VERBOSE)
result = '\n' + ansi_escape.sub('', pretty_dict(self))
return result
return getattr(self._dot_config, name)
def __str__(self, ):
return pretty_dict(self)
def update(self, key, value):
OmegaConf.set_readonly(self._dot_config, False)
self._dot_config[key] = value
self[key] = value
OmegaConf.set_readonly(self._dot_config, True)
def add_extra_cfgs(meta_cfg):
# add extra config options
OmegaConf.set_readonly(meta_cfg, False)
if 'with_smplx_gaussian' not in meta_cfg.MODEL.keys():
meta_cfg.MODEL['with_smplx_gaussian'] = True
OmegaConf.set_readonly(meta_cfg, True)
return meta_cfg
def read_config(path):
if not os.path.exists(path):
raise FileNotFoundError(f"{path} was not found.")
with open(path) as f:
config = yaml.load(f, Loader=yaml.Loader)
return config
def merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
def pretty_dict(input_dict, indent=0, highlight_keys=[]):
out_line = ""
tab = " "
for key, value in input_dict.items():
if key in highlight_keys:
out_line += tab * indent + colored.stylize(str(key), colored.fg(1))
else:
out_line += tab * indent + colored.stylize(str(key), colored.fg(2))
if isinstance(value, dict):
out_line += ':\n'
out_line += pretty_dict(value, indent+1, highlight_keys)
else:
if key in highlight_keys:
out_line += ":" + "\t" + colored.stylize(str(value), colored.fg(1)) + '\n'
else:
out_line += ":" + "\t" + colored.stylize(str(value), colored.fg(2)) + '\n'
if indent == 0:
max_length = 0
for line in out_line.split('\n'):
max_length = max(max_length, len(line.split('\t')[0]))
max_length += 4
aligned_line = ""
for line in out_line.split('\n'):
if '\t' in line:
aligned_number = max_length - len(line.split('\t')[0])
line = line.replace('\t', aligned_number * ' ')
aligned_line += line+'\n'
return aligned_line[:-2]
return out_line
class rtqdm(std_tqdm): # pragma: no cover
"""Experimental rich.progress GUI version of tqdm!"""
# TODO: @classmethod: write()?
def __init__(self, *args, **kwargs):
"""
This class accepts the following parameters *in addition* to
the parameters accepted by `tqdm`.
Parameters
----------
progress : tuple, optional
arguments for `rich.progress.Progress()`.
options : dict, optional
keyword arguments for `rich.progress.Progress()`.
"""
kwargs = kwargs.copy()
kwargs['gui'] = True
# convert disable = None to False
kwargs['disable'] = bool(kwargs.get('disable', False))
progress = kwargs.pop('progress', None)
options = kwargs.pop('options', {}).copy()
super(rtqdm, self).__init__(*args, **kwargs)
if self.disable:
return
# warn("rich is experimental/alpha", TqdmExperimentalWarning, stacklevel=2)
d = self.format_dict
if progress is None:
progress = (
"[progress.description]"
"[progress.percentage]{task.percentage:>4.0f}%",
BarColumn(bar_width=66),
FractionColumn(unit_scale=d['unit_scale'], unit_divisor=d['unit_divisor']),
"[",
TimeElapsedColumn(), "<", TimeRemainingColumn(), ",",
RateColumn(unit=d['unit'], unit_scale=d['unit_scale'], unit_divisor=d['unit_divisor']),
"{task.description}",
"]",
)
options.setdefault('transient', not self.leave)
self._prog = Progress(*progress, **options)
self._prog.__enter__()
self._task_id = self._prog.add_task(self.desc or "", **d)
def close(self):
if self.disable:
return
super(rtqdm, self).close()
self._prog.__exit__(None, None, None)
def clear(self, *_, **__):
pass
def set_postfix(self, desc):
desc_str = ", "+" , ".join([
colored.stylize(str(f"{k}"), colored.fg(3)) + " = " +
colored.stylize(str(f"{v}"), colored.fg(4))
for k, v in desc.items()]
)
self.desc = desc_str
self.display()
def display(self, *_, **__):
if not hasattr(self, '_prog'):
return
self._prog.update(self._task_id, completed=self.n, description=self.desc)
def reset(self, total=None):
"""
Resets to 0 iterations for repeated use.
Parameters
----------
total : int or float, optional. Total to use for the new bar.
"""
if hasattr(self, '_prog'):
self._prog.reset(total=total)
super(rtqdm, self).reset(total=total)
class FractionColumn(ProgressColumn):
"""Renders completed/total, e.g. '0.5/2.3 G'."""
def __init__(self, unit_scale=False, unit_divisor=1000):
self.unit_scale = unit_scale
self.unit_divisor = unit_divisor
super().__init__()
def render(self, task):
"""Calculate common unit for completed and total."""
completed = int(task.completed)
total = int(task.total)
if self.unit_scale:
unit, suffix = filesize.pick_unit_and_suffix(
total,
["", "K", "M", "G", "T", "P", "E", "Z", "Y"],
self.unit_divisor,
)
else:
unit, suffix = filesize.pick_unit_and_suffix(total, [""], 1)
precision = 0 if unit == 1 else 1
return Text(
f"{completed/unit:,.{precision}f}/{total/unit:,.{precision}f} {suffix}",
style="progress.download")
class RateColumn(ProgressColumn):
"""Renders human readable transfer speed."""
def __init__(self, unit="", unit_scale=False, unit_divisor=1000):
self.unit = unit
self.unit_scale = unit_scale
self.unit_divisor = unit_divisor
super().__init__()
def render(self, task):
"""Show data transfer speed."""
speed = task.speed
if speed is None:
return Text(f"? {self.unit}/s", style="progress.data.speed")
if self.unit_scale:
unit, suffix = filesize.pick_unit_and_suffix(
speed,
["", "K", "M", "G", "T", "P", "E", "Z", "Y"],
self.unit_divisor,
)
else:
unit, suffix = filesize.pick_unit_and_suffix(speed, [""], 1)
precision = 0 if unit == 1 else 1
return Text(f"{speed/unit:,.{precision}f} {suffix}{self.unit}/s",
style="progress.data.speed")
def device_parser(str_device):
def parser_dash(str_device):
device_id = str_device.split('-')
device_id = [i for i in range(int(device_id[0]), int(device_id[-1])+1)]
return device_id
if 'cpu' in str_device:
device_id = ['cpu']
else:
device_id = str_device.split(',')
device_id = [parser_dash(i) for i in device_id]
res = []
for i in device_id:
res += i
return res
def device_parser(str_device):
def parser_dash(str_device):
device_id = str_device.split('-')
device_id = [i for i in range(int(device_id[0]), int(device_id[-1])+1)]
return device_id
if 'cpu' in str_device:
device_id = ['cpu']
else:
device_id = str_device.split(',')
device_id = [parser_dash(i) for i in device_id]
res = []
for i in device_id:
res += i
return res
def calc_parameters(models):
op_para_nums=0
all_para_nums=0
for model in models:
op_para_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_para_num = sum(p.numel() for p in model.parameters())
op_para_nums += op_para_num
all_para_nums += all_para_num
return op_para_nums, all_para_nums
def biuld_logger(log_path, name='test_logger'):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
if not os.path.exists(os.path.dirname(log_path)):
os.makedirs(os.path.dirname(log_path))
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
def find_pt_file(base_path, prefix):
pt_files = glob.glob(os.path.join(base_path, f"{prefix}*.pt"))
if pt_files:
return max(pt_files, key=os.path.getmtime)
return None
def to8b(img):
return (255 * np.clip(img, 0, 1)).astype(np.uint8)
def inverse_sigmoid(x):
return torch.log(x/(1-x))