File size: 7,815 Bytes
2216aae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
from __future__ import annotations
from torchgen.api import dispatcher
from torchgen.api.types import (
BaseCppType,
BaseCType,
Binding,
boolT,
ConstRefCType,
CType,
longT,
NamedCType,
tensorT,
)
from torchgen.model import (
Argument,
BaseTy,
BaseType,
FunctionSchema,
NativeFunction,
NativeFunctionsViewGroup,
)
# This file describes the translation of JIT schema to API's used
# when creating `ViewMeta` specializations that are used by the functionalization pass.
# These API's mostly follow the dispatcher API, with one difference:
# - While the forward function just directly calls into the at::_ops API
# (following the dispatcher convention), the logic here for the reverse function
# is responsible for generating both the call-site, and the declarations
# (which are implemented manually in the at::functionalization::impl namespace).
# Define some specific lambda input arguments.
base_binding = Binding(
name="base",
nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
has_symbolic_inputs_binding = Binding(
name="has_symbolic_inputs",
nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
argument=Argument(
name="has_symbolic_inputs",
type=BaseType(BaseTy.bool),
default=None,
annotation=None,
),
default=None,
)
mutated_view_binding = Binding(
name="mutated_view",
nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
argument=Argument(
name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
),
default=None,
)
out_index_binding = Binding(
name="out_index",
nctype=NamedCType(name="out_index", type=BaseCType(longT)),
argument=Argument(
name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
),
default=None,
)
reapply_views_binding = Binding(
name="reapply_views",
nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
argument=Argument(
name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
),
default=None,
)
InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
inverse_return_mode_binding = Binding(
name="inverse_return_mode",
nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
argument=Argument(
name="inverse_return_mode",
# NB: not actually a bool but it doesn't matter because this isn't used
type=BaseType(BaseTy.bool),
default=None,
annotation=None,
),
default=None,
)
# Name of the `ViewMeta` specialization class created.
def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
namespace = "at::functionalization::" if with_namespace else ""
return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
# Name of the operation called inside the `forward`/`reverse` implementations.
def name(
g: NativeFunctionsViewGroup,
*,
is_reverse: bool,
include_namespace: bool,
reapply_views: bool | None = None,
) -> str:
if reapply_views is None:
# reapply_views is only important for the fwd lambda,
# since we always plumb the runtime "reapply_views" argument into the reverse function.
assert is_reverse
if is_reverse:
return reverse_name(g.view, include_namespace)
# in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
assert include_namespace
assert g.view_copy is not None
api_name = (
g.view.func.name.unambiguous_name()
if reapply_views
else g.view_copy.func.name.unambiguous_name()
)
return f"at::_ops::{api_name}::call"
def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
# for the reverse: we plumb the "reapply_views" flag into that function and support
# both copy and non-copy variants. (We could avoid doing that, but that would require
# writing out twice as many view inverse functions).
api_name = f.func.name.unambiguous_name()
# in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
if include_namespace:
return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
else:
return f"{api_name}_inverse"
def returns_type(func: FunctionSchema) -> CType:
# Assertion: all view ops return tensor-like outputs
assert len(func.returns) >= 1
for ret in func.returns:
assert ret.type.is_tensor_like()
# However, the return type of the lambda is always an individual tensor.
# For multi-tensor outputs, each tensor needs to be tracked individually.
return BaseCType(tensorT)
# Checks whether `func` might return more than one value.
def is_multi_output(func: FunctionSchema) -> bool:
return len(func.returns) > 1 or (
len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
)
# `ViewMeta` specialization constructor parameters.
def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
# All specializations are parematerized by `has_symbolic_inputs` flag.
arguments = [has_symbolic_inputs_binding]
# If `func` might return more than 1 value, we also parameterize this specialization
# with the output index.
if is_multi_output(func):
arguments.append(out_index_binding)
return arguments
# `ViewMeta` specialized class' constructor arguments.
#
# Values needed specifically by this specialization, that the base class does not need.
# Same as the class' attributes, but non-owning.
def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
return attributes(func, owning=False)
# `ViewMeta` specialized class' non-static member data.
#
# Essential data for calling the instance's `forward` and `reverse functions. You can
# think of them as values that should be captured from the functionalization kernel.
def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
return [
reapply_views_binding,
inverse_return_mode_binding,
*[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
]
def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]
# The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
# Both of these follow the dispatcher API.
non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
if not is_reverse:
# the forward lambda swaps out the original tensor argument with the lambd arg "base"
return [base_binding] + non_self_bindings
else:
# the reverse lambda does the same, but with an additional "mutated_view" arg
# additionally, we have a calling convention: for view ops that return multiple tensor outputs
# their corresponding view_inverse function takes in an additional index argument.
if is_multi_output(func):
return [
base_binding,
mutated_view_binding,
inverse_return_mode_binding,
out_index_binding,
] + non_self_bindings
else:
return [
base_binding,
mutated_view_binding,
inverse_return_mode_binding,
] + non_self_bindings
|