|
|
|
|
|
import os |
|
|
from functools import partial |
|
|
from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX |
|
|
from torch._sources import ParsedDef, SourceContext |
|
|
from typing import Callable, Dict, List |
|
|
import ast |
|
|
import dataclasses |
|
|
import inspect |
|
|
import sys |
|
|
|
|
|
def _get_fake_filename(cls, method_name): |
|
|
return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name) |
|
|
|
|
|
|
|
|
def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef: |
|
|
body = '\n'.join(f' {b}' for b in body_lines) |
|
|
decl = f'def {name}{signature}:\n{body}' |
|
|
|
|
|
|
|
|
try: |
|
|
py_ast = ast.parse(decl) |
|
|
except SyntaxError: |
|
|
|
|
|
|
|
|
raise RuntimeError( |
|
|
f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. " |
|
|
"Please file a bug report at <https://github.com/pytorch/pytorch/issues>" |
|
|
) |
|
|
fake_filename = _get_fake_filename(cls, name) |
|
|
|
|
|
return ParsedDef( |
|
|
py_ast, |
|
|
ctx=SourceContext( |
|
|
source=decl, |
|
|
filename=fake_filename, |
|
|
file_lineno=0, |
|
|
leading_whitespace_len=0 |
|
|
), |
|
|
source=decl, |
|
|
filename=fake_filename, |
|
|
file_lineno=0 |
|
|
) |
|
|
|
|
|
|
|
|
def synthesize__init__(cls) -> ParsedDef: |
|
|
|
|
|
|
|
|
if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)): |
|
|
raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses") |
|
|
|
|
|
|
|
|
|
|
|
signature = inspect.signature(cls.__init__) |
|
|
|
|
|
|
|
|
|
|
|
init_vars: List[str] = [] |
|
|
if sys.version_info >= (3, 8): |
|
|
params = [] |
|
|
for name, param in signature.parameters.items(): |
|
|
ann = param.annotation |
|
|
|
|
|
if isinstance(ann, dataclasses.InitVar): |
|
|
|
|
|
init_vars.append(name) |
|
|
params.append(param.replace(annotation=ann.type)) |
|
|
else: |
|
|
params.append(param) |
|
|
|
|
|
signature = signature.replace(parameters=params) |
|
|
|
|
|
body = [ |
|
|
|
|
|
f'self.{field.name} = {field.name}' |
|
|
for field in dataclasses.fields(cls) |
|
|
if field.init and field.name not in init_vars |
|
|
] |
|
|
|
|
|
if hasattr(cls, '__post_init__'): |
|
|
body.append('self.__post_init__(' + ', '.join(init_vars) + ')') |
|
|
|
|
|
return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature)) |
|
|
|
|
|
|
|
|
def synthesize__repr__(cls) -> ParsedDef: |
|
|
return compose_fn( |
|
|
cls, '__repr__', |
|
|
[f"return '{cls.__name__}(" + ", ".join([ |
|
|
f"{field.name}=self.{field.name}" |
|
|
for field in dataclasses.fields(cls) if field.repr |
|
|
]) + ")'"], |
|
|
signature='(self) -> str' |
|
|
) |
|
|
|
|
|
def synthesize__hash__(cls) -> ParsedDef: |
|
|
return compose_fn( |
|
|
cls, '__hash__', |
|
|
[ |
|
|
|
|
|
|
|
|
"raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')" |
|
|
], |
|
|
signature='(self) -> int' |
|
|
) |
|
|
|
|
|
|
|
|
def synthesize_equality(cls, name: str, converse: str) -> ParsedDef: |
|
|
return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[ |
|
|
f"if val1 {converse} val2: return False" |
|
|
]) |
|
|
|
|
|
def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef: |
|
|
return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[ |
|
|
f"if val1 {op} val2: return True", |
|
|
f"elif val2 {op} val1: return False", |
|
|
]) |
|
|
|
|
|
def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef: |
|
|
body = [] |
|
|
for field in dataclasses.fields(cls): |
|
|
if not field.compare: |
|
|
continue |
|
|
|
|
|
body.extend([ |
|
|
f"val1 = self.{field.name}", |
|
|
f"val2 = other.{field.name}", |
|
|
]) |
|
|
body.extend( |
|
|
inner if not is_optional(field.type) else [ |
|
|
|
|
|
"if val1 is not None and val2 is not None:", |
|
|
*[' ' + line for line in inner], |
|
|
"elif (val1 is None) != (val2 is None):", |
|
|
f" raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else " return False" |
|
|
] |
|
|
) |
|
|
|
|
|
body.append(f"return {allow_eq}") |
|
|
return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool') |
|
|
|
|
|
DATACLASS_MAGIC_METHODS: Dict[str, Callable] = { |
|
|
"__init__": synthesize__init__, |
|
|
"__repr__": synthesize__repr__, |
|
|
"__hash__": synthesize__hash__, |
|
|
"__eq__": partial(synthesize_equality, name="__eq__", converse="!="), |
|
|
"__ne__": partial(synthesize_equality, name="__ne__", converse="=="), |
|
|
"__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False), |
|
|
"__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True), |
|
|
"__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False), |
|
|
"__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True), |
|
|
} |
|
|
|