Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """A parameter dictionary class which supports the nest structure.""" | |
| import collections | |
| import copy | |
| import re | |
| import six | |
| import tensorflow as tf, tf_keras | |
| import yaml | |
| # regex pattern that matches on key-value pairs in a comma-separated | |
| # key-value pair string. It splits each k-v pair on the = sign, and | |
| # matches on values that are within single quotes, double quotes, single | |
| # values (e.g. floats, ints, etc.), and a lists within brackets. | |
| _PARAM_RE = re.compile( | |
| r""" | |
| (?P<name>[a-zA-Z][\w\.]*)(?P<bracketed_index>\[?[0-9]*\]?) # variable name: "var" or "x" followed by optional index: "[0]" or "[23]" | |
| \s*=\s* | |
| ((?P<val>\'(.*?)\' # single quote | |
| | | |
| \"(.*?)\" # double quote | |
| | | |
| [^,\[]* # single value | |
| | | |
| \[[^\]]*\])) # list of values | |
| ($|,\s*)""", re.VERBOSE) | |
| _CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)') | |
| # Yaml LOADER with an implicit resolver to parse float decimal and exponential | |
| # format. The regular experission parse the following cases: | |
| # 1- Decimal number with an optional exponential term. | |
| # 2- Integer number with an exponential term. | |
| # 3- Decimal number with an optional exponential term. | |
| # 4- Decimal number. | |
| _LOADER = yaml.FullLoader | |
| _LOADER.add_implicit_resolver( | |
| 'tag:yaml.org,2002:float', | |
| re.compile(r''' | |
| ^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? | |
| | | |
| [-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) | |
| | | |
| \\.[0-9_]+(?:[eE][-+][0-9]+)? | |
| | | |
| [-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X), | |
| list('-+0123456789.')) | |
| class ParamsDict(object): | |
| """A hyperparameter container class.""" | |
| RESERVED_ATTR = ['_locked', '_restrictions'] | |
| def __init__(self, default_params=None, restrictions=None): | |
| """Instantiate a ParamsDict. | |
| Instantiate a ParamsDict given a set of default parameters and a list of | |
| restrictions. Upon initialization, it validates itself by checking all the | |
| defined restrictions, and raise error if it finds inconsistency. | |
| Args: | |
| default_params: a Python dict or another ParamsDict object including the | |
| default parameters to initialize. | |
| restrictions: a list of strings, which define a list of restrictions to | |
| ensure the consistency of different parameters internally. Each | |
| restriction string is defined as a binary relation with a set of | |
| operators, including {'==', '!=', '<', '<=', '>', '>='}. | |
| """ | |
| self._locked = False | |
| self._restrictions = [] | |
| if restrictions: | |
| self._restrictions = restrictions | |
| if default_params is None: | |
| default_params = {} | |
| self.override(default_params, is_strict=False) | |
| def _set(self, k, v): | |
| if isinstance(v, dict): | |
| self.__dict__[k] = ParamsDict(v) | |
| else: | |
| self.__dict__[k] = copy.deepcopy(v) | |
| def __setattr__(self, k, v): | |
| """Sets the value of the existing key. | |
| Note that this does not allow directly defining a new key. Use the | |
| `override` method with `is_strict=False` instead. | |
| Args: | |
| k: the key string. | |
| v: the value to be used to set the key `k`. | |
| Raises: | |
| KeyError: if k is not defined in the ParamsDict. | |
| """ | |
| if k not in ParamsDict.RESERVED_ATTR: | |
| if k not in self.__dict__.keys(): | |
| raise KeyError('The key `%{}` does not exist. ' | |
| 'To extend the existing keys, use ' | |
| '`override` with `is_strict` = True.'.format(k)) | |
| if self._locked: | |
| raise ValueError('The ParamsDict has been locked. ' | |
| 'No change is allowed.') | |
| self._set(k, v) | |
| def __getattr__(self, k): | |
| """Gets the value of the existing key. | |
| Args: | |
| k: the key string. | |
| Returns: | |
| the value of the key. | |
| Raises: | |
| AttributeError: if k is not defined in the ParamsDict. | |
| """ | |
| if k not in self.__dict__.keys(): | |
| raise AttributeError('The key `{}` does not exist. '.format(k)) | |
| return self.__dict__[k] | |
| def __contains__(self, key): | |
| """Implements the membership test operator.""" | |
| return key in self.__dict__ | |
| def get(self, key, value=None): | |
| """Accesses through built-in dictionary get method.""" | |
| return self.__dict__.get(key, value) | |
| def __delattr__(self, k): | |
| """Deletes the key and removes its values. | |
| Args: | |
| k: the key string. | |
| Raises: | |
| AttributeError: if k is reserverd or not defined in the ParamsDict. | |
| ValueError: if the ParamsDict instance has been locked. | |
| """ | |
| if k in ParamsDict.RESERVED_ATTR: | |
| raise AttributeError( | |
| 'The key `{}` is reserved. No change is allowes. '.format(k)) | |
| if k not in self.__dict__.keys(): | |
| raise AttributeError('The key `{}` does not exist. '.format(k)) | |
| if self._locked: | |
| raise ValueError('The ParamsDict has been locked. No change is allowed.') | |
| del self.__dict__[k] | |
| def override(self, override_params, is_strict=True): | |
| """Override the ParamsDict with a set of given params. | |
| Args: | |
| override_params: a dict or a ParamsDict specifying the parameters to be | |
| overridden. | |
| is_strict: a boolean specifying whether override is strict or not. If | |
| True, keys in `override_params` must be present in the ParamsDict. If | |
| False, keys in `override_params` can be different from what is currently | |
| defined in the ParamsDict. In this case, the ParamsDict will be extended | |
| to include the new keys. | |
| """ | |
| if self._locked: | |
| raise ValueError('The ParamsDict has been locked. No change is allowed.') | |
| if isinstance(override_params, ParamsDict): | |
| override_params = override_params.as_dict() | |
| self._override(override_params, is_strict) # pylint: disable=protected-access | |
| def _override(self, override_dict, is_strict=True): | |
| """The implementation of `override`.""" | |
| for k, v in six.iteritems(override_dict): | |
| if k in ParamsDict.RESERVED_ATTR: | |
| raise KeyError('The key `%{}` is internally reserved. ' | |
| 'Can not be overridden.') | |
| if k not in self.__dict__.keys(): | |
| if is_strict: | |
| raise KeyError('The key `{}` does not exist. ' | |
| 'To extend the existing keys, use ' | |
| '`override` with `is_strict` = False.'.format(k)) | |
| else: | |
| self._set(k, v) | |
| else: | |
| if isinstance(v, dict): | |
| self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access | |
| elif isinstance(v, ParamsDict): | |
| self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access | |
| else: | |
| self.__dict__[k] = copy.deepcopy(v) | |
| def lock(self): | |
| """Makes the ParamsDict immutable.""" | |
| self._locked = True | |
| def as_dict(self): | |
| """Returns a dict representation of ParamsDict. | |
| For the nested ParamsDict, a nested dict will be returned. | |
| """ | |
| params_dict = {} | |
| for k, v in six.iteritems(self.__dict__): | |
| if k not in ParamsDict.RESERVED_ATTR: | |
| if isinstance(v, ParamsDict): | |
| params_dict[k] = v.as_dict() | |
| else: | |
| params_dict[k] = copy.deepcopy(v) | |
| return params_dict | |
| def validate(self): | |
| """Validate the parameters consistency based on the restrictions. | |
| This method validates the internal consistency using the pre-defined list of | |
| restrictions. A restriction is defined as a string which specifies a binary | |
| operation. The supported binary operations are {'==', '!=', '<', '<=', '>', | |
| '>='}. Note that the meaning of these operators are consistent with the | |
| underlying Python immplementation. Users should make sure the define | |
| restrictions on their type make sense. | |
| For example, for a ParamsDict like the following | |
| ``` | |
| a: | |
| a1: 1 | |
| a2: 2 | |
| b: | |
| bb: | |
| bb1: 10 | |
| bb2: 20 | |
| ccc: | |
| a1: 1 | |
| a3: 3 | |
| ``` | |
| one can define two restrictions like this | |
| ['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2'] | |
| What it enforces are: | |
| - a.a1 = 1 == b.ccc.a1 = 1 | |
| - a.a2 = 2 <= b.bb.bb2 = 20 | |
| Raises: | |
| KeyError: if any of the following happens | |
| (1) any of parameters in any of restrictions is not defined in | |
| ParamsDict, | |
| (2) any inconsistency violating the restriction is found. | |
| ValueError: if the restriction defined in the string is not supported. | |
| """ | |
| def _get_kv(dotted_string, params_dict): | |
| """Get keys and values indicated by dotted_string.""" | |
| if _CONST_VALUE_RE.match(dotted_string) is not None: | |
| const_str = dotted_string | |
| if const_str == 'None': | |
| constant = None | |
| else: | |
| constant = float(const_str) | |
| return None, constant | |
| else: | |
| tokenized_params = dotted_string.split('.') | |
| v = params_dict | |
| for t in tokenized_params: | |
| v = v[t] | |
| return tokenized_params[-1], v | |
| def _get_kvs(tokens, params_dict): | |
| if len(tokens) != 2: | |
| raise ValueError('Only support binary relation in restriction.') | |
| stripped_tokens = [t.strip() for t in tokens] | |
| left_k, left_v = _get_kv(stripped_tokens[0], params_dict) | |
| right_k, right_v = _get_kv(stripped_tokens[1], params_dict) | |
| return left_k, left_v, right_k, right_v | |
| params_dict = self.as_dict() | |
| for restriction in self._restrictions: | |
| if '==' in restriction: | |
| tokens = restriction.split('==') | |
| _, left_v, _, right_v = _get_kvs(tokens, params_dict) | |
| if left_v != right_v: | |
| raise KeyError( | |
| 'Found inconsistency between key `{}` and key `{}`.'.format( | |
| tokens[0], tokens[1])) | |
| elif '!=' in restriction: | |
| tokens = restriction.split('!=') | |
| _, left_v, _, right_v = _get_kvs(tokens, params_dict) | |
| if left_v == right_v: | |
| raise KeyError( | |
| 'Found inconsistency between key `{}` and key `{}`.'.format( | |
| tokens[0], tokens[1])) | |
| elif '<' in restriction: | |
| tokens = restriction.split('<') | |
| _, left_v, _, right_v = _get_kvs(tokens, params_dict) | |
| if left_v >= right_v: | |
| raise KeyError( | |
| 'Found inconsistency between key `{}` and key `{}`.'.format( | |
| tokens[0], tokens[1])) | |
| elif '<=' in restriction: | |
| tokens = restriction.split('<=') | |
| _, left_v, _, right_v = _get_kvs(tokens, params_dict) | |
| if left_v > right_v: | |
| raise KeyError( | |
| 'Found inconsistency between key `{}` and key `{}`.'.format( | |
| tokens[0], tokens[1])) | |
| elif '>' in restriction: | |
| tokens = restriction.split('>') | |
| _, left_v, _, right_v = _get_kvs(tokens, params_dict) | |
| if left_v <= right_v: | |
| raise KeyError( | |
| 'Found inconsistency between key `{}` and key `{}`.'.format( | |
| tokens[0], tokens[1])) | |
| elif '>=' in restriction: | |
| tokens = restriction.split('>=') | |
| _, left_v, _, right_v = _get_kvs(tokens, params_dict) | |
| if left_v < right_v: | |
| raise KeyError( | |
| 'Found inconsistency between key `{}` and key `{}`.'.format( | |
| tokens[0], tokens[1])) | |
| else: | |
| raise ValueError('Unsupported relation in restriction.') | |
| def read_yaml_to_params_dict(file_path: str): | |
| """Reads a YAML file to a ParamsDict.""" | |
| with tf.io.gfile.GFile(file_path, 'r') as f: | |
| params_dict = yaml.load(f, Loader=_LOADER) | |
| return ParamsDict(params_dict) | |
| def save_params_dict_to_yaml(params, file_path): | |
| """Saves the input ParamsDict to a YAML file.""" | |
| with tf.io.gfile.GFile(file_path, 'w') as f: | |
| def _my_list_rep(dumper, data): | |
| # u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence. | |
| return dumper.represent_sequence( | |
| u'tag:yaml.org,2002:seq', data, flow_style=True) | |
| yaml.add_representer(list, _my_list_rep) | |
| yaml.dump(params.as_dict(), f, default_flow_style=False) | |
| def nested_csv_str_to_json_str(csv_str): | |
| """Converts a nested (using '.') comma-separated k=v string to a JSON string. | |
| Converts a comma-separated string of key/value pairs that supports | |
| nesting of keys to a JSON string. Nesting is implemented using | |
| '.' between levels for a given key. | |
| Spacing between commas and = is supported (e.g. there is no difference between | |
| "a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before | |
| keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported). | |
| Note that this will only support values supported by CSV, meaning | |
| values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not | |
| supported. Strings are supported as well, e.g. "a='hello'". | |
| An example conversion would be: | |
| "a=1, b=2, c.a=2, c.b=3, d.a.a=5" | |
| to | |
| "{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}" | |
| Args: | |
| csv_str: the comma separated string. | |
| Returns: | |
| the converted JSON string. | |
| Raises: | |
| ValueError: If csv_str is not in a comma separated string or | |
| if the string is formatted incorrectly. | |
| """ | |
| if not csv_str: | |
| return '' | |
| array_param_map = collections.defaultdict(str) | |
| max_index_map = collections.defaultdict(str) | |
| formatted_entries = [] | |
| nested_map = collections.defaultdict(list) | |
| pos = 0 | |
| while pos < len(csv_str): | |
| m = _PARAM_RE.match(csv_str, pos) | |
| if not m: | |
| raise ValueError('Malformed hyperparameter value while parsing ' | |
| 'CSV string: %s' % csv_str[pos:]) | |
| pos = m.end() | |
| # Parse the values. | |
| m_dict = m.groupdict() | |
| name = m_dict['name'] | |
| v = m_dict['val'] | |
| bracketed_index = m_dict['bracketed_index'] | |
| # If we reach the name of the array. | |
| if bracketed_index and '.' not in name: | |
| # Extract the array's index by removing '[' and ']' | |
| index = int(bracketed_index[1:-1]) | |
| if '.' in v: | |
| numeric_val = float(v) | |
| else: | |
| numeric_val = int(v) | |
| # Add the value to the array. | |
| if name not in array_param_map: | |
| max_index_map[name] = index | |
| array_param_map[name] = [None] * (index + 1) | |
| array_param_map[name][index] = numeric_val | |
| elif index < max_index_map[name]: | |
| array_param_map[name][index] = numeric_val | |
| else: | |
| array_param_map[name] += [None] * (index - max_index_map[name]) | |
| array_param_map[name][index] = numeric_val | |
| max_index_map[name] = index | |
| continue | |
| # If a GCS path (e.g. gs://...) is provided, wrap this in quotes | |
| # as yaml.load would otherwise throw an exception | |
| if re.match(r'(?=[^\"\'])(?=[gs://])', v): | |
| v = '\'{}\''.format(v) | |
| name_nested = name.split('.') | |
| if len(name_nested) > 1: | |
| grouping = name_nested[0] | |
| if bracketed_index: | |
| value = '.'.join(name_nested[1:]) + bracketed_index + '=' + v | |
| else: | |
| value = '.'.join(name_nested[1:]) + '=' + v | |
| nested_map[grouping].append(value) | |
| else: | |
| formatted_entries.append('%s : %s' % (name, v)) | |
| for grouping, value in nested_map.items(): | |
| value = ','.join(value) | |
| value = nested_csv_str_to_json_str(value) | |
| formatted_entries.append('%s : %s' % (grouping, value)) | |
| # Add array parameters and check that the array is fully initialized. | |
| for name in array_param_map: | |
| if any(v is None for v in array_param_map[name]): | |
| raise ValueError('Did not pass all values of array: %s' % name) | |
| formatted_entries.append('%s : %s' % (name, array_param_map[name])) | |
| return '{' + ', '.join(formatted_entries) + '}' | |
| def override_params_dict(params, dict_or_string_or_yaml_file, is_strict): | |
| """Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file. | |
| The logic of the function is outlined below: | |
| 1. Test that the input is a dict. If not, proceed to 2. | |
| 2. Tests that the input is a string. If not, raise unknown ValueError | |
| 2.1. Test if the string is in a CSV format. If so, parse. | |
| If not, proceed to 2.2. | |
| 2.2. Try loading the string as a YAML/JSON. If successful, parse to | |
| dict and use it to override. If not, proceed to 2.3. | |
| 2.3. Try using the string as a file path and load the YAML file. | |
| Args: | |
| params: a ParamsDict object to be overridden. | |
| dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to | |
| a YAML file specifying the parameters to be overridden. | |
| is_strict: a boolean specifying whether override is strict or not. | |
| Returns: | |
| params: the overridden ParamsDict object. | |
| Raises: | |
| ValueError: if failed to override the parameters. | |
| """ | |
| if not dict_or_string_or_yaml_file: | |
| return params | |
| if isinstance(dict_or_string_or_yaml_file, dict): | |
| params.override(dict_or_string_or_yaml_file, is_strict) | |
| elif isinstance(dict_or_string_or_yaml_file, six.string_types): | |
| try: | |
| dict_or_string_or_yaml_file = ( | |
| nested_csv_str_to_json_str(dict_or_string_or_yaml_file)) | |
| except ValueError: | |
| pass | |
| params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=_LOADER) | |
| if isinstance(params_dict, dict): | |
| params.override(params_dict, is_strict) | |
| else: | |
| with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f: | |
| params.override(yaml.load(f, Loader=_LOADER), is_strict) | |
| else: | |
| raise ValueError('Unknown input type to parse.') | |
| return params | |