Spaces:
Build error
Build error
| """Utilities for modifying the configuration.""" | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| import itertools | |
| import operator | |
| from functools import reduce | |
| from typing import Any, Generator, List | |
| from omegaconf import DictConfig | |
| def flatten_sweep_params(params_dict: DictConfig) -> DictConfig: | |
| """Flatten the nested parameters section of the config object. | |
| We need to flatten the params so that all the nested keys are concatenated into a single string. | |
| This is useful when | |
| - We need to do a cartesian product of all the combinations of the configuration for grid search. | |
| - Save keys as headers for csv | |
| - Add the config to `wandb` sweep. | |
| Args: | |
| params_dict: DictConfig: The dictionary containing the hpo parameters in the original, nested, structure. | |
| Returns: | |
| flattened version of the parameter dictionary. | |
| """ | |
| def flatten_nested_dict(nested_params: DictConfig, keys: List[str], flattened_params: DictConfig): | |
| """Flatten nested dictionary. | |
| Recursive helper function that traverses the nested config object and stores the leaf nodes in a flattened | |
| dictionary. | |
| Args: | |
| nested_params: DictConfig: config object containing the original parameters. | |
| keys: List[str]: list of keys leading to the current location in the config. | |
| flattened_params: DictConfig: Dictionary in which the flattened parameters are stored. | |
| """ | |
| for name, cfg in nested_params.items(): | |
| if isinstance(cfg, DictConfig): | |
| flatten_nested_dict(cfg, keys + [str(name)], flattened_params) | |
| else: | |
| key = ".".join(keys + [str(name)]) | |
| flattened_params[key] = cfg | |
| flattened_params_dict = DictConfig({}) | |
| flatten_nested_dict(params_dict, [], flattened_params_dict) | |
| return flattened_params_dict | |
| def get_run_config(params_dict: DictConfig) -> Generator[DictConfig, None, None]: | |
| """Yields configuration for a single run. | |
| Args: | |
| params_dict (DictConfig): Configuration for grid search. | |
| Example: | |
| >>> dummy_config = DictConfig({ | |
| "parent1":{ | |
| "child1": ['a', 'b', 'c'], | |
| "child2": [1, 2, 3] | |
| }, | |
| "parent2":['model1', 'model2'] | |
| }) | |
| >>> for run_config in get_run_config(dummy_config): | |
| >>> print(run_config) | |
| {'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model1'} | |
| {'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model2'} | |
| {'parent1.child1': 'a', 'parent1.child2': 2, 'parent2': 'model1'} | |
| ... | |
| Yields: | |
| Generator[DictConfig]: Dictionary containing flattened keys | |
| and values for current run. | |
| """ | |
| params = flatten_sweep_params(params_dict) | |
| combinations = list(itertools.product(*params.values())) | |
| keys = params.keys() | |
| for combination in combinations: | |
| run_config = DictConfig({}) | |
| for key, val in zip(keys, combination): | |
| run_config[key] = val | |
| yield run_config | |
| def get_from_nested_config(config: DictConfig, keymap: List) -> Any: | |
| """Retrieves an item from a nested config object using a list of keys. | |
| Args: | |
| config: DictConfig: nested DictConfig object | |
| keymap: List[str]: list of keys corresponding to item that should be retrieved. | |
| """ | |
| return reduce(operator.getitem, keymap, config) | |
| def set_in_nested_config(config: DictConfig, keymap: List, value: Any): | |
| """Set an item in a nested config object using a list of keys. | |
| Args: | |
| config: DictConfig: nested DictConfig object | |
| keymap: List[str]: list of keys corresponding to item that should be set. | |
| value: Any: Value that should be assigned to the dictionary item at the specified location. | |
| Example: | |
| >>> dummy_config = DictConfig({ | |
| "parent1":{ | |
| "child1": ['a', 'b', 'c'], | |
| "child2": [1, 2, 3] | |
| }, | |
| "parent2":['model1', 'model2'] | |
| }) | |
| >>> model_config = DictConfig({ | |
| "parent1":{ | |
| "child1": 'e', | |
| "child2": 4, | |
| }, | |
| "parent3": False | |
| }) | |
| >>> for run_config in get_run_config(dummy_config): | |
| >>> print("Original model config", model_config) | |
| >>> print("Suggested config", run_config) | |
| >>> for param in run_config.keys(): | |
| >>> set_in_nested_config(model_config, param.split('.'), run_config[param]) | |
| >>> print("Replaced model config", model_config) | |
| >>> break | |
| Original model config {'parent1': {'child1': 'e', 'child2': 4}, 'parent3': False} | |
| Suggested config {'parent1.child1': 'a', 'parent1.child2': 1, 'parent2': 'model1'} | |
| Replaced model config {'parent1': {'child1': 'a', 'child2': 1}, 'parent3': False, 'parent2': 'model1'} | |
| """ | |
| get_from_nested_config(config, keymap[:-1])[keymap[-1]] = value | |