sleepyhead111's picture
Add files using upload-large-folder tool
b12c168 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from argparse import ArgumentParser
from dataclasses import MISSING, dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
def eval_str_list(x, x_type=float):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
try:
return list(map(x_type, x))
except TypeError:
return [x_type(x)]
class StrEnum(Enum):
def __str__(self):
return self.value
def __eq__(self, other: str):
return self.value == other
def __repr__(self):
return self.value
def ChoiceEnum(choices: List[str]):
"""return the Enum class used to enforce list of choices"""
return StrEnum("Choices", {k: k for k in choices})
@dataclass
class FairseqDataclass:
"""fairseq base dataclass that supported fetching attributes and metas"""
_name: Optional[str] = None
@staticmethod
def name():
return None
def _get_all_attributes(self) -> List[str]:
return [k for k in self.__dataclass_fields__.keys()]
def _get_meta(
self, attribute_name: str, meta: str, default: Optional[Any] = None
) -> Any:
return self.__dataclass_fields__[attribute_name].metadata.get(meta, default)
def _get_name(self, attribute_name: str) -> str:
return self.__dataclass_fields__[attribute_name].name
def _get_default(self, attribute_name: str) -> Any:
if hasattr(self, attribute_name):
if str(getattr(self, attribute_name)).startswith("${"):
return str(getattr(self, attribute_name))
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
"${"
):
return str(self.__dataclass_fields__[attribute_name].default)
elif (
getattr(self, attribute_name)
!= self.__dataclass_fields__[attribute_name].default
):
return getattr(self, attribute_name)
return self.__dataclass_fields__[attribute_name].default
def _get_default_factory(self, attribute_name: str) -> Any:
if hasattr(self, attribute_name):
if str(getattr(self, attribute_name)).startswith("${"):
return str(getattr(self, attribute_name))
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
"${"
):
return str(self.__dataclass_fields__[attribute_name].default)
elif (
getattr(self, attribute_name)
!= self.__dataclass_fields__[attribute_name].default_factory()
):
return getattr(self, attribute_name)
return self.__dataclass_fields__[attribute_name].default_factory()
def _get_type(self, attribute_name: str) -> Any:
return self.__dataclass_fields__[attribute_name].type
def _get_help(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "help")
def _get_argparse_const(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "argparse_const")
def _get_argparse_alias(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "argparse_alias")
def _get_choices(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "choices")
def gen_parser_from_dataclass(
parser: ArgumentParser,
dataclass_instance: FairseqDataclass,
delete_default: bool = False,
) -> None:
"""convert a dataclass instance to tailing parser arguments"""
import re
def argparse_name(name: str):
if name == "data":
# normally data is positional args
return name
if name == "_name":
# private member, skip
return None
return "--" + name.replace("_", "-")
def interpret_dc_type(field_type):
if isinstance(field_type, str):
raise RuntimeError()
typestring = str(field_type)
if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring):
return field_type.__args__[0]
return field_type
def get_kwargs_from_dc(
dataclass_instance: FairseqDataclass, k: str
) -> Dict[str, Any]:
"""k: dataclass attributes"""
field_type = dataclass_instance._get_type(k)
inter_type = interpret_dc_type(field_type)
if isinstance(inter_type, type) and issubclass(inter_type, List):
field_default = dataclass_instance._get_default_factory(k)
else:
field_default = dataclass_instance._get_default(k)
if isinstance(inter_type, type) and issubclass(inter_type, Enum):
field_choices = [t.value for t in list(inter_type)]
else:
field_choices = None
field_help = dataclass_instance._get_help(k)
field_const = dataclass_instance._get_argparse_const(k)
kwargs = {}
if isinstance(field_default, str) and field_default.startswith("${"):
kwargs["default"] = field_default
else:
if field_default is MISSING:
kwargs["required"] = True
if field_choices is not None:
kwargs["choices"] = field_choices
if (isinstance(inter_type, type) and issubclass(inter_type, List)) or (
"List" in str(inter_type)
):
if "int" in str(inter_type):
kwargs["type"] = lambda x: eval_str_list(x, int)
elif "float" in str(inter_type):
kwargs["type"] = lambda x: eval_str_list(x, float)
elif "str" in str(inter_type):
kwargs["type"] = lambda x: eval_str_list(x, str)
else:
raise NotImplementedError()
if field_default is not MISSING:
kwargs["default"] = ",".join(map(str, field_default))
elif (
isinstance(inter_type, type) and issubclass(inter_type, Enum)
) or "Enum" in str(inter_type):
kwargs["type"] = str
if field_default is not MISSING:
if isinstance(field_default, Enum):
kwargs["default"] = field_default.value
else:
kwargs["default"] = field_default
elif inter_type is bool:
kwargs["action"] = (
"store_false" if field_default is True else "store_true"
)
kwargs["default"] = field_default
else:
kwargs["type"] = inter_type
if field_default is not MISSING:
kwargs["default"] = field_default
kwargs["help"] = field_help
if field_const is not None:
kwargs["const"] = field_const
kwargs["nargs"] = "?"
return kwargs
for k in dataclass_instance._get_all_attributes():
field_name = argparse_name(dataclass_instance._get_name(k))
if field_name is None:
continue
kwargs = get_kwargs_from_dc(dataclass_instance, k)
field_args = [field_name]
alias = dataclass_instance._get_argparse_alias(k)
if alias is not None:
field_args.append(alias)
if "default" in kwargs:
if isinstance(kwargs["default"], str) and kwargs["default"].startswith(
"${"
):
if kwargs["help"] is None:
# this is a field with a name that will be added elsewhere
continue
else:
del kwargs["default"]
if delete_default:
del kwargs["default"]
try:
parser.add_argument(*field_args, **kwargs)
except ArgumentError:
pass