|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
import argparse |
|
|
import importlib |
|
|
from omegaconf import DictConfig |
|
|
|
|
|
def prepare_parser_from_dict(dic, parser=None): |
|
|
"""Prepare an argparser from a dictionary. |
|
|
|
|
|
Args: |
|
|
dic (dict): Two-level config dictionary with unique bottom-level keys. |
|
|
parser (argparse.ArgumentParser, optional): If a parser already |
|
|
exists, add the keys from the dictionary on the top of it. |
|
|
|
|
|
Returns: |
|
|
argparse.ArgumentParser: |
|
|
Parser instance with groups corresponding to the first level keys |
|
|
and arguments corresponding to the second level keys with default |
|
|
values given by the values. |
|
|
""" |
|
|
|
|
|
def standardized_entry_type(value): |
|
|
"""If the default value is None, replace NoneType by str_int_float. |
|
|
If the default value is boolean, look for boolean strings.""" |
|
|
if value is None: |
|
|
return str_int_float |
|
|
if isinstance(str2bool(value), bool): |
|
|
return str2bool_arg |
|
|
return type(value) |
|
|
|
|
|
if parser is None: |
|
|
parser = argparse.ArgumentParser() |
|
|
for k in dic.keys(): |
|
|
group = parser.add_argument_group(k) |
|
|
if isinstance(dic[k], list): |
|
|
entry_type = standardized_entry_type(dic[k]) |
|
|
group.add_argument("--" + k, default=dic[k], type=entry_type) |
|
|
elif isinstance(dic[k], dict): |
|
|
for kk in dic[k].keys(): |
|
|
entry_type = standardized_entry_type(dic[k][kk]) |
|
|
group.add_argument("--" + kk, default=dic[k][kk], type=entry_type) |
|
|
elif isinstance(dic[k], str): |
|
|
entry_type = standardized_entry_type(dic[k]) |
|
|
group.add_argument("--" + k, default=dic[k], type=entry_type) |
|
|
return parser |
|
|
|
|
|
|
|
|
def str_int_float(value): |
|
|
"""Type to convert strings to int, float (in this order) if possible. |
|
|
|
|
|
Args: |
|
|
value (str): Value to convert. |
|
|
|
|
|
Returns: |
|
|
int, float, str: Converted value. |
|
|
""" |
|
|
if isint(value): |
|
|
return int(value) |
|
|
if isfloat(value): |
|
|
return float(value) |
|
|
elif isinstance(value, str): |
|
|
return value |
|
|
|
|
|
|
|
|
def str2bool(value): |
|
|
"""Type to convert strings to Boolean (returns input if not boolean)""" |
|
|
if not isinstance(value, str): |
|
|
return value |
|
|
if value.lower() in ("yes", "true", "y", "1"): |
|
|
return True |
|
|
elif value.lower() in ("no", "false", "n", "0"): |
|
|
return False |
|
|
else: |
|
|
return value |
|
|
|
|
|
|
|
|
def str2bool_arg(value): |
|
|
"""Argparse type to convert strings to Boolean""" |
|
|
value = str2bool(value) |
|
|
if isinstance(value, bool): |
|
|
return value |
|
|
raise argparse.ArgumentTypeError("Boolean value expected.") |
|
|
|
|
|
|
|
|
def isfloat(value): |
|
|
"""Computes whether `value` can be cast to a float. |
|
|
|
|
|
Args: |
|
|
value (str): Value to check. |
|
|
|
|
|
Returns: |
|
|
bool: Whether `value` can be cast to a float. |
|
|
|
|
|
""" |
|
|
try: |
|
|
float(value) |
|
|
return True |
|
|
except ValueError: |
|
|
return False |
|
|
|
|
|
|
|
|
def isint(value): |
|
|
"""Computes whether `value` can be cast to an int |
|
|
|
|
|
Args: |
|
|
value (str): Value to check. |
|
|
|
|
|
Returns: |
|
|
bool: Whether `value` can be cast to an int. |
|
|
|
|
|
""" |
|
|
try: |
|
|
int(value) |
|
|
return True |
|
|
except ValueError: |
|
|
return False |
|
|
|
|
|
|
|
|
def parse_args_as_dict(parser, return_plain_args=False, args=None): |
|
|
"""Get a dict of dicts out of process `parser.parse_args()` |
|
|
|
|
|
Top-level keys corresponding to groups and bottom-level keys corresponding |
|
|
to arguments. Under `'main_args'`, the arguments which don't belong to a |
|
|
argparse group (i.e main arguments defined before parsing from a dict) can |
|
|
be found. |
|
|
|
|
|
Args: |
|
|
parser (argparse.ArgumentParser): ArgumentParser instance containing |
|
|
groups. Output of `prepare_parser_from_dict`. |
|
|
return_plain_args (bool): Whether to return the output or |
|
|
`parser.parse_args()`. |
|
|
args (list): List of arguments as read from the command line. |
|
|
Used for unit testing. |
|
|
|
|
|
Returns: |
|
|
dict: |
|
|
Dictionary of dictionaries containing the arguments. Optionally the |
|
|
direct output `parser.parse_args()`. |
|
|
""" |
|
|
args = parser.parse_args(args=args) |
|
|
args_dic = {} |
|
|
for group in parser._action_groups: |
|
|
group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions} |
|
|
args_dic[group.title] = group_dict |
|
|
if sys.version_info.minor == 10: |
|
|
args_dic["main_args"] = args_dic["positional arguments"] |
|
|
del args_dic["positional arguments"] |
|
|
else: |
|
|
args_dic["main_args"] = args_dic["optional arguments"] |
|
|
del args_dic["optional arguments"] |
|
|
if return_plain_args: |
|
|
return args_dic, args |
|
|
return args_dic |
|
|
|
|
|
def instantiate(config, **kwargs): |
|
|
if '__target__' in config: |
|
|
module_path, class_name = config['__target__'].rsplit('.', 1) |
|
|
module = importlib.import_module(module_path) |
|
|
cls = getattr(module, class_name) |
|
|
|
|
|
params = {} |
|
|
for key, value in config.items(): |
|
|
if key != '__target__': |
|
|
if isinstance(value, DictConfig) and '__target__' in value: |
|
|
params[key] = instantiate(value) |
|
|
else: |
|
|
params[key] = value |
|
|
|
|
|
params.update(kwargs) |
|
|
return cls(**params) |
|
|
else: |
|
|
|
|
|
return {k: instantiate(v, **kwargs) if isinstance(v, DictConfig) else v for k, v in config.items()} |