| import collections |
| import functools |
| import os |
| import re |
|
|
| import yaml |
|
|
| class AttrDict(dict): |
| """Dict as attribute trick.""" |
|
|
| def __init__(self, *args, **kwargs): |
| super(AttrDict, self).__init__(*args, **kwargs) |
| self.__dict__ = self |
| for key, value in self.__dict__.items(): |
| if isinstance(value, dict): |
| self.__dict__[key] = AttrDict(value) |
| elif isinstance(value, (list, tuple)): |
| if isinstance(value[0], dict): |
| self.__dict__[key] = [AttrDict(item) for item in value] |
| else: |
| self.__dict__[key] = value |
|
|
| def yaml(self): |
| """Convert object to yaml dict and return.""" |
| yaml_dict = {} |
| for key, value in self.__dict__.items(): |
| if isinstance(value, AttrDict): |
| yaml_dict[key] = value.yaml() |
| elif isinstance(value, list): |
| if isinstance(value[0], AttrDict): |
| new_l = [] |
| for item in value: |
| new_l.append(item.yaml()) |
| yaml_dict[key] = new_l |
| else: |
| yaml_dict[key] = value |
| else: |
| yaml_dict[key] = value |
| return yaml_dict |
|
|
| def __repr__(self): |
| """Print all variables.""" |
| ret_str = [] |
| for key, value in self.__dict__.items(): |
| if isinstance(value, AttrDict): |
| ret_str.append('{}:'.format(key)) |
| child_ret_str = value.__repr__().split('\n') |
| for item in child_ret_str: |
| ret_str.append(' ' + item) |
| elif isinstance(value, list): |
| if isinstance(value[0], AttrDict): |
| ret_str.append('{}:'.format(key)) |
| for item in value: |
| |
| child_ret_str = item.__repr__().split('\n') |
| for item in child_ret_str: |
| ret_str.append(' ' + item) |
| else: |
| ret_str.append('{}: {}'.format(key, value)) |
| else: |
| ret_str.append('{}: {}'.format(key, value)) |
| return '\n'.join(ret_str) |
|
|
|
|
| class Config(AttrDict): |
| r"""Configuration class. This should include every human specifiable |
| hyperparameter values for your training.""" |
|
|
| def __init__(self, filename=None, args=None, verbose=False, is_train=True): |
| super(Config, self).__init__() |
| |
| |
|
|
| large_number = 1000000000 |
| self.snapshot_save_iter = large_number |
| self.snapshot_save_epoch = large_number |
| self.snapshot_save_start_iter = 0 |
| self.snapshot_save_start_epoch = 0 |
| self.image_save_iter = large_number |
| self.eval_epoch = large_number |
| self.start_eval_epoch = large_number |
| self.eval_epoch = large_number |
| self.max_epoch = large_number |
| self.max_iter = large_number |
| self.logging_iter = 100 |
| self.image_to_tensorboard=False |
| self.which_iter = 0 |
| self.resume = False |
|
|
| self.checkpoints_dir = '/Users/shadowcun/Downloads/' |
| self.name = 'face' |
| self.phase = 'train' if is_train else 'test' |
|
|
| |
| self.gen = AttrDict(type='generators.dummy') |
| self.dis = AttrDict(type='discriminators.dummy') |
|
|
| |
| self.gen_optimizer = AttrDict(type='adam', |
| lr=0.0001, |
| adam_beta1=0.0, |
| adam_beta2=0.999, |
| eps=1e-8, |
| lr_policy=AttrDict(iteration_mode=False, |
| type='step', |
| step_size=large_number, |
| gamma=1)) |
| self.dis_optimizer = AttrDict(type='adam', |
| lr=0.0001, |
| adam_beta1=0.0, |
| adam_beta2=0.999, |
| eps=1e-8, |
| lr_policy=AttrDict(iteration_mode=False, |
| type='step', |
| step_size=large_number, |
| gamma=1)) |
| |
| self.data = AttrDict(name='dummy', |
| type='datasets.images', |
| num_workers=0) |
| self.test_data = AttrDict(name='dummy', |
| type='datasets.images', |
| num_workers=0, |
| test=AttrDict(is_lmdb=False, |
| roots='', |
| batch_size=1)) |
| self.trainer = AttrDict( |
| model_average=False, |
| model_average_beta=0.9999, |
| model_average_start_iteration=1000, |
| model_average_batch_norm_estimation_iteration=30, |
| model_average_remove_sn=True, |
| image_to_tensorboard=False, |
| hparam_to_tensorboard=False, |
| distributed_data_parallel='pytorch', |
| delay_allreduce=True, |
| gan_relativistic=False, |
| gen_step=1, |
| dis_step=1) |
|
|
| |
| self.cudnn = AttrDict(deterministic=False, |
| benchmark=True) |
|
|
| |
| self.pretrained_weight = '' |
| self.inference_args = AttrDict() |
|
|
|
|
| |
| assert os.path.exists(filename), 'File {} not exist.'.format(filename) |
| loader = yaml.SafeLoader |
| loader.add_implicit_resolver( |
| u'tag:yaml.org,2002:float', |
| re.compile(u'''^(?: |
| [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |
| |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |
| |\\.[0-9_]+(?:[eE][-+][0-9]+)? |
| |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |
| |[-+]?\\.(?:inf|Inf|INF) |
| |\\.(?:nan|NaN|NAN))$''', re.X), |
| list(u'-+0123456789.')) |
| try: |
| with open(filename, 'r') as f: |
| cfg_dict = yaml.load(f, Loader=loader) |
| except EnvironmentError: |
| print('Please check the file with name of "%s"', filename) |
| recursive_update(self, cfg_dict) |
|
|
| |
| if 'common' in cfg_dict: |
| self.common = AttrDict(**cfg_dict['common']) |
| self.gen.common = self.common |
| self.dis.common = self.common |
|
|
|
|
| if verbose: |
| print(' config '.center(80, '-')) |
| print(self.__repr__()) |
| print(''.center(80, '-')) |
|
|
|
|
| def rsetattr(obj, attr, val): |
| """Recursively find object and set value""" |
| pre, _, post = attr.rpartition('.') |
| return setattr(rgetattr(obj, pre) if pre else obj, post, val) |
|
|
|
|
| def rgetattr(obj, attr, *args): |
| """Recursively find object and return value""" |
|
|
| def _getattr(obj, attr): |
| r"""Get attribute.""" |
| return getattr(obj, attr, *args) |
|
|
| return functools.reduce(_getattr, [obj] + attr.split('.')) |
|
|
|
|
| def recursive_update(d, u): |
| """Recursively update AttrDict d with AttrDict u""" |
| for key, value in u.items(): |
| if isinstance(value, collections.abc.Mapping): |
| d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) |
| elif isinstance(value, (list, tuple)): |
| if isinstance(value[0], dict): |
| d.__dict__[key] = [AttrDict(item) for item in value] |
| else: |
| d.__dict__[key] = value |
| else: |
| d.__dict__[key] = value |
| return d |
|
|