| | |
| | |
| | |
| | |
| |
|
| | |
| | """Hyperparameter values.""" |
| | from __future__ import absolute_import |
| | from __future__ import division |
| | from __future__ import print_function |
| |
|
| | import json |
| | import numbers |
| | import re |
| | import six |
| |
|
| | |
| | |
| | |
| | |
| | |
| | PARAM_RE = re.compile( |
| | r""" |
| | (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" |
| | (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None |
| | \s*=\s* |
| | ((?P<val>[^,\[]*) # single value: "a" or None |
| | | |
| | \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3" |
| | ($|,\s*)""", |
| | re.VERBOSE, |
| | ) |
| |
|
| |
|
| | def _parse_fail(name, var_type, value, values): |
| | """Helper function for raising a value error for bad assignment.""" |
| | raise ValueError( |
| | "Could not parse hparam '%s' of type '%s' with value '%s' in %s" |
| | % (name, var_type.__name__, value, values) |
| | ) |
| |
|
| |
|
| | def _reuse_fail(name, values): |
| | """Helper function for raising a value error for reuse of name.""" |
| | raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values)) |
| |
|
| |
|
| | def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary): |
| | """Update results_dictionary with a scalar value. |
| | |
| | Used to update the results_dictionary to be returned by parse_values when |
| | encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) |
| | |
| | Mutates results_dictionary. |
| | |
| | Args: |
| | name: Name of variable in assignment ("s" or "arr"). |
| | parse_fn: Function for parsing the actual value. |
| | var_type: Type of named variable. |
| | m_dict: Dictionary constructed from regex parsing. |
| | m_dict['val']: RHS value (scalar) |
| | m_dict['index']: List index value (or None) |
| | values: Full expression being parsed |
| | results_dictionary: The dictionary being updated for return by the parsing |
| | function. |
| | |
| | Raises: |
| | ValueError: If the name has already been used. |
| | """ |
| | try: |
| | parsed_value = parse_fn(m_dict["val"]) |
| | except ValueError: |
| | _parse_fail(name, var_type, m_dict["val"], values) |
| |
|
| | |
| | if not m_dict["index"]: |
| | if name in results_dictionary: |
| | _reuse_fail(name, values) |
| | results_dictionary[name] = parsed_value |
| | else: |
| | if name in results_dictionary: |
| | |
| | |
| | if not isinstance(results_dictionary.get(name), dict): |
| | _reuse_fail(name, values) |
| | else: |
| | results_dictionary[name] = {} |
| |
|
| | index = int(m_dict["index"]) |
| | |
| | if index in results_dictionary[name]: |
| | _reuse_fail("{}[{}]".format(name, index), values) |
| | results_dictionary[name][index] = parsed_value |
| |
|
| |
|
| | def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary): |
| | """Update results_dictionary from a list of values. |
| | |
| | Used to update results_dictionary to be returned by parse_values when |
| | encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) |
| | |
| | Mutates results_dictionary. |
| | |
| | Args: |
| | name: Name of variable in assignment ("arr"). |
| | parse_fn: Function for parsing individual values. |
| | var_type: Type of named variable. |
| | m_dict: Dictionary constructed from regex parsing. |
| | m_dict['val']: RHS value (scalar) |
| | values: Full expression being parsed |
| | results_dictionary: The dictionary being updated for return by the parsing |
| | function. |
| | |
| | Raises: |
| | ValueError: If the name has an index or the values cannot be parsed. |
| | """ |
| | if m_dict["index"] is not None: |
| | raise ValueError("Assignment of a list to a list index.") |
| | elements = filter(None, re.split("[ ,]", m_dict["vals"])) |
| | |
| | if name in results_dictionary: |
| | raise _reuse_fail(name, values) |
| | try: |
| | results_dictionary[name] = [parse_fn(e) for e in elements] |
| | except ValueError: |
| | _parse_fail(name, var_type, m_dict["vals"], values) |
| |
|
| |
|
| | def _cast_to_type_if_compatible(name, param_type, value): |
| | """Cast hparam to the provided type, if compatible. |
| | |
| | Args: |
| | name: Name of the hparam to be cast. |
| | param_type: The type of the hparam. |
| | value: The value to be cast, if compatible. |
| | |
| | Returns: |
| | The result of casting `value` to `param_type`. |
| | |
| | Raises: |
| | ValueError: If the type of `value` is not compatible with param_type. |
| | * If `param_type` is a string type, but `value` is not. |
| | * If `param_type` is a boolean, but `value` is not, or vice versa. |
| | * If `param_type` is an integer type, but `value` is not. |
| | * If `param_type` is a float type, but `value` is not a numeric type. |
| | """ |
| | fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % ( |
| | name, |
| | param_type, |
| | value, |
| | ) |
| |
|
| | |
| | if issubclass(param_type, type(None)): |
| | return value |
| |
|
| | |
| | if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance( |
| | value, (six.string_types, six.binary_type) |
| | ): |
| | raise ValueError(fail_msg) |
| |
|
| | |
| | if issubclass(param_type, bool) != isinstance(value, bool): |
| | raise ValueError(fail_msg) |
| |
|
| | |
| | if issubclass(param_type, numbers.Integral) and not isinstance( |
| | value, numbers.Integral |
| | ): |
| | raise ValueError(fail_msg) |
| |
|
| | |
| | if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number): |
| | raise ValueError(fail_msg) |
| |
|
| | return param_type(value) |
| |
|
| |
|
| | def parse_values(values, type_map, ignore_unknown=False): |
| | """Parses hyperparameter values from a string into a python map. |
| | |
| | `values` is a string containing comma-separated `name=value` pairs. |
| | For each pair, the value of the hyperparameter named `name` is set to |
| | `value`. |
| | |
| | If a hyperparameter name appears multiple times in `values`, a ValueError |
| | is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). |
| | |
| | If a hyperparameter name in both an index assignment and scalar assignment, |
| | a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). |
| | |
| | The hyperparameter name may contain '.' symbols, which will result in an |
| | attribute name that is only accessible through the getattr and setattr |
| | functions. (And must be first explicit added through add_hparam.) |
| | |
| | WARNING: Use of '.' in your variable names is allowed, but is not well |
| | supported and not recommended. |
| | |
| | The `value` in `name=value` must follows the syntax according to the |
| | type of the parameter: |
| | |
| | * Scalar integer: A Python-parsable integer point value. E.g.: 1, |
| | 100, -12. |
| | * Scalar float: A Python-parsable floating point value. E.g.: 1.0, |
| | -.54e89. |
| | * Boolean: Either true or false. |
| | * Scalar string: A non-empty sequence of characters, excluding comma, |
| | spaces, and square brackets. E.g.: foo, bar_1. |
| | * List: A comma separated list of scalar values of the parameter type |
| | enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. |
| | |
| | When index assignment is used, the corresponding type_map key should be the |
| | list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not |
| | "arr[1]"). |
| | |
| | Args: |
| | values: String. Comma separated list of `name=value` pairs where |
| | 'value' must follow the syntax described above. |
| | type_map: A dictionary mapping hyperparameter names to types. Note every |
| | parameter name in values must be a key in type_map. The values must |
| | conform to the types indicated, where a value V is said to conform to a |
| | type T if either V has type T, or V is a list of elements of type T. |
| | Hence, for a multidimensional parameter 'x' taking float values, |
| | 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. |
| | ignore_unknown: Bool. Whether values that are missing a type in type_map |
| | should be ignored. If set to True, a ValueError will not be raised for |
| | unknown hyperparameter type. |
| | |
| | Returns: |
| | A python map mapping each name to either: |
| | * A scalar value. |
| | * A list of scalar values. |
| | * A dictionary mapping index numbers to scalar values. |
| | (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") |
| | |
| | Raises: |
| | ValueError: If there is a problem with input. |
| | * If `values` cannot be parsed. |
| | * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). |
| | * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', |
| | 'a[1]=1,a[1]=2', or 'a=1,a=[1]') |
| | """ |
| | results_dictionary = {} |
| | pos = 0 |
| | while pos < len(values): |
| | m = PARAM_RE.match(values, pos) |
| | if not m: |
| | raise ValueError("Malformed hyperparameter value: %s" % values[pos:]) |
| | |
| | pos = m.end() |
| | |
| | m_dict = m.groupdict() |
| | name = m_dict["name"] |
| | if name not in type_map: |
| | if ignore_unknown: |
| | continue |
| | raise ValueError("Unknown hyperparameter type for %s" % name) |
| | type_ = type_map[name] |
| |
|
| | |
| | if type_ == bool: |
| |
|
| | def parse_bool(value): |
| | if value in ["true", "True"]: |
| | return True |
| | elif value in ["false", "False"]: |
| | return False |
| | else: |
| | try: |
| | return bool(int(value)) |
| | except ValueError: |
| | _parse_fail(name, type_, value, values) |
| |
|
| | parse = parse_bool |
| | else: |
| | parse = type_ |
| |
|
| | |
| | if m_dict["val"] is not None: |
| | _process_scalar_value( |
| | name, parse, type_, m_dict, values, results_dictionary |
| | ) |
| |
|
| | |
| | elif m_dict["vals"] is not None: |
| | _process_list_value(name, parse, type_, m_dict, values, results_dictionary) |
| |
|
| | else: |
| | _parse_fail(name, type_, "", values) |
| |
|
| | return results_dictionary |
| |
|
| |
|
| | class HParams(object): |
| | """Class to hold a set of hyperparameters as name-value pairs. |
| | |
| | A `HParams` object holds hyperparameters used to build and train a model, |
| | such as the number of hidden units in a neural net layer or the learning rate |
| | to use when training. |
| | |
| | You first create a `HParams` object by specifying the names and values of the |
| | hyperparameters. |
| | |
| | To make them easily accessible the parameter names are added as direct |
| | attributes of the class. A typical usage is as follows: |
| | |
| | ```python |
| | # Create a HParams object specifying names and values of the model |
| | # hyperparameters: |
| | hparams = HParams(learning_rate=0.1, num_hidden_units=100) |
| | |
| | # The hyperparameter are available as attributes of the HParams object: |
| | hparams.learning_rate ==> 0.1 |
| | hparams.num_hidden_units ==> 100 |
| | ``` |
| | |
| | Hyperparameters have type, which is inferred from the type of their value |
| | passed at construction type. The currently supported types are: integer, |
| | float, boolean, string, and list of integer, float, boolean, or string. |
| | |
| | You can override hyperparameter values by calling the |
| | [`parse()`](#HParams.parse) method, passing a string of comma separated |
| | `name=value` pairs. This is intended to make it possible to override |
| | any hyperparameter values from a single command-line flag to which |
| | the user passes 'hyper-param=value' pairs. It avoids having to define |
| | one flag for each hyperparameter. |
| | |
| | The syntax expected for each value depends on the type of the parameter. |
| | See `parse()` for a description of the syntax. |
| | |
| | Example: |
| | |
| | ```python |
| | # Define a command line flag to pass name=value pairs. |
| | # For example using argparse: |
| | import argparse |
| | parser = argparse.ArgumentParser(description='Train my model.') |
| | parser.add_argument('--hparams', type=str, |
| | help='Comma separated list of "name=value" pairs.') |
| | args = parser.parse_args() |
| | ... |
| | def my_program(): |
| | # Create a HParams object specifying the names and values of the |
| | # model hyperparameters: |
| | hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, |
| | activations=['relu', 'tanh']) |
| | |
| | # Override hyperparameters values by parsing the command line |
| | hparams.parse(args.hparams) |
| | |
| | # If the user passed `--hparams=learning_rate=0.3` on the command line |
| | # then 'hparams' has the following attributes: |
| | hparams.learning_rate ==> 0.3 |
| | hparams.num_hidden_units ==> 100 |
| | hparams.activations ==> ['relu', 'tanh'] |
| | |
| | # If the hyperparameters are in json format use parse_json: |
| | hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') |
| | ``` |
| | """ |
| |
|
| | _HAS_DYNAMIC_ATTRIBUTES = True |
| |
|
| | def __init__(self, model_structure=None, **kwargs): |
| | """Create an instance of `HParams` from keyword arguments. |
| | |
| | The keyword arguments specify name-values pairs for the hyperparameters. |
| | The parameter types are inferred from the type of the values passed. |
| | |
| | The parameter names are added as attributes of `HParams` object, so they |
| | can be accessed directly with the dot notation `hparams._name_`. |
| | |
| | Example: |
| | |
| | ```python |
| | # Define 3 hyperparameters: 'learning_rate' is a float parameter, |
| | # 'num_hidden_units' an integer parameter, and 'activation' a string |
| | # parameter. |
| | hparams = tf.HParams( |
| | learning_rate=0.1, num_hidden_units=100, activation='relu') |
| | |
| | hparams.activation ==> 'relu' |
| | ``` |
| | |
| | Note that a few names are reserved and cannot be used as hyperparameter |
| | names. If you use one of the reserved name the constructor raises a |
| | `ValueError`. |
| | |
| | Args: |
| | model_structure: An instance of ModelStructure, defining the feature |
| | crosses to be used in the Trial. |
| | **kwargs: Key-value pairs where the key is the hyperparameter name and |
| | the value is the value for the parameter. |
| | |
| | Raises: |
| | ValueError: If both `hparam_def` and initialization values are provided, |
| | or if one of the arguments is invalid. |
| | |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | |
| | self._hparam_types = {} |
| | self._model_structure = model_structure |
| | for name, value in six.iteritems(kwargs): |
| | self.add_hparam(name, value) |
| |
|
| | def add_hparam(self, name, value): |
| | """Adds {name, value} pair to hyperparameters. |
| | |
| | Args: |
| | name: Name of the hyperparameter. |
| | value: Value of the hyperparameter. Can be one of the following types: |
| | int, float, string, int list, float list, or string list. |
| | |
| | Raises: |
| | ValueError: if one of the arguments is invalid. |
| | """ |
| | |
| | |
| | |
| | if getattr(self, name, None) is not None: |
| | raise ValueError("Hyperparameter name is reserved: %s" % name) |
| | if isinstance(value, (list, tuple)): |
| | if not value: |
| | raise ValueError( |
| | "Multi-valued hyperparameters cannot be empty: %s" % name |
| | ) |
| | self._hparam_types[name] = (type(value[0]), True) |
| | else: |
| | self._hparam_types[name] = (type(value), False) |
| | setattr(self, name, value) |
| |
|
| | def set_hparam(self, name, value): |
| | """Set the value of an existing hyperparameter. |
| | |
| | This function verifies that the type of the value matches the type of the |
| | existing hyperparameter. |
| | |
| | Args: |
| | name: Name of the hyperparameter. |
| | value: New value of the hyperparameter. |
| | |
| | Raises: |
| | KeyError: If the hyperparameter doesn't exist. |
| | ValueError: If there is a type mismatch. |
| | """ |
| | param_type, is_list = self._hparam_types[name] |
| | if isinstance(value, list): |
| | if not is_list: |
| | raise ValueError( |
| | "Must not pass a list for single-valued parameter: %s" % name |
| | ) |
| | setattr( |
| | self, |
| | name, |
| | [_cast_to_type_if_compatible(name, param_type, v) for v in value], |
| | ) |
| | else: |
| | if is_list: |
| | raise ValueError( |
| | "Must pass a list for multi-valued parameter: %s." % name |
| | ) |
| | setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) |
| |
|
| | def del_hparam(self, name): |
| | """Removes the hyperparameter with key 'name'. |
| | |
| | Does nothing if it isn't present. |
| | |
| | Args: |
| | name: Name of the hyperparameter. |
| | """ |
| | if hasattr(self, name): |
| | delattr(self, name) |
| | del self._hparam_types[name] |
| |
|
| | def parse(self, values): |
| | """Override existing hyperparameter values, parsing new values from a string. |
| | |
| | See parse_values for more detail on the allowed format for values. |
| | |
| | Args: |
| | values: String. Comma separated list of `name=value` pairs where 'value' |
| | must follow the syntax described above. |
| | |
| | Returns: |
| | The `HParams` instance. |
| | |
| | Raises: |
| | ValueError: If `values` cannot be parsed or a hyperparameter in `values` |
| | doesn't exist. |
| | """ |
| | type_map = {} |
| | for name, t in self._hparam_types.items(): |
| | param_type, _ = t |
| | type_map[name] = param_type |
| |
|
| | values_map = parse_values(values, type_map) |
| | return self.override_from_dict(values_map) |
| |
|
| | def override_from_dict(self, values_dict): |
| | """Override existing hyperparameter values, parsing new values from a dictionary. |
| | |
| | Args: |
| | values_dict: Dictionary of name:value pairs. |
| | |
| | Returns: |
| | The `HParams` instance. |
| | |
| | Raises: |
| | KeyError: If a hyperparameter in `values_dict` doesn't exist. |
| | ValueError: If `values_dict` cannot be parsed. |
| | """ |
| | for name, value in values_dict.items(): |
| | self.set_hparam(name, value) |
| | return self |
| |
|
| | def set_model_structure(self, model_structure): |
| | self._model_structure = model_structure |
| |
|
| | def get_model_structure(self): |
| | return self._model_structure |
| |
|
| | def to_json(self, indent=None, separators=None, sort_keys=False): |
| | """Serializes the hyperparameters into JSON. |
| | |
| | Args: |
| | indent: If a non-negative integer, JSON array elements and object members |
| | will be pretty-printed with that indent level. An indent level of 0, or |
| | negative, will only insert newlines. `None` (the default) selects the |
| | most compact representation. |
| | separators: Optional `(item_separator, key_separator)` tuple. Default is |
| | `(', ', ': ')`. |
| | sort_keys: If `True`, the output dictionaries will be sorted by key. |
| | |
| | Returns: |
| | A JSON string. |
| | """ |
| |
|
| | def remove_callables(x): |
| | """Omit callable elements from input with arbitrary nesting.""" |
| | if isinstance(x, dict): |
| | return { |
| | k: remove_callables(v) |
| | for k, v in six.iteritems(x) |
| | if not callable(v) |
| | } |
| | elif isinstance(x, list): |
| | return [remove_callables(i) for i in x if not callable(i)] |
| | return x |
| |
|
| | return json.dumps( |
| | remove_callables(self.values()), |
| | indent=indent, |
| | separators=separators, |
| | sort_keys=sort_keys, |
| | ) |
| |
|
| | def parse_json(self, values_json): |
| | """Override existing hyperparameter values, parsing new values from a json object. |
| | |
| | Args: |
| | values_json: String containing a json object of name:value pairs. |
| | |
| | Returns: |
| | The `HParams` instance. |
| | |
| | Raises: |
| | KeyError: If a hyperparameter in `values_json` doesn't exist. |
| | ValueError: If `values_json` cannot be parsed. |
| | """ |
| | values_map = json.loads(values_json) |
| | return self.override_from_dict(values_map) |
| |
|
| | def values(self): |
| | """Return the hyperparameter values as a Python dictionary. |
| | |
| | Returns: |
| | A dictionary with hyperparameter names as keys. The values are the |
| | hyperparameter values. |
| | """ |
| | return {n: getattr(self, n) for n in self._hparam_types.keys()} |
| |
|
| | def get(self, key, default=None): |
| | """Returns the value of `key` if it exists, else `default`.""" |
| | if key in self._hparam_types: |
| | |
| | if default is not None: |
| | param_type, is_param_list = self._hparam_types[key] |
| | type_str = "list<%s>" % param_type if is_param_list else str(param_type) |
| | fail_msg = ( |
| | "Hparam '%s' of type '%s' is incompatible with " |
| | "default=%s" % (key, type_str, default) |
| | ) |
| |
|
| | is_default_list = isinstance(default, list) |
| | if is_param_list != is_default_list: |
| | raise ValueError(fail_msg) |
| |
|
| | try: |
| | if is_default_list: |
| | for value in default: |
| | _cast_to_type_if_compatible(key, param_type, value) |
| | else: |
| | _cast_to_type_if_compatible(key, param_type, default) |
| | except ValueError as e: |
| | raise ValueError("%s. %s" % (fail_msg, e)) |
| |
|
| | return getattr(self, key) |
| |
|
| | return default |
| |
|
| | def __contains__(self, key): |
| | return key in self._hparam_types |
| |
|
| | def __str__(self): |
| | return str(sorted(self.values().items())) |
| |
|
| | def __repr__(self): |
| | return "%s(%s)" % (type(self).__name__, self.__str__()) |
| |
|
| | @staticmethod |
| | def _get_kind_name(param_type, is_list): |
| | """Returns the field name given parameter type and is_list. |
| | |
| | Args: |
| | param_type: Data type of the hparam. |
| | is_list: Whether this is a list. |
| | |
| | Returns: |
| | A string representation of the field name. |
| | |
| | Raises: |
| | ValueError: If parameter type is not recognized. |
| | """ |
| | if issubclass(param_type, bool): |
| | |
| | |
| | typename = "bool" |
| | elif issubclass(param_type, six.integer_types): |
| | |
| | |
| | typename = "int64" |
| | elif issubclass(param_type, (six.string_types, six.binary_type)): |
| | |
| | |
| | typename = "bytes" |
| | elif issubclass(param_type, float): |
| | typename = "float" |
| | else: |
| | raise ValueError("Unsupported parameter type: %s" % str(param_type)) |
| |
|
| | suffix = "list" if is_list else "value" |
| | return "_".join([typename, suffix]) |
| |
|