|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import dataclasses
|
| import json
|
| import os
|
| import sys
|
| import types
|
| from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
|
| from collections.abc import Callable, Iterable
|
| from copy import copy
|
| from enum import Enum
|
| from inspect import isclass
|
| from pathlib import Path
|
| from typing import Any, Literal, NewType, Union, get_type_hints
|
|
|
|
|
| DataClass = NewType("DataClass", Any)
|
| DataClassType = NewType("DataClassType", Any)
|
|
|
|
|
|
|
| def string_to_bool(v):
|
| if isinstance(v, bool):
|
| return v
|
| if v.lower() in ("yes", "true", "t", "y", "1"):
|
| return True
|
| elif v.lower() in ("no", "false", "f", "n", "0"):
|
| return False
|
| else:
|
| raise ArgumentTypeError(
|
| f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
|
| )
|
|
|
|
|
| def make_choice_type_function(choices: list) -> Callable[[str], Any]:
|
| """
|
| Creates a mapping function from each choices string representation to the actual value. Used to support multiple
|
| value types for a single argument.
|
|
|
| Args:
|
| choices (list): List of choices.
|
|
|
| Returns:
|
| Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
|
| """
|
| str_to_choice = {str(choice): choice for choice in choices}
|
| return lambda arg: str_to_choice.get(arg, arg)
|
|
|
|
|
| def HfArg(
|
| *,
|
| aliases: str | list[str] | None = None,
|
| help: str | None = None,
|
| default: Any = dataclasses.MISSING,
|
| default_factory: Callable[[], Any] = dataclasses.MISSING,
|
| metadata: dict | None = None,
|
| **kwargs,
|
| ) -> dataclasses.Field:
|
| """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.
|
|
|
| Example comparing the use of `HfArg` and `dataclasses.field`:
|
| ```
|
| @dataclass
|
| class Args:
|
| regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
|
| hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
|
| ```
|
|
|
| Args:
|
| aliases (Union[str, list[str]], optional):
|
| Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
|
| Defaults to None.
|
| help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
|
| default (Any, optional):
|
| Default value for the argument. If not default or default_factory is specified, the argument is required.
|
| Defaults to dataclasses.MISSING.
|
| default_factory (Callable[[], Any], optional):
|
| The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
|
| default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
|
| Defaults to dataclasses.MISSING.
|
| metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
|
|
|
| Returns:
|
| Field: A `dataclasses.Field` with the desired properties.
|
| """
|
| if metadata is None:
|
|
|
| metadata = {}
|
| if aliases is not None:
|
| metadata["aliases"] = aliases
|
| if help is not None:
|
| metadata["help"] = help
|
|
|
| return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
|
|
|
|
|
| class HfArgumentParser(ArgumentParser):
|
| """
|
| This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
|
|
|
| The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
|
| arguments to the parser after initialization and you'll get the output back after parsing as an additional
|
| namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
|
|
|
| Args:
|
| dataclass_types (`DataClassType` or `Iterable[DataClassType]`, *optional*):
|
| Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
|
| kwargs (`dict[str, Any]`, *optional*):
|
| Passed to `argparse.ArgumentParser()` in the regular way.
|
| """
|
|
|
| dataclass_types: Iterable[DataClassType]
|
|
|
| def __init__(self, dataclass_types: DataClassType | Iterable[DataClassType] | None = None, **kwargs):
|
|
|
| if dataclass_types is None:
|
| dataclass_types = []
|
| elif not isinstance(dataclass_types, Iterable):
|
| dataclass_types = [dataclass_types]
|
|
|
|
|
| if "formatter_class" not in kwargs:
|
| kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
|
| super().__init__(**kwargs)
|
| if dataclasses.is_dataclass(dataclass_types):
|
| dataclass_types = [dataclass_types]
|
| self.dataclass_types = list(dataclass_types)
|
| for dtype in self.dataclass_types:
|
| self._add_dataclass_arguments(dtype)
|
|
|
| @staticmethod
|
| def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
|
|
|
|
|
|
|
|
|
| long_options = [f"--{field.name}"]
|
| if "_" in field.name:
|
| long_options.append(f"--{field.name.replace('_', '-')}")
|
|
|
| kwargs = field.metadata.copy()
|
|
|
|
|
| if isinstance(field.type, str):
|
| raise RuntimeError(
|
| "Unresolved type detected, which should have been done with the help of "
|
| "`typing.get_type_hints` method by default"
|
| )
|
|
|
| aliases = kwargs.pop("aliases", [])
|
| if isinstance(aliases, str):
|
| aliases = [aliases]
|
|
|
| origin_type = getattr(field.type, "__origin__", field.type)
|
| if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)):
|
| if str not in field.type.__args__ and (
|
| len(field.type.__args__) != 2 or type(None) not in field.type.__args__
|
| ):
|
| raise ValueError(
|
| "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because"
|
| " the argument parser only supports one type per argument."
|
| f" Problem encountered in field '{field.name}'."
|
| )
|
| if type(None) not in field.type.__args__:
|
|
|
| field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1]
|
| origin_type = getattr(field.type, "__origin__", field.type)
|
| elif bool not in field.type.__args__:
|
|
|
| field.type = (
|
| field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
|
| )
|
| origin_type = getattr(field.type, "__origin__", field.type)
|
|
|
|
|
|
|
| bool_kwargs = {}
|
| if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)):
|
| if origin_type is Literal:
|
| kwargs["choices"] = field.type.__args__
|
| else:
|
| kwargs["choices"] = [x.value for x in field.type]
|
|
|
| kwargs["type"] = make_choice_type_function(kwargs["choices"])
|
|
|
| if field.default is not dataclasses.MISSING:
|
| kwargs["default"] = field.default
|
| else:
|
| kwargs["required"] = True
|
| elif field.type is bool or field.type == bool | None:
|
|
|
|
|
| bool_kwargs = copy(kwargs)
|
|
|
|
|
| kwargs["type"] = string_to_bool
|
| if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
|
|
| default = False if field.default is dataclasses.MISSING else field.default
|
|
|
| kwargs["default"] = default
|
|
|
| kwargs["nargs"] = "?"
|
|
|
| kwargs["const"] = True
|
| elif isclass(origin_type) and issubclass(origin_type, list):
|
| kwargs["type"] = field.type.__args__[0]
|
| kwargs["nargs"] = "+"
|
| if field.default_factory is not dataclasses.MISSING:
|
| kwargs["default"] = field.default_factory()
|
| elif field.default is dataclasses.MISSING:
|
| kwargs["required"] = True
|
| else:
|
| kwargs["type"] = field.type
|
| if field.default is not dataclasses.MISSING:
|
| kwargs["default"] = field.default
|
| elif field.default_factory is not dataclasses.MISSING:
|
| kwargs["default"] = field.default_factory()
|
| else:
|
| kwargs["required"] = True
|
| parser.add_argument(*long_options, *aliases, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
| if field.default is True and (field.type is bool or field.type == bool | None):
|
| bool_kwargs["default"] = False
|
| parser.add_argument(
|
| f"--no_{field.name}",
|
| f"--no-{field.name.replace('_', '-')}",
|
| action="store_false",
|
| dest=field.name,
|
| **bool_kwargs,
|
| )
|
|
|
| def _add_dataclass_arguments(self, dtype: DataClassType):
|
| if hasattr(dtype, "_argument_group_name"):
|
| parser = self.add_argument_group(dtype._argument_group_name)
|
| else:
|
| parser = self
|
|
|
| try:
|
| type_hints: dict[str, type] = get_type_hints(dtype)
|
| except NameError:
|
| raise RuntimeError(
|
| f"Type resolution failed for {dtype}. Try declaring the class in global scope or "
|
| "removing line of `from __future__ import annotations` which opts in Postponed "
|
| "Evaluation of Annotations (PEP 563)"
|
| ) from None
|
|
|
| for field in dataclasses.fields(dtype):
|
| if not field.init:
|
| continue
|
| field.type = type_hints[field.name]
|
| self._parse_dataclass_field(parser, field)
|
|
|
| def parse_args_into_dataclasses(
|
| self,
|
| args=None,
|
| return_remaining_strings=False,
|
| look_for_args_file=True,
|
| args_filename=None,
|
| args_file_flag=None,
|
| ) -> tuple[DataClass, ...]:
|
| """
|
| Parse command-line args into instances of the specified dataclass types.
|
|
|
| This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
|
| docs.python.org/3/library/argparse.html#argparse.ArgumentParser.parse_args
|
|
|
| Args:
|
| args:
|
| List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
|
| return_remaining_strings:
|
| If true, also return a list of remaining argument strings.
|
| look_for_args_file:
|
| If true, will look for a ".args" file with the same base name as the entry point script for this
|
| process, and will append its potential content to the command line args.
|
| args_filename:
|
| If not None, will uses this file instead of the ".args" file specified in the previous argument.
|
| args_file_flag:
|
| If not None, will look for a file in the command-line args specified with this flag. The flag can be
|
| specified multiple times and precedence is determined by the order (last one wins).
|
|
|
| Returns:
|
| Tuple consisting of:
|
|
|
| - the dataclass instances in the same order as they were passed to the initializer.abspath
|
| - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
|
| after initialization.
|
| - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
|
| """
|
|
|
| if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)):
|
| args_files = []
|
|
|
| if args_filename:
|
| args_files.append(Path(args_filename))
|
| elif look_for_args_file and len(sys.argv):
|
| args_files.append(Path(sys.argv[0]).with_suffix(".args"))
|
|
|
|
|
| if args_file_flag:
|
|
|
| args_file_parser = ArgumentParser()
|
| args_file_parser.add_argument(args_file_flag, type=str, action="append")
|
|
|
|
|
| cfg, args = args_file_parser.parse_known_args(args=args)
|
| cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None)
|
|
|
| if cmd_args_file_paths:
|
| args_files.extend([Path(p) for p in cmd_args_file_paths])
|
|
|
| file_args = []
|
| for args_file in args_files:
|
| if args_file.exists():
|
| file_args += args_file.read_text().split()
|
|
|
|
|
|
|
| args = file_args + args if args is not None else file_args + sys.argv[1:]
|
| namespace, remaining_args = self.parse_known_args(args=args)
|
| outputs = []
|
| for dtype in self.dataclass_types:
|
| keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
| inputs = {k: v for k, v in vars(namespace).items() if k in keys}
|
| for k in keys:
|
| delattr(namespace, k)
|
| obj = dtype(**inputs)
|
| outputs.append(obj)
|
| if len(namespace.__dict__) > 0:
|
|
|
| outputs.append(namespace)
|
| if return_remaining_strings:
|
| return (*outputs, remaining_args)
|
| else:
|
| if remaining_args:
|
| raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
|
|
|
| return (*outputs,)
|
|
|
| def parse_dict(self, args: dict[str, Any], allow_extra_keys: bool = False) -> tuple[DataClass, ...]:
|
| """
|
| Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
|
| types.
|
|
|
| Args:
|
| args (`dict`):
|
| dict containing config values
|
| allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
| Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
|
|
|
| Returns:
|
| Tuple consisting of:
|
|
|
| - the dataclass instances in the same order as they were passed to the initializer.
|
| """
|
| unused_keys = set(args.keys())
|
| outputs = []
|
| for dtype in self.dataclass_types:
|
| keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
| inputs = {k: v for k, v in args.items() if k in keys}
|
| unused_keys.difference_update(inputs.keys())
|
| obj = dtype(**inputs)
|
| outputs.append(obj)
|
| if not allow_extra_keys and unused_keys:
|
| raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
|
| return tuple(outputs)
|
|
|
| def parse_json_file(self, json_file: str | os.PathLike, allow_extra_keys: bool = False) -> tuple[DataClass, ...]:
|
| """
|
| Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
| dataclass types.
|
|
|
| Args:
|
| json_file (`str` or `os.PathLike`):
|
| File name of the json file to parse
|
| allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
| Defaults to False. If False, will raise an exception if the json file contains keys that are not
|
| parsed.
|
|
|
| Returns:
|
| Tuple consisting of:
|
|
|
| - the dataclass instances in the same order as they were passed to the initializer.
|
| """
|
| with open(Path(json_file), encoding="utf-8") as open_json_file:
|
| data = json.loads(open_json_file.read())
|
| outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
|
| return tuple(outputs)
|
|
|
| def parse_yaml_file(self, yaml_file: str | os.PathLike, allow_extra_keys: bool = False) -> tuple[DataClass, ...]:
|
| """
|
| Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
|
| dataclass types.
|
|
|
| Args:
|
| yaml_file (`str` or `os.PathLike`):
|
| File name of the yaml file to parse
|
| allow_extra_keys (`bool`, *optional*, defaults to `False`):
|
| Defaults to False. If False, will raise an exception if the json file contains keys that are not
|
| parsed.
|
|
|
| Returns:
|
| Tuple consisting of:
|
|
|
| - the dataclass instances in the same order as they were passed to the initializer.
|
| """
|
| import yaml
|
|
|
| outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
|
| return tuple(outputs)
|
|
|