| import logging |
| from .constants import * |
|
|
|
|
| _logger = logging.getLogger(__name__) |
|
|
|
|
| def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): |
| new_config = {} |
| default_cfg = default_cfg |
| if not default_cfg and model is not None and hasattr(model, 'default_cfg'): |
| default_cfg = model.default_cfg |
|
|
| |
| in_chans = 3 |
| if 'chans' in args and args['chans'] is not None: |
| in_chans = args['chans'] |
|
|
| input_size = (in_chans, 224, 224) |
| if 'input_size' in args and args['input_size'] is not None: |
| assert isinstance(args['input_size'], (tuple, list)) |
| assert len(args['input_size']) == 3 |
| input_size = tuple(args['input_size']) |
| in_chans = input_size[0] |
| elif 'img_size' in args and args['img_size'] is not None: |
| assert isinstance(args['img_size'], int) |
| input_size = (in_chans, args['img_size'], args['img_size']) |
| else: |
| if use_test_size and 'test_input_size' in default_cfg: |
| input_size = default_cfg['test_input_size'] |
| elif 'input_size' in default_cfg: |
| input_size = default_cfg['input_size'] |
| new_config['input_size'] = input_size |
|
|
| |
| new_config['interpolation'] = 'bicubic' |
| if 'interpolation' in args and args['interpolation']: |
| new_config['interpolation'] = args['interpolation'] |
| elif 'interpolation' in default_cfg: |
| new_config['interpolation'] = default_cfg['interpolation'] |
|
|
| |
| new_config['mean'] = IMAGENET_DEFAULT_MEAN |
| if 'mean' in args and args['mean'] is not None: |
| mean = tuple(args['mean']) |
| if len(mean) == 1: |
| mean = tuple(list(mean) * in_chans) |
| else: |
| assert len(mean) == in_chans |
| new_config['mean'] = mean |
| elif 'mean' in default_cfg: |
| new_config['mean'] = default_cfg['mean'] |
|
|
| |
| new_config['std'] = IMAGENET_DEFAULT_STD |
| if 'std' in args and args['std'] is not None: |
| std = tuple(args['std']) |
| if len(std) == 1: |
| std = tuple(list(std) * in_chans) |
| else: |
| assert len(std) == in_chans |
| new_config['std'] = std |
| elif 'std' in default_cfg: |
| new_config['std'] = default_cfg['std'] |
|
|
| |
| new_config['crop_pct'] = DEFAULT_CROP_PCT |
| if 'crop_pct' in args and args['crop_pct'] is not None: |
| new_config['crop_pct'] = args['crop_pct'] |
| elif 'crop_pct' in default_cfg: |
| new_config['crop_pct'] = default_cfg['crop_pct'] |
|
|
| if verbose: |
| _logger.info('Data processing configuration for current model + dataset:') |
| for n, v in new_config.items(): |
| _logger.info('\t%s: %s' % (n, str(v))) |
|
|
| return new_config |
|
|