|
|
from __future__ import annotations |
|
|
|
|
|
from torchgen.api import dispatcher |
|
|
from torchgen.api.types import ( |
|
|
BaseCppType, |
|
|
BaseCType, |
|
|
Binding, |
|
|
boolT, |
|
|
ConstRefCType, |
|
|
CType, |
|
|
longT, |
|
|
NamedCType, |
|
|
tensorT, |
|
|
) |
|
|
from torchgen.model import ( |
|
|
Argument, |
|
|
BaseTy, |
|
|
BaseType, |
|
|
FunctionSchema, |
|
|
NativeFunction, |
|
|
NativeFunctionsViewGroup, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_binding = Binding( |
|
|
name="base", |
|
|
nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), |
|
|
argument=Argument( |
|
|
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None |
|
|
), |
|
|
default=None, |
|
|
) |
|
|
|
|
|
has_symbolic_inputs_binding = Binding( |
|
|
name="has_symbolic_inputs", |
|
|
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)), |
|
|
argument=Argument( |
|
|
name="has_symbolic_inputs", |
|
|
type=BaseType(BaseTy.bool), |
|
|
default=None, |
|
|
annotation=None, |
|
|
), |
|
|
default=None, |
|
|
) |
|
|
mutated_view_binding = Binding( |
|
|
name="mutated_view", |
|
|
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), |
|
|
argument=Argument( |
|
|
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None |
|
|
), |
|
|
default=None, |
|
|
) |
|
|
out_index_binding = Binding( |
|
|
name="out_index", |
|
|
nctype=NamedCType(name="out_index", type=BaseCType(longT)), |
|
|
argument=Argument( |
|
|
name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None |
|
|
), |
|
|
default=None, |
|
|
) |
|
|
reapply_views_binding = Binding( |
|
|
name="reapply_views", |
|
|
nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)), |
|
|
argument=Argument( |
|
|
name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None |
|
|
), |
|
|
default=None, |
|
|
) |
|
|
|
|
|
InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode") |
|
|
inverse_return_mode_binding = Binding( |
|
|
name="inverse_return_mode", |
|
|
nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)), |
|
|
argument=Argument( |
|
|
name="inverse_return_mode", |
|
|
|
|
|
type=BaseType(BaseTy.bool), |
|
|
default=None, |
|
|
annotation=None, |
|
|
), |
|
|
default=None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def classname(func: FunctionSchema, with_namespace: bool = False) -> str: |
|
|
namespace = "at::functionalization::" if with_namespace else "" |
|
|
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta" |
|
|
|
|
|
|
|
|
|
|
|
def name( |
|
|
g: NativeFunctionsViewGroup, |
|
|
*, |
|
|
is_reverse: bool, |
|
|
include_namespace: bool, |
|
|
reapply_views: bool | None = None, |
|
|
) -> str: |
|
|
if reapply_views is None: |
|
|
|
|
|
|
|
|
assert is_reverse |
|
|
if is_reverse: |
|
|
return reverse_name(g.view, include_namespace) |
|
|
|
|
|
assert include_namespace |
|
|
assert g.view_copy is not None |
|
|
api_name = ( |
|
|
g.view.func.name.unambiguous_name() |
|
|
if reapply_views |
|
|
else g.view_copy.func.name.unambiguous_name() |
|
|
) |
|
|
return f"at::_ops::{api_name}::call" |
|
|
|
|
|
|
|
|
def reverse_name(f: NativeFunction, include_namespace: bool) -> str: |
|
|
|
|
|
|
|
|
|
|
|
api_name = f.func.name.unambiguous_name() |
|
|
|
|
|
if include_namespace: |
|
|
return f"at::functionalization::FunctionalInverses::{api_name}_inverse" |
|
|
else: |
|
|
return f"{api_name}_inverse" |
|
|
|
|
|
|
|
|
def returns_type(func: FunctionSchema) -> CType: |
|
|
|
|
|
assert len(func.returns) >= 1 |
|
|
for ret in func.returns: |
|
|
assert ret.type.is_tensor_like() |
|
|
|
|
|
|
|
|
return BaseCType(tensorT) |
|
|
|
|
|
|
|
|
|
|
|
def is_multi_output(func: FunctionSchema) -> bool: |
|
|
return len(func.returns) > 1 or ( |
|
|
len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def base_ctor_arguments(func: FunctionSchema) -> list[Binding]: |
|
|
|
|
|
arguments = [has_symbolic_inputs_binding] |
|
|
|
|
|
|
|
|
|
|
|
if is_multi_output(func): |
|
|
arguments.append(out_index_binding) |
|
|
|
|
|
return arguments |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]: |
|
|
return attributes(func, owning=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]: |
|
|
args = func.arguments.flat_all |
|
|
assert args[0].type == BaseType(BaseTy.Tensor) |
|
|
return [ |
|
|
reapply_views_binding, |
|
|
inverse_return_mode_binding, |
|
|
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]], |
|
|
] |
|
|
|
|
|
|
|
|
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: |
|
|
args = func.arguments.flat_all |
|
|
assert args[0].type == BaseType(BaseTy.Tensor) |
|
|
non_self_args = args[1:] |
|
|
|
|
|
|
|
|
non_self_bindings = [dispatcher.argument(a) for a in non_self_args] |
|
|
if not is_reverse: |
|
|
|
|
|
return [base_binding] + non_self_bindings |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
if is_multi_output(func): |
|
|
return [ |
|
|
base_binding, |
|
|
mutated_view_binding, |
|
|
inverse_return_mode_binding, |
|
|
out_index_binding, |
|
|
] + non_self_bindings |
|
|
else: |
|
|
return [ |
|
|
base_binding, |
|
|
mutated_view_binding, |
|
|
inverse_return_mode_binding, |
|
|
] + non_self_bindings |
|
|
|