| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import dataclasses |
| | import json |
| | import sys |
| | import types |
| | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError |
| | from copy import copy |
| | from enum import Enum |
| | from inspect import isclass |
| | from pathlib import Path |
| | from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints |
| |
|
| | import yaml |
| |
|
| |
|
| | DataClass = NewType("DataClass", Any) |
| | DataClassType = NewType("DataClassType", Any) |
| |
|
| |
|
| | |
| | def string_to_bool(v): |
| | if isinstance(v, bool): |
| | return v |
| | if v.lower() in ("yes", "true", "t", "y", "1"): |
| | return True |
| | elif v.lower() in ("no", "false", "f", "n", "0"): |
| | return False |
| | else: |
| | raise ArgumentTypeError( |
| | f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." |
| | ) |
| |
|
| |
|
| | def make_choice_type_function(choices: list) -> Callable[[str], Any]: |
| | """ |
| | Creates a mapping function from each choices string representation to the actual value. Used to support multiple |
| | value types for a single argument. |
| | |
| | Args: |
| | choices (list): List of choices. |
| | |
| | Returns: |
| | Callable[[str], Any]: Mapping function from string representation to actual value for each choice. |
| | """ |
| | str_to_choice = {str(choice): choice for choice in choices} |
| | return lambda arg: str_to_choice.get(arg, arg) |
| |
|
| |
|
| | def HfArg( |
| | *, |
| | aliases: Union[str, List[str]] = None, |
| | help: str = None, |
| | default: Any = dataclasses.MISSING, |
| | default_factory: Callable[[], Any] = dataclasses.MISSING, |
| | metadata: dict = None, |
| | **kwargs, |
| | ) -> dataclasses.Field: |
| | """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`. |
| | |
| | Example comparing the use of `HfArg` and `dataclasses.field`: |
| | ``` |
| | @dataclass |
| | class Args: |
| | regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"}) |
| | hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!") |
| | ``` |
| | |
| | Args: |
| | aliases (Union[str, List[str]], optional): |
| | Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`. |
| | Defaults to None. |
| | help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None. |
| | default (Any, optional): |
| | Default value for the argument. If not default or default_factory is specified, the argument is required. |
| | Defaults to dataclasses.MISSING. |
| | default_factory (Callable[[], Any], optional): |
| | The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide |
| | default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`. |
| | Defaults to dataclasses.MISSING. |
| | metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None. |
| | |
| | Returns: |
| | Field: A `dataclasses.Field` with the desired properties. |
| | """ |
| | if metadata is None: |
| | |
| | metadata = {} |
| | if aliases is not None: |
| | metadata["aliases"] = aliases |
| | if help is not None: |
| | metadata["help"] = help |
| |
|
| | return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs) |
| |
|
| |
|
| | class HfArgumentParser(ArgumentParser): |
| | """ |
| | This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. |
| | |
| | The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) |
| | arguments to the parser after initialization and you'll get the output back after parsing as an additional |
| | namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass. |
| | """ |
| |
|
| | dataclass_types: Iterable[DataClassType] |
| |
|
| | def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs): |
| | """ |
| | Args: |
| | dataclass_types: |
| | Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args. |
| | kwargs (`Dict[str, Any]`, *optional*): |
| | Passed to `argparse.ArgumentParser()` in the regular way. |
| | """ |
| | |
| | if "formatter_class" not in kwargs: |
| | kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter |
| | super().__init__(**kwargs) |
| | if dataclasses.is_dataclass(dataclass_types): |
| | dataclass_types = [dataclass_types] |
| | self.dataclass_types = list(dataclass_types) |
| | for dtype in self.dataclass_types: |
| | self._add_dataclass_arguments(dtype) |
| |
|
| | @staticmethod |
| | def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): |
| | field_name = f"--{field.name}" |
| | kwargs = field.metadata.copy() |
| | |
| | |
| | if isinstance(field.type, str): |
| | raise RuntimeError( |
| | "Unresolved type detected, which should have been done with the help of " |
| | "`typing.get_type_hints` method by default" |
| | ) |
| |
|
| | aliases = kwargs.pop("aliases", []) |
| | if isinstance(aliases, str): |
| | aliases = [aliases] |
| |
|
| | origin_type = getattr(field.type, "__origin__", field.type) |
| | if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)): |
| | if str not in field.type.__args__ and ( |
| | len(field.type.__args__) != 2 or type(None) not in field.type.__args__ |
| | ): |
| | raise ValueError( |
| | "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because" |
| | " the argument parser only supports one type per argument." |
| | f" Problem encountered in field '{field.name}'." |
| | ) |
| | if type(None) not in field.type.__args__: |
| | |
| | field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] |
| | origin_type = getattr(field.type, "__origin__", field.type) |
| | elif bool not in field.type.__args__: |
| | |
| | field.type = ( |
| | field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] |
| | ) |
| | origin_type = getattr(field.type, "__origin__", field.type) |
| |
|
| | |
| | |
| | bool_kwargs = {} |
| | if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)): |
| | if origin_type is Literal: |
| | kwargs["choices"] = field.type.__args__ |
| | else: |
| | kwargs["choices"] = [x.value for x in field.type] |
| |
|
| | kwargs["type"] = make_choice_type_function(kwargs["choices"]) |
| |
|
| | if field.default is not dataclasses.MISSING: |
| | kwargs["default"] = field.default |
| | else: |
| | kwargs["required"] = True |
| | elif field.type is bool or field.type == Optional[bool]: |
| | |
| | |
| | bool_kwargs = copy(kwargs) |
| |
|
| | |
| | kwargs["type"] = string_to_bool |
| | if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): |
| | |
| | default = False if field.default is dataclasses.MISSING else field.default |
| | |
| | kwargs["default"] = default |
| | |
| | kwargs["nargs"] = "?" |
| | |
| | kwargs["const"] = True |
| | elif isclass(origin_type) and issubclass(origin_type, list): |
| | kwargs["type"] = field.type.__args__[0] |
| | kwargs["nargs"] = "+" |
| | if field.default_factory is not dataclasses.MISSING: |
| | kwargs["default"] = field.default_factory() |
| | elif field.default is dataclasses.MISSING: |
| | kwargs["required"] = True |
| | else: |
| | kwargs["type"] = field.type |
| | if field.default is not dataclasses.MISSING: |
| | kwargs["default"] = field.default |
| | elif field.default_factory is not dataclasses.MISSING: |
| | kwargs["default"] = field.default_factory() |
| | else: |
| | kwargs["required"] = True |
| | parser.add_argument(field_name, *aliases, **kwargs) |
| |
|
| | |
| | |
| | |
| | |
| | if field.default is True and (field.type is bool or field.type == Optional[bool]): |
| | bool_kwargs["default"] = False |
| | parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs) |
| |
|
| | def _add_dataclass_arguments(self, dtype: DataClassType): |
| | if hasattr(dtype, "_argument_group_name"): |
| | parser = self.add_argument_group(dtype._argument_group_name) |
| | else: |
| | parser = self |
| |
|
| | try: |
| | type_hints: Dict[str, type] = get_type_hints(dtype) |
| | except NameError: |
| | raise RuntimeError( |
| | f"Type resolution failed for {dtype}. Try declaring the class in global scope or " |
| | "removing line of `from __future__ import annotations` which opts in Postponed " |
| | "Evaluation of Annotations (PEP 563)" |
| | ) |
| | except TypeError as ex: |
| | |
| | if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex): |
| | python_version = ".".join(map(str, sys.version_info[:3])) |
| | raise RuntimeError( |
| | f"Type resolution failed for {dtype} on Python {python_version}. Try removing " |
| | "line of `from __future__ import annotations` which opts in union types as " |
| | "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To " |
| | "support Python versions that lower than 3.10, you need to use " |
| | "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of " |
| | "`X | None`." |
| | ) from ex |
| | raise |
| |
|
| | for field in dataclasses.fields(dtype): |
| | if not field.init: |
| | continue |
| | field.type = type_hints[field.name] |
| | self._parse_dataclass_field(parser, field) |
| |
|
| | def parse_args_into_dataclasses( |
| | self, |
| | args=None, |
| | return_remaining_strings=False, |
| | look_for_args_file=True, |
| | args_filename=None, |
| | args_file_flag=None, |
| | ) -> Tuple[DataClass, ...]: |
| | """ |
| | Parse command-line args into instances of the specified dataclass types. |
| | |
| | This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at: |
| | docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args |
| | |
| | Args: |
| | args: |
| | List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) |
| | return_remaining_strings: |
| | If true, also return a list of remaining argument strings. |
| | look_for_args_file: |
| | If true, will look for a ".args" file with the same base name as the entry point script for this |
| | process, and will append its potential content to the command line args. |
| | args_filename: |
| | If not None, will uses this file instead of the ".args" file specified in the previous argument. |
| | args_file_flag: |
| | If not None, will look for a file in the command-line args specified with this flag. The flag can be |
| | specified multiple times and precedence is determined by the order (last one wins). |
| | |
| | Returns: |
| | Tuple consisting of: |
| | |
| | - the dataclass instances in the same order as they were passed to the initializer.abspath |
| | - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser |
| | after initialization. |
| | - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) |
| | """ |
| |
|
| | if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)): |
| | args_files = [] |
| |
|
| | if args_filename: |
| | args_files.append(Path(args_filename)) |
| | elif look_for_args_file and len(sys.argv): |
| | args_files.append(Path(sys.argv[0]).with_suffix(".args")) |
| |
|
| | |
| | if args_file_flag: |
| | |
| | args_file_parser = ArgumentParser() |
| | args_file_parser.add_argument(args_file_flag, type=str, action="append") |
| |
|
| | |
| | cfg, args = args_file_parser.parse_known_args(args=args) |
| | cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None) |
| |
|
| | if cmd_args_file_paths: |
| | args_files.extend([Path(p) for p in cmd_args_file_paths]) |
| |
|
| | file_args = [] |
| | for args_file in args_files: |
| | if args_file.exists(): |
| | file_args += args_file.read_text().split() |
| |
|
| | |
| | |
| | args = file_args + args if args is not None else file_args + sys.argv[1:] |
| | namespace, remaining_args = self.parse_known_args(args=args) |
| | outputs = [] |
| | for dtype in self.dataclass_types: |
| | keys = {f.name for f in dataclasses.fields(dtype) if f.init} |
| | inputs = {k: v for k, v in vars(namespace).items() if k in keys} |
| | for k in keys: |
| | delattr(namespace, k) |
| | obj = dtype(**inputs) |
| | outputs.append(obj) |
| | if len(namespace.__dict__) > 0: |
| | |
| | outputs.append(namespace) |
| | if return_remaining_strings: |
| | return (*outputs, remaining_args) |
| | else: |
| | if remaining_args: |
| | raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") |
| |
|
| | return (*outputs,) |
| |
|
| | def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: |
| | """ |
| | Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass |
| | types. |
| | |
| | Args: |
| | args (`dict`): |
| | dict containing config values |
| | allow_extra_keys (`bool`, *optional*, defaults to `False`): |
| | Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed. |
| | |
| | Returns: |
| | Tuple consisting of: |
| | |
| | - the dataclass instances in the same order as they were passed to the initializer. |
| | """ |
| | unused_keys = set(args.keys()) |
| | outputs = [] |
| | for dtype in self.dataclass_types: |
| | keys = {f.name for f in dataclasses.fields(dtype) if f.init} |
| | inputs = {k: v for k, v in args.items() if k in keys} |
| | unused_keys.difference_update(inputs.keys()) |
| | obj = dtype(**inputs) |
| | outputs.append(obj) |
| | if not allow_extra_keys and unused_keys: |
| | raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") |
| | return tuple(outputs) |
| |
|
| | def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: |
| | """ |
| | Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the |
| | dataclass types. |
| | |
| | Args: |
| | json_file (`str` or `os.PathLike`): |
| | File name of the json file to parse |
| | allow_extra_keys (`bool`, *optional*, defaults to `False`): |
| | Defaults to False. If False, will raise an exception if the json file contains keys that are not |
| | parsed. |
| | |
| | Returns: |
| | Tuple consisting of: |
| | |
| | - the dataclass instances in the same order as they were passed to the initializer. |
| | """ |
| | with open(Path(json_file), encoding="utf-8") as open_json_file: |
| | data = json.loads(open_json_file.read()) |
| | outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys) |
| | return tuple(outputs) |
| |
|
| | def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: |
| | """ |
| | Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the |
| | dataclass types. |
| | |
| | Args: |
| | yaml_file (`str` or `os.PathLike`): |
| | File name of the yaml file to parse |
| | allow_extra_keys (`bool`, *optional*, defaults to `False`): |
| | Defaults to False. If False, will raise an exception if the json file contains keys that are not |
| | parsed. |
| | |
| | Returns: |
| | Tuple consisting of: |
| | |
| | - the dataclass instances in the same order as they were passed to the initializer. |
| | """ |
| | outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) |
| | return tuple(outputs) |
| |
|