| | |
| | |
| | |
| | |
| |
|
| | from argparse import ArgumentParser |
| | from dataclasses import MISSING, dataclass |
| | from enum import Enum |
| | from typing import Any, Dict, List, Optional |
| |
|
| |
|
| | def eval_str_list(x, x_type=float): |
| | if x is None: |
| | return None |
| | if isinstance(x, str): |
| | x = eval(x) |
| | try: |
| | return list(map(x_type, x)) |
| | except TypeError: |
| | return [x_type(x)] |
| |
|
| |
|
| | class StrEnum(Enum): |
| | def __str__(self): |
| | return self.value |
| |
|
| | def __eq__(self, other: str): |
| | return self.value == other |
| |
|
| | def __repr__(self): |
| | return self.value |
| |
|
| |
|
| | def ChoiceEnum(choices: List[str]): |
| | """return the Enum class used to enforce list of choices""" |
| | return StrEnum("Choices", {k: k for k in choices}) |
| |
|
| |
|
| | @dataclass |
| | class FairseqDataclass: |
| | """fairseq base dataclass that supported fetching attributes and metas""" |
| |
|
| | _name: Optional[str] = None |
| |
|
| | @staticmethod |
| | def name(): |
| | return None |
| |
|
| | def _get_all_attributes(self) -> List[str]: |
| | return [k for k in self.__dataclass_fields__.keys()] |
| |
|
| | def _get_meta( |
| | self, attribute_name: str, meta: str, default: Optional[Any] = None |
| | ) -> Any: |
| | return self.__dataclass_fields__[attribute_name].metadata.get(meta, default) |
| |
|
| | def _get_name(self, attribute_name: str) -> str: |
| | return self.__dataclass_fields__[attribute_name].name |
| |
|
| | def _get_default(self, attribute_name: str) -> Any: |
| | if hasattr(self, attribute_name): |
| | if str(getattr(self, attribute_name)).startswith("${"): |
| | return str(getattr(self, attribute_name)) |
| | elif str(self.__dataclass_fields__[attribute_name].default).startswith( |
| | "${" |
| | ): |
| | return str(self.__dataclass_fields__[attribute_name].default) |
| | elif ( |
| | getattr(self, attribute_name) |
| | != self.__dataclass_fields__[attribute_name].default |
| | ): |
| | return getattr(self, attribute_name) |
| | return self.__dataclass_fields__[attribute_name].default |
| |
|
| | def _get_default_factory(self, attribute_name: str) -> Any: |
| | if hasattr(self, attribute_name): |
| | if str(getattr(self, attribute_name)).startswith("${"): |
| | return str(getattr(self, attribute_name)) |
| | elif str(self.__dataclass_fields__[attribute_name].default).startswith( |
| | "${" |
| | ): |
| | return str(self.__dataclass_fields__[attribute_name].default) |
| | elif ( |
| | getattr(self, attribute_name) |
| | != self.__dataclass_fields__[attribute_name].default_factory() |
| | ): |
| | return getattr(self, attribute_name) |
| | return self.__dataclass_fields__[attribute_name].default_factory() |
| |
|
| | def _get_type(self, attribute_name: str) -> Any: |
| | return self.__dataclass_fields__[attribute_name].type |
| |
|
| | def _get_help(self, attribute_name: str) -> Any: |
| | return self._get_meta(attribute_name, "help") |
| |
|
| | def _get_argparse_const(self, attribute_name: str) -> Any: |
| | return self._get_meta(attribute_name, "argparse_const") |
| |
|
| | def _get_argparse_alias(self, attribute_name: str) -> Any: |
| | return self._get_meta(attribute_name, "argparse_alias") |
| |
|
| | def _get_choices(self, attribute_name: str) -> Any: |
| | return self._get_meta(attribute_name, "choices") |
| |
|
| |
|
| | def gen_parser_from_dataclass( |
| | parser: ArgumentParser, |
| | dataclass_instance: FairseqDataclass, |
| | delete_default: bool = False, |
| | ) -> None: |
| | """convert a dataclass instance to tailing parser arguments""" |
| | import re |
| |
|
| | def argparse_name(name: str): |
| | if name == "data": |
| | |
| | return name |
| | if name == "_name": |
| | |
| | return None |
| | return "--" + name.replace("_", "-") |
| |
|
| | def interpret_dc_type(field_type): |
| | if isinstance(field_type, str): |
| | raise RuntimeError() |
| | typestring = str(field_type) |
| | if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): |
| | return field_type.__args__[0] |
| | return field_type |
| |
|
| | def get_kwargs_from_dc( |
| | dataclass_instance: FairseqDataclass, k: str |
| | ) -> Dict[str, Any]: |
| | """k: dataclass attributes""" |
| | field_type = dataclass_instance._get_type(k) |
| | inter_type = interpret_dc_type(field_type) |
| | if isinstance(inter_type, type) and issubclass(inter_type, List): |
| | field_default = dataclass_instance._get_default_factory(k) |
| | else: |
| | field_default = dataclass_instance._get_default(k) |
| |
|
| | if isinstance(inter_type, type) and issubclass(inter_type, Enum): |
| | field_choices = [t.value for t in list(inter_type)] |
| | else: |
| | field_choices = None |
| |
|
| | field_help = dataclass_instance._get_help(k) |
| | field_const = dataclass_instance._get_argparse_const(k) |
| | kwargs = {} |
| | if isinstance(field_default, str) and field_default.startswith("${"): |
| | kwargs["default"] = field_default |
| | else: |
| | if field_default is MISSING: |
| | kwargs["required"] = True |
| | if field_choices is not None: |
| | kwargs["choices"] = field_choices |
| | if (isinstance(inter_type, type) and issubclass(inter_type, List)) or ( |
| | "List" in str(inter_type) |
| | ): |
| | if "int" in str(inter_type): |
| | kwargs["type"] = lambda x: eval_str_list(x, int) |
| | elif "float" in str(inter_type): |
| | kwargs["type"] = lambda x: eval_str_list(x, float) |
| | elif "str" in str(inter_type): |
| | kwargs["type"] = lambda x: eval_str_list(x, str) |
| | else: |
| | raise NotImplementedError() |
| | if field_default is not MISSING: |
| | kwargs["default"] = ",".join(map(str, field_default)) |
| | elif ( |
| | isinstance(inter_type, type) and issubclass(inter_type, Enum) |
| | ) or "Enum" in str(inter_type): |
| | kwargs["type"] = str |
| | if field_default is not MISSING: |
| | if isinstance(field_default, Enum): |
| | kwargs["default"] = field_default.value |
| | else: |
| | kwargs["default"] = field_default |
| | elif inter_type is bool: |
| | kwargs["action"] = ( |
| | "store_false" if field_default is True else "store_true" |
| | ) |
| | kwargs["default"] = field_default |
| | else: |
| | kwargs["type"] = inter_type |
| | if field_default is not MISSING: |
| | kwargs["default"] = field_default |
| |
|
| | kwargs["help"] = field_help |
| | if field_const is not None: |
| | kwargs["const"] = field_const |
| | kwargs["nargs"] = "?" |
| | return kwargs |
| |
|
| | for k in dataclass_instance._get_all_attributes(): |
| | field_name = argparse_name(dataclass_instance._get_name(k)) |
| | if field_name is None: |
| | continue |
| |
|
| | kwargs = get_kwargs_from_dc(dataclass_instance, k) |
| |
|
| | field_args = [field_name] |
| | alias = dataclass_instance._get_argparse_alias(k) |
| | if alias is not None: |
| | field_args.append(alias) |
| |
|
| | if "default" in kwargs: |
| | if isinstance(kwargs["default"], str) and kwargs["default"].startswith( |
| | "${" |
| | ): |
| | if kwargs["help"] is None: |
| | |
| | continue |
| | else: |
| | del kwargs["default"] |
| | if delete_default: |
| | del kwargs["default"] |
| | try: |
| | parser.add_argument(*field_args, **kwargs) |
| | except ArgumentError: |
| | pass |
| |
|