Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Base configurations to standardize experiments.""" | |
| import copy | |
| import dataclasses | |
| import functools | |
| import inspect | |
| import typing | |
| from typing import Any, List, Mapping, Optional, Type, Union | |
| from absl import logging | |
| import tensorflow as tf, tf_keras | |
| import yaml | |
| from official.modeling.hyperparams import params_dict | |
| _BOUND = set() | |
| def bind(config_cls): | |
| """Bind a class to config cls.""" | |
| if not inspect.isclass(config_cls): | |
| raise ValueError('The bind decorator is supposed to apply on the class ' | |
| f'attribute. Received {config_cls}, not a class.') | |
| def decorator(builder): | |
| if config_cls in _BOUND: | |
| raise ValueError('Inside a program, we should not bind the config with a' | |
| ' class twice.') | |
| if inspect.isclass(builder): | |
| config_cls._BUILDER = builder # pylint: disable=protected-access | |
| elif inspect.isfunction(builder): | |
| def _wrapper(self, *args, **kwargs): # pylint: disable=unused-argument | |
| return builder(*args, **kwargs) | |
| config_cls._BUILDER = _wrapper # pylint: disable=protected-access | |
| else: | |
| raise ValueError(f'The `BUILDER` type is not supported: {builder}') | |
| _BOUND.add(config_cls) | |
| return builder | |
| return decorator | |
| def _is_optional(field): | |
| return typing.get_origin(field) is Union and type(None) in typing.get_args( | |
| field) | |
| class Config(params_dict.ParamsDict): | |
| """The base configuration class that supports YAML/JSON based overrides. | |
| Because of YAML/JSON serialization limitations, some semantics of dataclass | |
| are not supported: | |
| * It recursively enforces a allowlist of basic types and container types, so | |
| it avoids surprises with copy and reuse caused by unanticipated types. | |
| * Warning: it converts Dict to `Config` even within sequences, | |
| e.g. for config = Config({'key': [([{'a': 42}],)]), | |
| type(config.key[0][0][0]) is Config rather than dict. | |
| If you define/annotate some field as Dict, the field will convert to a | |
| `Config` instance and lose the dictionary type. | |
| """ | |
| # The class or method to bind with the params class. | |
| _BUILDER = None | |
| # It's safe to add bytes and other immutable types here. | |
| IMMUTABLE_TYPES = (str, int, float, bool, type(None)) | |
| # It's safe to add set, frozenset and other collections here. | |
| SEQUENCE_TYPES = (list, tuple) | |
| default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None | |
| restrictions: dataclasses.InitVar[Optional[List[str]]] = None | |
| def __post_init__(self, default_params, restrictions): | |
| super().__init__( | |
| default_params=default_params, | |
| restrictions=restrictions) | |
| def BUILDER(self): | |
| return self._BUILDER | |
| def _get_annotations(cls): | |
| """Returns valid annotations. | |
| Note: this is similar to dataclasses.__annotations__ except it also includes | |
| annotations from its parent classes. | |
| """ | |
| all_annotations = typing.get_type_hints(cls) | |
| # Removes Config class annotation from the value, e.g., default_params, | |
| # restrictions, etc. | |
| for k in Config.__annotations__: | |
| del all_annotations[k] | |
| return all_annotations | |
| def _isvalidsequence(cls, v): | |
| """Check if the input values are valid sequences. | |
| Args: | |
| v: Input sequence. | |
| Returns: | |
| True if the sequence is valid. Valid sequence includes the sequence | |
| type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or | |
| is dict or ParamsDict. | |
| """ | |
| if not isinstance(v, cls.SEQUENCE_TYPES): | |
| return False | |
| return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or | |
| all(isinstance(e, dict) for e in v) or | |
| all(isinstance(e, params_dict.ParamsDict) for e in v)) | |
| def _import_config(cls, v, subconfig_type): | |
| """Returns v with dicts converted to Configs, recursively.""" | |
| if not issubclass(subconfig_type, params_dict.ParamsDict): | |
| raise TypeError( | |
| 'Subconfig_type should be subclass of ParamsDict, found {!r}'.format( | |
| subconfig_type)) | |
| if isinstance(v, cls.IMMUTABLE_TYPES): | |
| return v | |
| elif isinstance(v, cls.SEQUENCE_TYPES): | |
| # Only support one layer of sequence. | |
| if not cls._isvalidsequence(v): | |
| raise TypeError( | |
| 'Invalid sequence: only supports single level {!r} of {!r} or ' | |
| 'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES, | |
| cls.IMMUTABLE_TYPES, v)) | |
| import_fn = functools.partial( | |
| cls._import_config, subconfig_type=subconfig_type) | |
| return type(v)(map(import_fn, v)) | |
| elif isinstance(v, params_dict.ParamsDict): | |
| # Deepcopy here is a temporary solution for preserving type in nested | |
| # Config object. | |
| return copy.deepcopy(v) | |
| elif isinstance(v, dict): | |
| return subconfig_type(v) | |
| else: | |
| raise TypeError('Unknown type: {!r}'.format(type(v))) | |
| def _export_config(cls, v): | |
| """Returns v with Configs converted to dicts, recursively.""" | |
| if isinstance(v, cls.IMMUTABLE_TYPES): | |
| return v | |
| elif isinstance(v, cls.SEQUENCE_TYPES): | |
| return type(v)(map(cls._export_config, v)) | |
| elif isinstance(v, params_dict.ParamsDict): | |
| return v.as_dict() | |
| elif isinstance(v, dict): | |
| raise TypeError('dict value not supported in converting.') | |
| else: | |
| raise TypeError('Unknown type: {!r}'.format(type(v))) | |
| def _get_subconfig_type( | |
| cls, k, subconfig_type=None | |
| ) -> Type[params_dict.ParamsDict]: | |
| """Get element type by the field name. | |
| Args: | |
| k: the key/name of the field. | |
| subconfig_type: default subconfig_type. If None, it is set to | |
| Config. | |
| Returns: | |
| Config as default. If a type annotation is found for `k`, | |
| 1) returns the type of the annotation if it is subtype of ParamsDict; | |
| 2) returns the element type if the annotation of `k` is List[SubType] | |
| or Tuple[SubType]. | |
| """ | |
| if not subconfig_type: | |
| subconfig_type = Config | |
| annotations = cls._get_annotations() | |
| if k in annotations: | |
| # Directly Config subtype. | |
| type_annotation = annotations[k] | |
| i = 0 | |
| # Loop for striping the Optional annotation. | |
| traverse_in = True | |
| while traverse_in: | |
| i += 1 | |
| if (isinstance(type_annotation, type) and | |
| issubclass(type_annotation, Config)): | |
| subconfig_type = type_annotation | |
| break | |
| else: | |
| # Check if the field is a sequence of subtypes. | |
| field_type = typing.get_origin(type_annotation) | |
| if (isinstance(field_type, type) and | |
| issubclass(field_type, cls.SEQUENCE_TYPES)): | |
| element_type = typing.get_args(type_annotation)[0] | |
| subconfig_type = ( | |
| element_type if issubclass(element_type, params_dict.ParamsDict) | |
| else subconfig_type) | |
| break | |
| elif _is_optional(type_annotation): | |
| # Strip the `Optional` annotation and process the subtype. | |
| type_annotation = typing.get_args(type_annotation)[0] | |
| continue | |
| traverse_in = False | |
| return subconfig_type | |
| def _set(self, k, v): | |
| """Overrides same method in ParamsDict. | |
| Also called by ParamsDict methods. | |
| Args: | |
| k: key to set. | |
| v: value. | |
| Raises: | |
| RuntimeError | |
| """ | |
| subconfig_type = self._get_subconfig_type(k) | |
| def is_null(k): | |
| if k not in self.__dict__ or not self.__dict__[k]: | |
| return True | |
| return False | |
| if isinstance(v, dict): | |
| if is_null(k): | |
| # If the key not exist or the value is None, a new Config-family object | |
| # sould be created for the key. | |
| self.__dict__[k] = subconfig_type(v) | |
| else: | |
| self.__dict__[k].override(v) | |
| elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all( | |
| [not isinstance(e, self.IMMUTABLE_TYPES) for e in v]): | |
| if len(self.__dict__[k]) == len(v): | |
| for i in range(len(v)): | |
| self.__dict__[k][i].override(v[i]) | |
| elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]): | |
| logging.warning( | |
| "The list/tuple don't match the value dictionaries provided. Thus, " | |
| 'the list/tuple is determined by the type annotation and ' | |
| 'values provided. This is error-prone.') | |
| self.__dict__[k] = self._import_config(v, subconfig_type) | |
| else: | |
| self.__dict__[k] = self._import_config(v, subconfig_type) | |
| else: | |
| self.__dict__[k] = self._import_config(v, subconfig_type) | |
| def __setattr__(self, k, v): | |
| if k == 'BUILDER' or k == '_BUILDER': | |
| raise AttributeError('`BUILDER` is a property and `_BUILDER` is the ' | |
| 'reserved class attribute. We should only assign ' | |
| '`_BUILDER` at the class level.') | |
| if k not in self.RESERVED_ATTR: | |
| if getattr(self, '_locked', False): | |
| raise ValueError('The Config has been locked. ' 'No change is allowed.') | |
| self._set(k, v) | |
| def _override(self, override_dict, is_strict=True): | |
| """Overrides same method in ParamsDict. | |
| Also called by ParamsDict methods. | |
| Args: | |
| override_dict: dictionary to write to . | |
| is_strict: If True, not allows to add new keys. | |
| Raises: | |
| KeyError: overriding reserved keys or keys not exist (is_strict=True). | |
| """ | |
| for k, v in sorted(override_dict.items()): | |
| if k in self.RESERVED_ATTR: | |
| raise KeyError('The key {!r} is internally reserved. ' | |
| 'Can not be overridden.'.format(k)) | |
| if k not in self.__dict__: | |
| if is_strict: | |
| raise KeyError('The key {!r} does not exist in {!r}. ' | |
| 'To extend the existing keys, use ' | |
| '`override` with `is_strict` = False.'.format( | |
| k, type(self))) | |
| else: | |
| self._set(k, v) | |
| else: | |
| if isinstance(v, dict) and self.__dict__[k]: | |
| self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access | |
| elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]: | |
| self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access | |
| else: | |
| self._set(k, v) | |
| def as_dict(self): | |
| """Returns a dict representation of params_dict.ParamsDict. | |
| For the nested params_dict.ParamsDict, a nested dict will be returned. | |
| """ | |
| return { | |
| k: self._export_config(v) | |
| for k, v in self.__dict__.items() | |
| if k not in self.RESERVED_ATTR | |
| } | |
| def replace(self, **kwargs): | |
| """Overrides/returns a unlocked copy with the current config unchanged.""" | |
| # pylint: disable=protected-access | |
| params = copy.deepcopy(self) | |
| params._locked = False | |
| params._override(kwargs, is_strict=True) | |
| # pylint: enable=protected-access | |
| return params | |
| def from_yaml(cls, file_path: str): | |
| # Note: This only works if the Config has all default values. | |
| with tf.io.gfile.GFile(file_path, 'r') as f: | |
| loaded = yaml.load(f, Loader=yaml.FullLoader) | |
| config = cls() | |
| config.override(loaded) | |
| return config | |
| def from_json(cls, file_path: str): | |
| """Wrapper for `from_yaml`.""" | |
| return cls.from_yaml(file_path) | |
| def from_args(cls, *args, **kwargs): | |
| """Builds a config from the given list of arguments.""" | |
| # Note we intend to keep `__annotations__` instead of `_get_annotations`. | |
| # Assuming a parent class of (a, b) with the sub-class of (c, d), the | |
| # sub-class will take (c, d) for args, rather than starting from (a, b). | |
| attributes = list(cls.__annotations__.keys()) | |
| default_params = {a: p for a, p in zip(attributes, args)} | |
| default_params.update(kwargs) | |
| return cls(default_params=default_params) | |