| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """A few things commonly used across A LOT of config files.""" |
| |
|
| | import string |
| |
|
| | import ml_collections as mlc |
| |
|
| |
|
| | def input_for_quicktest(config_input, quicktest): |
| | if quicktest: |
| | config_input.batch_size = 8 |
| | config_input.shuffle_buffer_size = 10 |
| | config_input.cache_raw = False |
| |
|
| |
|
| | def parse_arg(arg, lazy=False, **spec): |
| | """Makes ConfigDict's get_config single-string argument more usable. |
| | |
| | Example use in the config file: |
| | |
| | import big_vision.configs.common as bvcc |
| | def get_config(arg): |
| | arg = bvcc.parse_arg(arg, |
| | res=(224, int), |
| | runlocal=False, |
| | schedule='short', |
| | ) |
| | |
| | # ... |
| | |
| | config.shuffle_buffer = 250_000 if not arg.runlocal else 50 |
| | |
| | Ways that values can be passed when launching: |
| | |
| | --config amazing.py:runlocal,schedule=long,res=128 |
| | --config amazing.py:res=128 |
| | --config amazing.py:runlocal # A boolean needs no value for "true". |
| | --config amazing.py:runlocal=False # Explicit false boolean. |
| | --config amazing.py:128 # The first spec entry may be passed unnamed alone. |
| | |
| | Uses strict bool conversion (converting 'True', 'true' to True, and 'False', |
| | 'false', '' to False). |
| | |
| | Args: |
| | arg: the string argument that's passed to get_config. |
| | lazy: allow lazy parsing of arguments, which are not in spec. For these, |
| | the type is auto-extracted in dependence of most complex possible type. |
| | **spec: the name and default values of the expected options. |
| | If the value is a tuple, the value's first element is the default value, |
| | and the second element is a function called to convert the string. |
| | Otherwise the type is automatically extracted from the default value. |
| | |
| | Returns: |
| | ConfigDict object with extracted type-converted values. |
| | """ |
| | |
| | arg = arg or '' |
| | spec = {k: get_type_with_default(v) for k, v in spec.items()} |
| |
|
| | result = mlc.ConfigDict(type_safe=False) |
| |
|
| | |
| | if arg and ',' not in arg and '=' not in arg: |
| | |
| | |
| | if arg in spec or not spec: |
| | arg = f'{arg}=True' |
| | |
| | else: |
| | arg = f'{list(spec.keys())[0]}={arg}' |
| | |
| |
|
| | |
| | raw_kv = {raw_arg.split('=')[0]: |
| | raw_arg.split('=', 1)[-1] if '=' in raw_arg else 'True' |
| | for raw_arg in arg.split(',') if raw_arg} |
| |
|
| | |
| | for name, (default, type_fn) in spec.items(): |
| | val = raw_kv.pop(name, None) |
| | result[name] = type_fn(val) if val is not None else default |
| |
|
| | if raw_kv: |
| | if lazy: |
| | for k, v in raw_kv.items(): |
| | result[k] = autotype(v) |
| | else: |
| | raise ValueError(f'Unhandled config args remain: {raw_kv}') |
| |
|
| | return result |
| |
|
| |
|
| | def get_type_with_default(v): |
| | """Returns (v, string_to_v_type) with lenient bool parsing.""" |
| | |
| | if isinstance(v, bool): |
| | def strict_bool(x): |
| | assert x.lower() in {'true', 'false', ''} |
| | return x.lower() == 'true' |
| | return (v, strict_bool) |
| | |
| | if isinstance(v, (tuple, list)): |
| | assert len(v) == 2 and isinstance(v[1], type), ( |
| | 'List or tuple types are currently not supported because we use `,` as' |
| | ' dumb delimiter. Contributions (probably using ast) welcome. You can' |
| | ' unblock by using a string with eval(s.replace(";", ",")) or similar') |
| | return (v[0], v[1]) |
| | |
| | return (v, type(v)) |
| |
|
| |
|
| | def autotype(x): |
| | """Auto-converts string to bool/int/float if possible.""" |
| | assert isinstance(x, str) |
| | if x.lower() in {'true', 'false'}: |
| | return x.lower() == 'true' |
| | try: |
| | return int(x) |
| | except ValueError: |
| | try: |
| | return float(x) |
| | except ValueError: |
| | return x |
| |
|
| |
|
| | def pack_arg(**kw): |
| | """Packs key-word args as a string to be parsed by `parse_arg()`.""" |
| | for v in kw.values(): |
| | assert ',' not in f'{v}', f"Can't use `,` in config_arg value: {v}" |
| | return ','.join([f'{k}={v}' for k, v in kw.items()]) |
| |
|
| |
|
| | def arg(**kw): |
| | """Use like `add(**bvcc.arg(res=256, foo=bar), lr=0.1)` to pass config_arg.""" |
| | return {'config_arg': pack_arg(**kw), **kw} |
| |
|
| |
|
| | def _get_field_ref(config_dict, field_name): |
| | path = field_name.split('.') |
| | for field in path[:-1]: |
| | config_dict = getattr(config_dict, field) |
| | return config_dict.get_ref(path[-1]) |
| |
|
| |
|
| | def format_str(format_string, config): |
| | """Format string with reference fields from config. |
| | |
| | This makes it easy to build preprocess strings that contain references to |
| | fields tha are edited after. E.g.: |
| | |
| | ``` |
| | config = mlc.ConficDict() |
| | config.res = (256, 256) |
| | config.pp = bvcc.format_str('resize({res})', config) |
| | ... |
| | # if config.res is modified (e.g. via sweeps) it will propagate to pp field: |
| | config.res = (512, 512) |
| | assert config.pp == 'resize((512, 512))' |
| | ``` |
| | |
| | Args: |
| | format_string: string to format with references. |
| | config: ConfigDict to get references to format the string. |
| | |
| | Returns: |
| | A reference field which renders a string using references to config fields. |
| | """ |
| | output = '' |
| | parts = string.Formatter().parse(format_string) |
| | for (literal_text, field_name, format_spec, conversion) in parts: |
| | assert not format_spec and not conversion |
| | output += literal_text |
| | if field_name: |
| | output += _get_field_ref(config, field_name).to_str() |
| | return output |
| |
|