Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import copy | |
| import functools | |
| from enum import Enum, unique | |
| import json_tricks | |
| from schema import And | |
| from . import parameter_expressions | |
| to_json = functools.partial(json_tricks.dumps, allow_nan=True) | |
| class OptimizeMode(Enum): | |
| """Optimize Mode class | |
| if OptimizeMode is 'minimize', it means the tuner need to minimize the reward | |
| that received from Trial. | |
| if OptimizeMode is 'maximize', it means the tuner need to maximize the reward | |
| that received from Trial. | |
| """ | |
| Minimize = 'minimize' | |
| Maximize = 'maximize' | |
| class NodeType: | |
| """Node Type class | |
| """ | |
| ROOT = 'root' | |
| TYPE = '_type' | |
| VALUE = '_value' | |
| INDEX = '_index' | |
| NAME = '_name' | |
| class MetricType: | |
| """The types of metric data | |
| """ | |
| FINAL = 'FINAL' | |
| PERIODICAL = 'PERIODICAL' | |
| REQUEST_PARAMETER = 'REQUEST_PARAMETER' | |
| def split_index(params): | |
| """ | |
| Delete index infromation from params | |
| """ | |
| if isinstance(params, dict): | |
| if NodeType.INDEX in params.keys(): | |
| return split_index(params[NodeType.VALUE]) | |
| result = {} | |
| for key in params: | |
| result[key] = split_index(params[key]) | |
| return result | |
| else: | |
| return params | |
| def extract_scalar_reward(value, scalar_key='default'): | |
| """ | |
| Extract scalar reward from trial result. | |
| Parameters | |
| ---------- | |
| value : int, float, dict | |
| the reported final metric data | |
| scalar_key : str | |
| the key name that indicates the numeric number | |
| Raises | |
| ------ | |
| RuntimeError | |
| Incorrect final result: the final result should be float/int, | |
| or a dict which has a key named "default" whose value is float/int. | |
| """ | |
| if isinstance(value, (float, int)): | |
| reward = value | |
| elif isinstance(value, dict) and scalar_key in value and isinstance(value[scalar_key], (float, int)): | |
| reward = value[scalar_key] | |
| else: | |
| raise RuntimeError('Incorrect final result: the final result should be float/int, ' \ | |
| 'or a dict which has a key named "default" whose value is float/int.') | |
| return reward | |
| def extract_scalar_history(trial_history, scalar_key='default'): | |
| """ | |
| Extract scalar value from a list of intermediate results. | |
| Parameters | |
| ---------- | |
| trial_history : list | |
| accumulated intermediate results of a trial | |
| scalar_key : str | |
| the key name that indicates the numeric number | |
| Raises | |
| ------ | |
| RuntimeError | |
| Incorrect final result: the final result should be float/int, | |
| or a dict which has a key named "default" whose value is float/int. | |
| """ | |
| return [extract_scalar_reward(ele, scalar_key) for ele in trial_history] | |
| def convert_dict2tuple(value): | |
| """ | |
| convert dict type to tuple to solve unhashable problem. | |
| NOTE: this function will change original data. | |
| """ | |
| if isinstance(value, dict): | |
| for _keys in value: | |
| value[_keys] = convert_dict2tuple(value[_keys]) | |
| return tuple(sorted(value.items())) | |
| return value | |
| def json2space(x, oldy=None, name=NodeType.ROOT): | |
| """ | |
| Change search space from json format to hyperopt format | |
| """ | |
| y = list() | |
| if isinstance(x, dict): | |
| if NodeType.TYPE in x.keys(): | |
| _type = x[NodeType.TYPE] | |
| name = name + '-' + _type | |
| if _type == 'choice': | |
| if oldy is not None: | |
| _index = oldy[NodeType.INDEX] | |
| y += json2space(x[NodeType.VALUE][_index], | |
| oldy[NodeType.VALUE], name=name+'[%d]' % _index) | |
| else: | |
| y += json2space(x[NodeType.VALUE], None, name=name) | |
| y.append(name) | |
| else: | |
| for key in x.keys(): | |
| y += json2space(x[key], oldy[key] if oldy else None, name+"[%s]" % str(key)) | |
| elif isinstance(x, list): | |
| for i, x_i in enumerate(x): | |
| if isinstance(x_i, dict): | |
| if NodeType.NAME not in x_i.keys(): | |
| raise RuntimeError('\'_name\' key is not found in this nested search space.') | |
| y += json2space(x_i, oldy[i] if oldy else None, name + "[%d]" % i) | |
| return y | |
| def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeType.ROOT): | |
| """ | |
| Json to pramaters. | |
| """ | |
| if isinstance(x, dict): | |
| if NodeType.TYPE in x.keys(): | |
| _type = x[NodeType.TYPE] | |
| _value = x[NodeType.VALUE] | |
| name = name + '-' + _type | |
| Rand |= is_rand[name] | |
| if Rand is True: | |
| if _type == 'choice': | |
| _index = random_state.randint(len(_value)) | |
| y = { | |
| NodeType.INDEX: _index, | |
| NodeType.VALUE: json2parameter( | |
| x[NodeType.VALUE][_index], | |
| is_rand, | |
| random_state, | |
| None, | |
| Rand, | |
| name=name+"[%d]" % _index | |
| ) | |
| } | |
| else: | |
| y = getattr(parameter_expressions, _type)(*(_value + [random_state])) | |
| else: | |
| y = copy.deepcopy(oldy) | |
| else: | |
| y = dict() | |
| for key in x.keys(): | |
| y[key] = json2parameter( | |
| x[key], | |
| is_rand, | |
| random_state, | |
| oldy[key] if oldy else None, | |
| Rand, | |
| name + "[%s]" % str(key) | |
| ) | |
| elif isinstance(x, list): | |
| y = list() | |
| for i, x_i in enumerate(x): | |
| if isinstance(x_i, dict): | |
| if NodeType.NAME not in x_i.keys(): | |
| raise RuntimeError('\'_name\' key is not found in this nested search space.') | |
| y.append(json2parameter( | |
| x_i, | |
| is_rand, | |
| random_state, | |
| oldy[i] if oldy else None, | |
| Rand, | |
| name + "[%d]" % i | |
| )) | |
| else: | |
| y = copy.deepcopy(x) | |
| return y | |
| def merge_parameter(base_params, override_params): | |
| """ | |
| Update the parameters in ``base_params`` with ``override_params``. | |
| Can be useful to override parsed command line arguments. | |
| Parameters | |
| ---------- | |
| base_params : namespace or dict | |
| Base parameters. A key-value mapping. | |
| override_params : dict or None | |
| Parameters to override. Usually the parameters got from ``get_next_parameters()``. | |
| When it is none, nothing will happen. | |
| Returns | |
| ------- | |
| namespace or dict | |
| The updated ``base_params``. Note that ``base_params`` will be updated inplace. The return value is | |
| only for convenience. | |
| """ | |
| if override_params is None: | |
| return base_params | |
| is_dict = isinstance(base_params, dict) | |
| for k, v in override_params.items(): | |
| if is_dict: | |
| if k not in base_params: | |
| raise ValueError('Key \'%s\' not found in base parameters.' % k) | |
| if type(base_params[k]) != type(v) and base_params[k] is not None: | |
| raise TypeError('Expected \'%s\' in override parameters to have type \'%s\', but found \'%s\'.' % | |
| (k, type(base_params[k]), type(v))) | |
| base_params[k] = v | |
| else: | |
| if not hasattr(base_params, k): | |
| raise ValueError('Key \'%s\' not found in base parameters.' % k) | |
| if type(getattr(base_params, k)) != type(v) and getattr(base_params, k) is not None: | |
| raise TypeError('Expected \'%s\' in override parameters to have type \'%s\', but found \'%s\'.' % | |
| (k, type(getattr(base_params, k)), type(v))) | |
| setattr(base_params, k, v) | |
| return base_params | |
| class ClassArgsValidator(object): | |
| """ | |
| NNI tuners/assessors/adivisors accept a `classArgs` parameter in experiment configuration file. | |
| This ClassArgsValidator interface is used to validate the classArgs section in exeperiment | |
| configuration file. | |
| """ | |
| def validate_class_args(self, **kwargs): | |
| """ | |
| Validate the classArgs configuration in experiment configuration file. | |
| Parameters | |
| ---------- | |
| kwargs: dict | |
| kwargs passed to tuner/assessor/advisor constructor | |
| Raises: | |
| Raise an execption if the kwargs is invalid. | |
| """ | |
| pass | |
| def choices(self, key, *args): | |
| """ | |
| Utility method to create a scheme to check whether the `key` is one of the `args`. | |
| Parameters: | |
| ---------- | |
| key: str | |
| key name of the data to be validated | |
| args: list of str | |
| list of the choices | |
| Returns: Schema | |
| -------- | |
| A scheme to check whether the `key` is one of the `args`. | |
| """ | |
| return And(lambda n: n in args, error='%s should be in [%s]!' % (key, str(args))) | |
| def range(self, key, keyType, start, end): | |
| """ | |
| Utility method to create a schema to check whether the `key` is in the range of [start, end]. | |
| Parameters: | |
| ---------- | |
| key: str | |
| key name of the data to be validated | |
| keyType: type | |
| python data type, such as int, float | |
| start: type is specified by keyType | |
| start of the range | |
| end: type is specified by keyType | |
| end of the range | |
| Returns: Schema | |
| -------- | |
| A scheme to check whether the `key` is in the range of [start, end]. | |
| """ | |
| return And( | |
| And(keyType, error='%s should be %s type!' % (key, keyType.__name__)), | |
| And(lambda n: start <= n <= end, error='%s should be in range of (%s, %s)!' % (key, start, end)) | |
| ) | |