diff --git a/.gitattributes b/.gitattributes index d1633cccc4f25e42dddc72fb3ad125848b7171f4..d01719e98163a2babc6d51c6c6186bddc244e080 100644 --- a/.gitattributes +++ b/.gitattributes @@ -398,3 +398,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/ .venv/lib/python3.11/site-packages/mistral_common/data/mistral_instruct_tokenizer_240216.model.v2 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/numpy/lib/tests/__pycache__/test_io.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/mistral_common/data/tekken_240718.json filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/mistral_common/data/tekken_240718.json b/.venv/lib/python3.11/site-packages/mistral_common/data/tekken_240718.json new file mode 100644 index 0000000000000000000000000000000000000000..006aceecb567f7bcc2653b3227e17d92d7e43971 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/mistral_common/data/tekken_240718.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eccd1665d2e477697c33cb7f0daa6f6dfefc57a0a6bceb66d4be52952f827516 +size 14801223 diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cd1a69df98e4b0b1cedf180d22b124f389cf46a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b9e31ae0e3c737b3e1e2a38f5971117efd63003 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/all_to_all_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/all_to_all_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2039c2889e7580a5150fe8f619af35e9aabba23 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/all_to_all_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/from_operators.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/from_operators.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef8d0e77795a817a7745f270d2c7b2bfe7d67207 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/from_operators.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/input_data_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/input_data_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59019acae9af91f6000d02311f94cf0cc9a136db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/input_data_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/map_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/map_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9910a124c97713cfb761e87ce8d7668da37d3d51 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/map_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/n_ary_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/n_ary_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ad9c17f568617114a2f1beae405e9c67fffd8e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/n_ary_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/one_to_one_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/one_to_one_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd385c19a4b4090c38ffadffc26c2e2a2dc9cd35 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/one_to_one_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/read_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/read_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de3af596104522d525e3f2c469cde99cb11677b4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/read_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/write_operator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/write_operator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d5c9a445d0c206ee36021611127a42f31bca35c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/write_operator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/rules/randomize_blocks.py b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/rules/randomize_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..8810217258ab5b55e3d83eecd41c8e6c2fe4f99f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/data/_internal/logical/rules/randomize_blocks.py @@ -0,0 +1,77 @@ +import copy +from collections import deque + +from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule +from ray.data._internal.logical.operators.all_to_all_operator import ( + AbstractAllToAll, + RandomizeBlocks, +) + + +class ReorderRandomizeBlocksRule(Rule): + """Rule for reordering RandomizeBlocks logical operator. + + Reordering RandomizeBlocks operators is to help fuse multiple + AbstractUDFMap operators together for better performance. + + 1. Dedupes multiple RandomizeBlocks operators if they are not seeded. + 2. Moves RandomizeBlocks operator to the end of a sequence of AbstractUDFMap + operators. RandomizeBlocks operators are not moved across AbstractAllToAll operator + boundaries. + """ + + def apply(self, plan: LogicalPlan) -> LogicalPlan: + optimized_dag: LogicalOperator = self._apply(plan.dag) + new_plan = LogicalPlan(dag=optimized_dag, context=plan.context) + return new_plan + + def _apply(self, op: LogicalOperator) -> LogicalOperator: + operators = [] + + # Post-order traversal. + nodes = deque() + for node in op.post_order_iter(): + nodes.appendleft(node) + + while len(nodes) > 0: + current_op = nodes.pop() + upstream_ops = current_op.input_dependencies + + # Iterate through all upstream ops, and remove all RandomizeBlocks + # operators. + for i in range(len(upstream_ops)): + if isinstance(upstream_ops[i], RandomizeBlocks): + # If no seeds are provided, then collapse into a single + # RandomizeBlocks operator. + current_seed = upstream_ops[i]._seed + if not operators or current_seed or operators[-1]._seed: + # We need to make a copy of the operator. + # Because the operator instance may be shared by multiple + # Datasets. We shouldn't modify it in place. + operators.append(copy.copy(upstream_ops[i])) + + # Remove RandomizeBlocks operator from the dag and wire in new input + # dependencies. + assert len(upstream_ops[i].input_dependencies) == 1 + upstream_ops[i] = upstream_ops[i].input_dependencies[0] + if isinstance(current_op, AbstractAllToAll) and not isinstance( + current_op, RandomizeBlocks + ): + # If this operator is a an AllToAll Operator, then insert + # RandomizeBlocks right before this operator rather than the end of the + # DAG. + # All-to-all operators can have only 1 input operator. + assert len(upstream_ops) == 1 + input_op = upstream_ops[0] + for random_op in operators: + random_op._input_dependencies = [input_op] + input_op = random_op + upstream_ops[0] = input_op + operators = [] + + # Add RandomizeBlocks operator as the last operator in the DAG if necessary. + for random_op in operators: + random_op._input_dependencies = [op] + op = random_op + + return op diff --git a/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..253814436347c16296011330107ea96e7b6e8f2b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:651bd8f392a2068689c3b3a80e08fda2bab7e27693fdef2c9f01c2c6303ab472 +size 123663 diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/autograd.py b/.venv/lib/python3.11/site-packages/torchgen/api/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..644069395e1dd86d7bc65c4d69473faa1a068b66 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/autograd.py @@ -0,0 +1,870 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import cast, Sequence + +from torchgen import local +from torchgen.api import cpp +from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT +from torchgen.model import ( + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + NativeFunctionsViewGroup, + SchemaKind, + Type, +) +from torchgen.utils import IDENT_REGEX + + +# Represents a saved attribute involved in backward calculation. +# Note that it can be a derived property of an input argument, e.g.: +# we could save `other.scalar_type()` instead of the entire `other` tensor. +@dataclass(frozen=True) +class SavedAttribute: + # The NamedCType holds the updated name and cpp type of the attribute + # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type` + nctype: NamedCType + + # The expression to read the derived property at save time, e.g.: + # `other.scalar_type()`. + expr: str + + +# Represents a backward formula that calculates derivatives for one +# or more tensors. +@dataclass(frozen=True) +class Derivative: + # The formula string (legit C++ expression). + # Note that expressions against input arguments have been replaced with the + # corresponding saved attributes. + # E.g.: + # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())` + # here: `mul_tensor_backward(grad, self, other_scalar_type)` + formula: str + + # The formula string before input argument replacement + original_formula: str + + # Names of the arguments for which this formula calculates derivatives. + var_names: tuple[str, ...] + + # Saved inputs that are referenced by the formula. + saved_inputs: tuple[SavedAttribute, ...] + + # Saved outputs that are referenced by the formula. + saved_outputs: tuple[SavedAttribute, ...] + + # Gradients that are referenced by name in the formula. + named_gradients: set[str] + + +# Represents a forward formula that calculates forward derivatives +# for one tensor. +@dataclass(frozen=True) +class ForwardDerivative: + # The formula string (legit C++ expression). + # Note that special keywords such as "linear" or "element_wise" have been + # replaced by the automatically generated formula. + formula: str + + # Name of the output arguments for which this formula calculates forward + # derivatives + var_names: tuple[str, ...] + + # Type of the output arguments for which this formula calculates forward + # derivatives + var_types: tuple[Type, ...] + + # Inputs for which the forward derivatives are required for this formula + required_inputs_fw_grad: tuple[str, ...] | None + + # Inputs for which the primal is required for this formula + required_inputs_primal: tuple[str, ...] | None + + # Flag to specify if this formula requires the original value of self + # This is only used by inplace operations + required_original_self_value: bool + + # If this formula is specified in derivatives.yaml or if we are re-using the + # out of place formula for inplace + is_reusing_outplace_formula: bool + + +# Represents differentiability info for a NativeFunction. +@dataclass(frozen=True) +class DifferentiabilityInfo: + # The base name read from derivatives.yaml. + name: str + + # The matching native function. + # + # There can be multiple NativeFunction having the same base name: + # - different overloads with different types of input arguments; + # - in-place/out/functional variants of the same function; + # + # We first use the schema string (under the 'name' key) in derivatives.yaml + # to find the NativeFunction having the same schema string. + # Then we find the in-place/out/functional variants of the matching function. + # Among these variants, we choose the one having the same name as the + # derivatives.yaml entry. If there is no exact match, then we choose the + # in-place variant. + # TODO: maybe the logic to search for all variants is no longer necessary? + func: NativeFunction + + # The name of the generated autograd function. + # It's set only if we will calculate a derivative, i.e. + # 'args_with_derivatives' is not empty. + op: str | None + + # The derivatives formulae for this function. + # Note that the length of this sequence is the number of differentiable inputs + derivatives: Sequence[Derivative] + + # The forward derivatives formulae for this function. + # Note that the length of this sequence is the number of differentiable outputs + forward_derivatives: Sequence[ForwardDerivative] + + # The union of 'saved_inputs' of all 'derivatives'. + all_saved_inputs: Sequence[SavedAttribute] + + # The union of 'saved_outputs' of all 'derivatives'. + all_saved_outputs: Sequence[SavedAttribute] + + # All named gradients that are available for use, in the same + # order as in the grads vector. + available_named_gradients: Sequence[str] + + # The named gradients that are used in any of the derivatives. + # Invariant: all(name in available_named_gradients for name in used_named_gradients) + used_named_gradients: set[str] + + # The function's input arguments for which it calculates derivatives. + # It's the union of 'var_names' of all 'derivatives', sorted by the + # argument order in the function schema. + args_with_derivatives: Sequence[Binding] + + # Names of arguments whose derivative formula is 'non_differentiable'. + non_differentiable_arg_names: Sequence[str] + + # Raw data read from derivatives.yaml. + output_differentiability: list[bool] | None + + # output_differentiability in derivatives.yaml can be a list of + # conditions that express if the output is differentiable. In this case, + # the number of conditions must match the number of outputs + # (NB: we only support one condition right now). + # output_differentiability gets populated with True for each condition, + # while output_differentiability_conditions gets populated with the conditions + output_differentiability_conditions: list[str] | None + + @property + def has_derivatives(self) -> bool: + return len(self.args_with_derivatives) > 0 + + # Generates a new DifferentiabilityInfo using the exact same set of derivative information, + # but with a new operator name. + # This is used when generating "copy" variants of view ops, + # which are able to use the exact same derivative formula as the original view op + # See Note [Codegen'd {view}_copy Operators] + def create_view_copy_from_view_derivative( + self, g: NativeFunctionsViewGroup + ) -> DifferentiabilityInfo | None: + if g.view_copy is None: + return None + f = g.view_copy + + name_split_by_period = self.name.split(".", maxsplit=2) + # Append a "_copy" to the base name of the operator (but keep the overload name the same) + view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join( + name_split_by_period[1:] + ) + view_copy_op_name = None if self.op is None else f"{self.op}_copy" + + return DifferentiabilityInfo( + # Use the "_copy" version of name/func/op + name=view_copy_name, + func=f, + op=view_copy_op_name, + # But keep all derivative info the same + derivatives=self.derivatives, + forward_derivatives=self.forward_derivatives, + all_saved_inputs=self.all_saved_inputs, + all_saved_outputs=self.all_saved_outputs, + available_named_gradients=self.available_named_gradients, + used_named_gradients=self.used_named_gradients, + args_with_derivatives=self.args_with_derivatives, + non_differentiable_arg_names=self.non_differentiable_arg_names, + output_differentiability=self.output_differentiability, + output_differentiability_conditions=self.output_differentiability_conditions, + ) + + +def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool: + if info is None: + return False + for derivative in info.derivatives: + formula = derivative.formula + if re.search(IDENT_REGEX.format(ident), formula): + return True + return False + + +def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool: + return uses_ident(info, "retain_variables") + + +def uses_single_grad(info: DifferentiabilityInfo | None) -> bool: + return uses_ident(info, "grad") + + +# Represents a differentiable `Argument`. +# How is it different from the `Argument` type? +# - It's processed Arguments which are differentiable and only used in the +# context of the autograd codegen; +# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument; +@dataclass(frozen=True) +class DifferentiableInput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str + + +# Represents a differentiable `Return`. +# How it it different from the `Return` type? +# - The name in `Return` is optional. Here it is always populated using the same +# `cpp.return_names()` method. +# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant? +# - It's processed Returns which are differentiable, in compliance with the +# `output_differentiability` field defined in derivatives.yaml (if specified), +# and are only used in the context of the autograd codegen; +@dataclass(frozen=True) +class DifferentiableOutput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str + + +@dataclass(frozen=True) +class NativeFunctionWithDifferentiabilityInfo: + func: NativeFunction + info: dict[str, DifferentiabilityInfo] | None + fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None + + +# TODO: Update comment below since it is out of date. +def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: + """How are we going to call the underlying implementation of a + declaration? There are two strategies: + - use_derived: we want to call the implementation on CPUDoubleType + (or a similar, derived Type instance). Because these derived + instances deal in Tensors, not Variables (it's a completely different + object, so it doesn't dispatch back to VariableType), code on + this dispatch path needs to wrap/unwrap tensors. If the + derived implementation takes and returns tensors, the + implementation is usually differentiable (although we also use + the derived dispatch path for non-differentiable functions + that we still want to dispatch on the derived Type instance; + e.g., size()) + - use_type: we want to call the implementation on Type, because + it is implemented concretely, and the functions it invokes will + get dispatched back to VariableType (which will ensure that they + are differentiable.) + """ + # fn is derived as long as any of its per-key differentiability infos + # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType + # and ADInplaceOrViewType. We want to generate these functions as long as a + # derivative is defined for ANY dispatch key. + if fn.func.is_abstract or ( + fn.info is not None and any(info.has_derivatives for info in fn.info.values()) + ): + # If the function is abstract (not implemented on at::Type), we must + # call the implementation on the derived type with unpacked tensors. + + # If the function has a derivative specified and is concrete, we could + # call either implementation. We prefer the calling the derived + # type's implementation with unpacked tensors because it is more + # performant in some cases: any internal calls to other ATen functions + # won't have the history tracked. + + # If the function has a type dispatched argument (i.e. is a factory), + # we prefer calling the derived type's implementation both because it is + # more performant and to ensure factory functions return tensors with _version + # of 0 (probably not strictly necessary, but nice to have to keeps versions simple + # to understand. + + return "use_derived" + else: + # If the function is concrete (we don't have to override it) and we + # didn't declare it in derivatives.yaml, we'll assume that it is + # actually implemented out of differentiable functions. (This + # assumption might not hold, but then you'll see gradcheck fail.) + return "use_type" + + +def is_foreach_func(f: NativeFunction) -> bool: + return f.func.name.name.base.startswith("_foreach_") + + +# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind +# is functional for their backward derivatives (and forward derivatives in the future), i.e., +# they would find such one in `functional_info_by_signature`. There however are some exceptions: +_foreach_with_inplace_ref = {"_foreach_zero_"} +_foreach_with_tensor_overload = { + "_foreach_add.Tensor", + "_foreach_mul.Tensor", + "_foreach_div.Tensor", +} +# The following do not support the alpha kwarg, which the nonforeach versions support. +_skip_argument_len_check = { + "_foreach_add.Scalar", + "_foreach_add_.Scalar", + "_foreach_add.ScalarList", + "_foreach_add_.ScalarList", + "_foreach_sub.Scalar", + "_foreach_sub_.Scalar", + "_foreach_sub.ScalarList", + "_foreach_sub_.ScalarList", +} + + +# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function +# reference to generate derivatives. +def is_reference_for_foreach( + f: NativeFunction, + function_schema: FunctionSchema, +) -> bool: + return ( + f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base + and ( + not function_schema.name.name.inplace + or str(f.func.name) in _foreach_with_inplace_ref + ) + and ( + str(f.func.name) in _skip_argument_len_check + or len(f.func.arguments.flat_non_out) + == len(function_schema.arguments.flat_non_out) + ) + and all( + ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) + for arg, ref_arg in zip( + f.func.arguments.flat_non_out, + function_schema.arguments.flat_non_out, + ) + ) + ) + + +# TODO(crcrpar): Avoid hard coding "Default" ideally. +def gen_foreach_derivativeinfo( + foreach_function: NativeFunction, + functional_info_by_signature: dict[ + FunctionSchema, dict[str, DifferentiabilityInfo] + ], + non_functional_info_by_signature: dict[ + FunctionSchema, dict[str, DifferentiabilityInfo] + ], + dispatch_key: str = "Default", +) -> tuple[DifferentiabilityInfo | None, bool]: + """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place. + + The second return value indicates whether the info is generated in this function. + """ + ref_diff_info: DifferentiabilityInfo | None = None + + for function_schema, diff_info in functional_info_by_signature.items(): + if not is_reference_for_foreach(foreach_function, function_schema): + continue + ref_diff_info = diff_info[dispatch_key] + if ref_diff_info is not None: + break + # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature + # while the info of `zero_` is in non_functional_info_by_signature + if ( + ref_diff_info is None + and foreach_function.func.kind() == SchemaKind.inplace + and str(foreach_function.func.name) in _foreach_with_inplace_ref + ): + for function_schema, diff_info in non_functional_info_by_signature.items(): + if not is_reference_for_foreach(foreach_function, function_schema): + continue + ref_diff_info = diff_info[dispatch_key] + if ref_diff_info is not None: + break + if ref_diff_info is None: + return None, False + + # non out-place uses the existing Derivative. + if foreach_function.func.kind() == SchemaKind.inplace: + return ref_diff_info, False + + map_refarg2foreacharg, map_name2arg = {}, {} + for i, (arg, ref_arg) in enumerate( + zip( + foreach_function.func.arguments.flat_non_out, + function_schema.arguments.flat_non_out, + ) + ): + map_refarg2foreacharg[ref_arg.name] = arg.name + map_name2arg[arg.name] = arg + + all_saved_inputs, all_saved_outputs, all_var_names = [], [], [] + modified_derivative_formulas = [] + for i, derivative in enumerate(ref_diff_info.derivatives): + modified_formula = derivative.formula.replace("grad", "grads[i]").replace( + "result", "result[i]" + ) + saved_inputs, saved_outputs = [], [] + # note(crcrpar): This context seems necessary to call `cpp.argument_type` + with local.parametrize( + use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, + ): + for ref_input in derivative.saved_inputs: + ref_input_jit_name = ref_input.expr.split(".")[0] + mapped_name = map_refarg2foreacharg[ref_input_jit_name] + if isinstance(map_name2arg[mapped_name].type, ListType): + mapped_expr = mapped_name + "[i]" + else: + mapped_expr = mapped_name + new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr) + modified_formula = modified_formula.replace( + cast(str, ref_input.nctype.name), new_expr + ) + + nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name) + canonical_nctype = NamedCType( + nctype.name, nctype.type.remove_const_ref() + ) + saved_inputs.append( + SavedAttribute(nctype=canonical_nctype, expr=mapped_name) + ) + for ref_output in derivative.saved_outputs: + if ref_output.nctype.name == "result": + saved_outputs.append( + SavedAttribute( + nctype=NamedCType( + name="result", type=BaseCType(tensorListT) + ), + expr="result", + ) + ) + else: + raise RuntimeError("") + var_names = [map_refarg2foreacharg[var] for var in derivative.var_names] + all_var_names.extend(var_names) + all_saved_inputs.extend(saved_inputs) + all_saved_outputs.extend(saved_outputs) + modified_derivative = Derivative( + formula=modified_formula, + original_formula=derivative.formula, + var_names=tuple(var_names), + saved_inputs=tuple(saved_inputs), + saved_outputs=tuple(saved_outputs), + named_gradients=set(), + ) + modified_derivative_formulas.append(modified_derivative) + + with local.parametrize( + use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, + ): + args_with_derivatives = [ + Binding( + name=arg.name, + nctype=cpp.argument_type(arg, binds=arg.name), + argument=arg, + default=None, + ) + for arg in foreach_function.func.arguments.flat_non_out + if arg.name in all_var_names + ] + + forward_derivatives: list[ForwardDerivative] = [] + fw_derivative: ForwardDerivative + for fw_derivative in ref_diff_info.forward_derivatives: + var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef] + var_types: list[Type] = list(fw_derivative.var_types) + required_inputs_fw_grad: list[str] = [] + required_inputs_primal: list[str] = [] + if fw_derivative.required_inputs_fw_grad is not None: + required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) + if fw_derivative.required_inputs_primal: + required_inputs_primal = list(fw_derivative.required_inputs_primal) + modified_formula = fw_derivative.formula + + # Foreach's result is TensorList + if "result" in modified_formula: + modified_formula = fw_derivative.formula.replace("result", "result[i]") + + for foreach_arg, ref_arg in zip( + foreach_function.func.arguments.flat_non_out, + ref_diff_info.func.func.arguments.flat_non_out, + ): + # Modify reference forward formula + if ( + isinstance(foreach_arg.type, ListType) + and not foreach_arg.type.is_tensor_like() + ): + # Assuming ScalarList + modified_formula = modified_formula.replace( + ref_arg.name, foreach_arg.name + "[i]" + ) + elif foreach_arg.type.is_tensor_like(): + # Assuming TensorList / Tensor + # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}" + assert isinstance(foreach_arg.type, ListType) or ( + foreach_arg.type == BaseType(BaseTy.Tensor) + and str(foreach_function.func.name) in _foreach_with_tensor_overload + ), f"{foreach_function.func.name}, {foreach_arg.type}" + for suffix in ("_p", "_t"): + curr_expr = ref_arg.name + suffix + if curr_expr in modified_formula: + new_expr = foreach_arg.name + suffix + modified_formula = modified_formula.replace(curr_expr, new_expr) + else: + # Assuming Scalar + if foreach_arg.name != ref_arg.name: + modified_formula = modified_formula.replace( + ref_arg.name, foreach_arg.name + ) + + # note(crcrpar): there should exist a cooler way... + for i, name in enumerate(var_names): + if name == ref_arg.name: + var_names[i] = foreach_arg.name + var_types[i] = foreach_arg.type + for i, name in enumerate(required_inputs_fw_grad): + if name == ref_arg.name: + required_inputs_fw_grad[i] = foreach_arg.name + for i, name in enumerate(required_inputs_primal): + if name == ref_arg.name: + required_inputs_primal[i] = foreach_arg.name + forward_derivatives.append( + ForwardDerivative( + formula=modified_formula, + var_names=tuple(var_names), + var_types=tuple(var_types), + required_inputs_fw_grad=tuple(required_inputs_fw_grad), + required_inputs_primal=tuple(required_inputs_primal), + required_original_self_value=fw_derivative.required_original_self_value, + is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula, + ) + ) + + return ( + DifferentiabilityInfo( + name=foreach_function.func.name.name.base, + func=foreach_function, + op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}", + derivatives=modified_derivative_formulas, + forward_derivatives=forward_derivatives, + all_saved_inputs=tuple(set(all_saved_inputs)), + all_saved_outputs=tuple(set(all_saved_outputs)), + available_named_gradients=(), + used_named_gradients=set(), + args_with_derivatives=args_with_derivatives, + non_differentiable_arg_names=[], + output_differentiability=None, + output_differentiability_conditions=None, + ), + True, + ) + + +def match_differentiability_info( + native_functions: list[NativeFunction], + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], +) -> list[NativeFunctionWithDifferentiabilityInfo]: + """Sets the "derivative" key on declarations to matching autograd function + In-place functions will use the out-of-place derivative definition if there + is no in-place specific derivative. + """ + + functional_info_by_signature = { + schema.signature(strip_default=True): info_dict + for schema, info_dict in differentiability_infos.items() + if schema.kind() == SchemaKind.functional + } + non_functional_info_by_signature = { + schema.signature(strip_default=True): info_dict + for schema, info_dict in differentiability_infos.items() + if schema.kind() != SchemaKind.functional + } + + def find_info( + f: NativeFunction, + ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]: + # Don't bother matching info to generated out= variants + if "generated" in f.tags and f.func.kind() == SchemaKind.out: + return None, False + + # (1) Check for an exact match + if f.func in differentiability_infos: + return differentiability_infos[f.func], True + + # (2) If no exact match, check if the out-of-place variant + # of this operator has a match. + # i.e mul() for mul_() or mul_out() + # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing + # native functions instead of the out-place counterparts. + f_sig = f.func.signature(strip_default=True) + if f_sig in functional_info_by_signature and not is_foreach_func(f): + return functional_info_by_signature[f_sig], False + + # (3) Some operators have a derivative explicitly defined for the mutable + # variant, but get a code-generated out-of-place variant which does *not* + # come with a derivative formula. + # For the generated out-of-place variant, use the mutable variant's formula + # if it exists. + if "generated" in f.tags and f_sig in non_functional_info_by_signature: + info_dict = non_functional_info_by_signature[f_sig] + # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389 + assert not any( + any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs) + for info in info_dict.values() + ), f"""\ +Attempted to convert a derivative formula for a mutable operator + to be used by automatically by its functional variant ("{str(f.func)}"). + this is not currently supported (we'd need to fix up the formula in the codegen).""" + return info_dict, False + + # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml` + if is_foreach_func(f): + assert f.func not in differentiability_infos + diff_info, is_generated = gen_foreach_derivativeinfo( + f, + functional_info_by_signature, + non_functional_info_by_signature, + ) + if diff_info is None: + return None, False + # TODO(crcrpar): Avoid hard coding "Default" ideally. + diff_info_dict = {"Default": diff_info} + if is_generated: + differentiability_infos[f.func] = diff_info_dict + functional_info_by_signature[f.func] = diff_info_dict + return diff_info_dict, is_generated + + return None, False + + result: list[NativeFunctionWithDifferentiabilityInfo] = [] + for f in native_functions: + info_dict, is_exact_match = find_info(f) + + # Currently, the '.strides()' to 'strides_or_error' replacement does not support + # 'self' derivatives of an inplace function, so we must check for this case. + if f.func.kind() == SchemaKind.inplace and (info_dict is not None): + for info in info_dict.values(): + for derivative in info.derivatives: + if "self" in derivative.var_names: + for saved_input in derivative.saved_inputs: + assert "strides_or_error" not in saved_input.expr, ( + "Calling '.strides()' in the 'self' derivative formula of an " + f"in-place function is not supported: {f.func}" + ) + + if not info_dict: + result.append( + NativeFunctionWithDifferentiabilityInfo( + func=f, info=None, fw_derivatives=None + ) + ) + continue + + fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {} + for key, info in info_dict.items(): + if not info.forward_derivatives: + fw_derivative_dict[key] = [] + continue + + forward_derivatives = info.forward_derivatives + + # For functions that have a single def for out-of-place and inplace (like abs()) + if f.func.kind() == SchemaKind.inplace: + # For inplace functions there is a little bit of work to do: + # 1) Validate the formula and make sure the input that is modified in not used: + # - If there is a formula for the inplace variant of the function (is_exact_match == True) then + # we make sure that the original value of the input that is being modified inplace (self_p) is + # not used in the formula. Note that the formula can use "original_self_p" here and that would + # trigger a clone of the original input. + # - If we are re-using the out of place formula (is_exact_match == False) then we replace every + # occurrence of self_p and self_t by original_self_p and original_self_t. These will be + # populated by cloned version of the original input (either the clone done by the backward AD + # logic if self is also used in a backward formula or a special clone that we add). + # 2) At this point, there cannot be a self_p in the formula. + # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is + # simply called self (as it is modified inplace). + # 4) Update the required primals data in case it used to contain "result" but should now contain + # "self" + # 5) If it is not an exact match, the user formula is not modifying the existing forward grad + # inplace as it should. So add some code that makes sure that we do so if the forward grad + # already exists. + + assert ( + len(info.forward_derivatives) == 1 + ) # Only single output inplace should exist + fw_info = info.forward_derivatives[0] + formula = fw_info.formula + + def replace_self_with_original_self(formula: str, postfix: str) -> str: + def repl(m: re.Match[str]) -> str: + return f"{m.group(1)}original_self{postfix}{m.group(2)}" + + return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) + + if re.search(IDENT_REGEX.format("self_p"), formula): + if is_exact_match: + # For manually defined formulas, don't allow the original value to be used + raise RuntimeError( + f'The formula for "{f.func.name}" is using the original value of self ' + "that is being modified inplace. This would lead to wrong forward gradients. " + 'Please use "result" in the formula only.' + ) + else: + # When the original formula is out of place, we save a clone of the primal + # value to be able to access this value if needed + # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t" + formula = replace_self_with_original_self(formula, "_p") + formula = replace_self_with_original_self(formula, "_t") + + # replace "result" from the formula by "self_p" + def repl(m: re.Match[str]) -> str: + return f"{m.group(1)}self_p{m.group(2)}" + + formula = re.sub(IDENT_REGEX.format("result"), repl, formula) + + required_primals = fw_info.required_inputs_primal + if re.search(IDENT_REGEX.format("self_p"), formula): + required_primals = ( + required_primals + ("self",) if required_primals else ("self",) + ) + + if not is_exact_match: + # NOTE [In-place forward AD formula Optimization] + # + # This optimization transforms the formula to directly do inplace, i.e. + # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met: + # + # 1) the formula satisfies the pattern: "self_t.op(*args)" + # 2) "op" in (1) needs to be the same as the op the derivative is for + # + # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2) + # If there is a need, we can relax (2) to allow any op that has an in-place variant + is_single_method_on_self_t = False + directly_do_inplace = False + op_name: str | None = None + between_parens: str | None = None + match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) + if match: + op_name, between_parens = match.group(1), match.group(2) + + # We want to... + # Match: self_t.op1(other_p.op2(arg)) + # Avoid: self_t.op1(args) + self_t.op2(args) + # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args) + def check_parens_nest_level_gt_zero(s: str) -> bool: + level = 1 + for ch in s: + if ch == ")": + level -= 1 + if level == 0: + return False + if ch == "(": + level += 1 + return True + + is_single_method_on_self_t = check_parens_nest_level_gt_zero( + between_parens + ) + directly_do_inplace = ( + is_single_method_on_self_t and op_name == info.name + ) + + if directly_do_inplace: + assert op_name is not None + assert between_parens is not None + formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}" + else: + # Make sure that the forward grad is modified inplace when the original formula + # is out of place + formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}" + + required_original_self_value = bool( + re.search(IDENT_REGEX.format("original_self_p"), formula) + ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula)) + + forward_derivatives = [ + ForwardDerivative( + formula=formula, + var_names=("self",), + var_types=fw_info.var_types, + required_inputs_fw_grad=fw_info.required_inputs_fw_grad, + required_inputs_primal=required_primals, + required_original_self_value=required_original_self_value, + is_reusing_outplace_formula=not is_exact_match, + ), + ] + + fw_derivative_dict[key] = forward_derivatives + + result.append( + NativeFunctionWithDifferentiabilityInfo( + func=f, info=info_dict, fw_derivatives=fw_derivative_dict + ) + ) + + return result + + +def is_differentiable( + name: str, type: Type, info: DifferentiabilityInfo | None +) -> bool: + return type.is_tensor_like() and ( + info is None or name not in info.non_differentiable_arg_names + ) + + +def gen_differentiable_outputs( + fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" +) -> list[DifferentiableOutput]: + f = fn.func + info = fn.info[key] if fn.info else None + outputs: list[DifferentiableOutput] = [ + DifferentiableOutput( + name=name, + type=ret.type, + cpp_type=cpp.return_type(ret, symint=True).cpp_type(), + ) + for name, ret in zip(cpp.return_names(f), f.func.returns) + ] + output_differentiability = info.output_differentiability if info else None + if output_differentiability is not None: + if len(output_differentiability) != len(outputs): + raise RuntimeError( + f"The length of output_differentiability ({len(output_differentiability)}), " + f"does not match the number of outputs ({len(outputs)})." + ) + differentiable_outputs: list[DifferentiableOutput] = [] + if False in output_differentiability and f.func.kind() == SchemaKind.inplace: + raise RuntimeError( + "output_differentiability=False for inplace operation (version_counter won't get updated)" + ) + for differentiable, output in zip(output_differentiability, outputs): + if differentiable: + differentiable_outputs.append(output) + return differentiable_outputs + candidate_differentiable_outputs = list( + filter(lambda r: is_differentiable(r.name, r.type, info), outputs) + ) + if uses_single_grad(info): + return candidate_differentiable_outputs[:1] + else: + return candidate_differentiable_outputs diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/functionalization.py b/.venv/lib/python3.11/site-packages/torchgen/api/functionalization.py new file mode 100644 index 0000000000000000000000000000000000000000..93667e39b17fa4ba82414fa92bb7200faf6f6515 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/functionalization.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from torchgen.api import dispatcher +from torchgen.api.types import ( + BaseCppType, + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + longT, + NamedCType, + tensorT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + NativeFunction, + NativeFunctionsViewGroup, +) + + +# This file describes the translation of JIT schema to API's used +# when creating view lambdas that are used by the functionalization pass. +# There are two types of lambdas: forward lambdas and reverse lambdas. +# These API's mostly follow the dispatcher API, with a few quirks: +# - The lambda capture has to convert reference types to value types +# - While the forward lambda just directly calls into the at::_ops API +# (following the dispatcher convention), the logic here for the reverse lambda +# is responsible for generating both the call-site, and the declarations +# (which are implemented manually in the at::functionalization::impl namespace). + +# The lambdas generated for each view op in the functionalization pass are of the form +# [capture_arguments](outer_arguments) -> returns_type { +# return name(inner_arguments); +# } + +# Define some specific lambda input arguments. +base_binding = Binding( + name="base", + nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), + argument=Argument( + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + ), + default=None, +) +mutated_view_binding = Binding( + name="mutated_view", + nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), + argument=Argument( + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + ), + default=None, +) +mutated_view_idx_binding = Binding( + name="mutated_view_idx", + nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), + argument=Argument( + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + ), + default=None, +) +reapply_views_binding = Binding( + name="reapply_views", + nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)), + argument=Argument( + name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None + ), + default=None, +) + +InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode") +inverse_return_mode_binding = Binding( + name="inverse_return_mode", + nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)), + argument=Argument( + name="inverse_return_mode", + # NB: not actually a bool but it doesn't matter because this isn't used + type=BaseType(BaseTy.bool), + default=None, + annotation=None, + ), + default=None, +) + + +# The lambda capture itself doesn't have a name. +# The name returned here corresponds to the name of the inner function called by the lambda. +def name( + g: NativeFunctionsViewGroup, + *, + is_reverse: bool, + include_namespace: bool, + reapply_views: bool | None = None, +) -> str: + if reapply_views is None: + # reapply_views is only important for the fwd lambda, + # since we always plumb the runtime "reapply_views" argument into the reverse function. + assert is_reverse + if is_reverse: + return reverse_name(g.view, include_namespace) + # in the forward case, we just directly call into the at::_ops API (so we always need the namespace) + assert include_namespace + assert g.view_copy is not None + api_name = ( + g.view.func.name.unambiguous_name() + if reapply_views + else g.view_copy.func.name.unambiguous_name() + ) + return f"at::_ops::{api_name}::call" + + +def reverse_name(f: NativeFunction, include_namespace: bool) -> str: + # for the reverse: we plumb the "reapply_views" flag into that function and support + # both copy and non-copy variants. (We could avoid doing that, but that would require + # writing out twice as many view inverse functions). + api_name = f.func.name.unambiguous_name() + # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't) + if include_namespace: + return f"at::functionalization::FunctionalInverses::{api_name}_inverse" + else: + return f"{api_name}_inverse" + + +def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: + # capture arguments include all arguments except `self`. + # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), + # So any reference types (IntArrayRef) need to be converted to value types (vector) + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + non_self_args = args[1:] + non_self_value_bindings = [ + dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args + ] + + all_bindings = [ + inverse_return_mode_binding if is_reverse else reapply_views_binding + ] + all_bindings.extend(non_self_value_bindings) + return all_bindings + + +def returns_type(func: FunctionSchema) -> CType: + # Assertion: all view ops return tensor-like outputs + assert len(func.returns) >= 1 + for ret in func.returns: + assert ret.type.is_tensor_like() + # However, the return type of the lambda is always an individual tensor. + # For multi-tensor outputs, each tensor needs to be tracked individually. + return BaseCType(tensorT) + + +def outer_arguments(*, is_reverse: bool) -> list[Binding]: + if is_reverse: + return [base_binding, mutated_view_binding, mutated_view_idx_binding] + else: + return [base_binding, mutated_view_idx_binding] + + +def inner_call_index(func: FunctionSchema) -> Binding | None: + # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. + # When we replay a view op that returns multiple tensors, we need to index into the output appropriately + if len(func.returns) > 1 or ( + len(func.returns) == 1 and func.returns[0].type.is_list_like() + ): + return mutated_view_idx_binding + return None + + +def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + non_self_args = args[1:] + # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API. + # Both of these follow the dispatcher API. + non_self_bindings = [dispatcher.argument(a) for a in non_self_args] + if not is_reverse: + # the forward lambda swaps out the original tensor argument with the lambd arg "base" + return [base_binding] + non_self_bindings + else: + # the reverse lambda does the same, but with an additional "mutated_view" arg + # additionally, we have a calling convention: for view ops that return multiple tensor outputs + # their corresponding view_inverse function takes in an additional index argument. + index_binding = inner_call_index(func) + if index_binding is not None: + return [ + base_binding, + mutated_view_binding, + inverse_return_mode_binding, + index_binding, + ] + non_self_bindings + else: + return [ + base_binding, + mutated_view_binding, + inverse_return_mode_binding, + ] + non_self_bindings diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/lazy.py b/.venv/lib/python3.11/site-packages/torchgen/api/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..cfffa516b656b8f479b5bfe16d4a4620fc35f9b0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/lazy.py @@ -0,0 +1,467 @@ +from __future__ import annotations + +from typing import Any + +from torchgen.api.types import ( + BaseCppType, + BaseCType, + boolT, + CType, + deviceT, + doubleT, + generatorT, + layoutT, + ListCType, + longT, + memoryFormatT, + NamedCType, + OptionalCType, + scalarT, + scalarTypeT, + stringT, + SymIntT, + VectorCType, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + OperatorName, + OptionalType, + Return, + TensorOptionsArguments, + Type, +) + + +_valueT: BaseCppType | None = None + + +# A ValueT is an IR type which represents the computation of a Tensor. In other +# words, a PyTorch user will do operations on lazy tensors, and each output lazy +# tensor internally tracks a ValueT representing the IR node that would have +# actually produced the value of this tensor for real. +# +# This is configurable because different lazy tensor backends (LTC vs XLA) will +# have different IR representations. (Though, arguably, after unification they +# shouldn't!) +def getValueT() -> BaseCppType: + global _valueT + if not _valueT: + raise NotImplementedError( + "The value type needs to be set with setValueT() in run_gen_lazy_tensor()" + ) + + return _valueT + + +def setValueT(val: BaseCppType) -> None: + global _valueT + _valueT = val + + +# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object, +# making it easier to represent special properties of an arg. +tensorListValueT = BaseCppType("torch::lazy", "Value") + + +def process_ir_type( + typ: Type, properties: LazyIrProperties, *, symint: bool +) -> BaseCType | VectorCType | OptionalCType | ListCType: + """ + This function takes a type from NativeFunctions and converts it for use with + lazy tensor codegen. + + Type conversion for lazy currently consists of + (1) changing at::Tensors into lazy::Values + (2) wrapping everything in a BaseCType + (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) + + (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) + There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' + + This is incomplete- there are assertions in places that it's expected to need to add + more types as the codegen is used with more operators. + """ + if isinstance(typ, BaseType): + if typ.name == BaseTy.Tensor: + return BaseCType(getValueT()) + elif typ.name == BaseTy.Scalar: + if properties.TreatScalarsAsConstants: + return BaseCType(scalarT) + # at::scalar has special handling, + # and is wrapped in an lazy::Value just like at::tensor + return BaseCType(getValueT()) + elif typ.name == BaseTy.ScalarType: + return BaseCType(scalarTypeT) + elif typ.name == BaseTy.int: + return BaseCType(longT) + elif typ.name == BaseTy.SymInt: + if symint: + return BaseCType(getValueT()) + else: + return BaseCType(longT) + elif typ.name == BaseTy.bool: + return BaseCType(boolT) + elif typ.name == BaseTy.float: + return BaseCType(doubleT) + elif typ.name == BaseTy.str: + return BaseCType(stringT) + elif typ.name == BaseTy.Device: + return BaseCType(deviceT) + elif typ.name == BaseTy.Generator: + return BaseCType(generatorT) + elif typ.name == BaseTy.Layout: + return BaseCType(layoutT) + elif typ.name == BaseTy.MemoryFormat: + return BaseCType(memoryFormatT) + else: + raise AssertionError(f"TODO add support for type {repr(typ)}") + elif isinstance(typ, OptionalType): + return OptionalCType(process_ir_type(typ.elem, properties, symint=symint)) + elif isinstance(typ, ListType): + if str(typ.elem) == "Tensor?": + # TODO(whc) is this actually correct? or should it use a Vector like above + return ListCType(OptionalCType(BaseCType(getValueT()))) + elif str(typ.elem) == "Tensor": + # this is a TensorList which comes in from GetTensorList as a Value + return BaseCType(tensorListValueT) + elif typ.elem == BaseType(BaseTy.SymInt): + # TODO: return a value type. The problem here is analogous to + # the problem with tensorListValueT: if you have SymInt[] you + # cannot conveniently save the list of Value directly, as nodes + # expect to save values as a vector for ALL arguments. So you + # need a separate IR node that represents all of the size nodes + # assembled into a list. I'm not an LTC dev so I don't want to + # figure it out right now. Y'all figure it out... + return VectorCType(BaseCType(longT)) + + else: + return VectorCType(process_ir_type(typ.elem, properties, symint=symint)) + else: + raise AssertionError(f"unrecognized type {repr(typ)}") + + +# TODO: Determining this based off of CType is bad; this should be computed +# from Type directly; then the same logic as process_ir_type can be used +# +# Invariant: passed typ should be an *owning* CType (e.g., we will report +# that ArrayRef is NOT a value type) +def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool: + """ + Given a type, determine if it is a Value-like type. This is equivalent to + being Tensor-like, but assumes the type has already been transformed. + """ + if isinstance(typ, BaseCType): + # I am regretting my naming conventions, but now we are wrapping at::scalar in + # lazy value, while preserving other 'scalar' types as scalars in the IR + treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants + return ( + typ.type == getValueT() + or (typ.type == scalarT and not treat_scalars_as_constants) + or typ.type == SymIntT + ) + elif typ == VectorCType(BaseCType(SymIntT)): + # TODO: report True for this + return False + elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): + return isValueType(typ.elem, properties) + return False + + +def isSymIntType(typ: Type) -> bool: + return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt + + +def isWrappedScalarType(typ: Type) -> bool: + """ + Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value. + Since we literally change the type from scalarT to valueT, information is lost. + This function helps build a list of wrapped scalars to save that information + """ + if isinstance(typ, BaseType): + # I am regretting my naming conventions, but now we are wrapping at::scalar in + # lazy value, while preserving other 'scalar' types as scalars in the IR + return typ.name == BaseTy.Scalar + elif isinstance(typ, (OptionalType, ListType)): + return isWrappedScalarType(typ.elem) + return False + + +# TODO: dedupe with Type.is_generator_like +def isGeneratorType(typ: Type) -> bool: + if isinstance(typ, BaseType): + return typ.name == BaseTy.Generator + elif isinstance(typ, (OptionalType)): + return isGeneratorType(typ.elem) + return False + + +# This class caches a few derived properties computed from an Argument +# and LazyIrProperties +class LazyArgument: + name: str + orig_type: Type + lazy_type_: CType | None + is_wrapped_scalar: bool + is_generator: bool + # TODO: this is lies, it is false for symint list + is_symint_or_list: bool + + # Whether or not we are treating this as symint or not + symint: bool + + # true if this argument is or contains a lazy IR value + is_lazy_value: bool + + def __init__( + self, arg: Argument, properties: LazyIrProperties, *, symint: bool + ) -> None: + self.name = arg.name + self.orig_type = arg.type + self.symint = symint + self.is_optional = isinstance(arg.type, OptionalType) + self.is_generator = isGeneratorType(arg.type) + self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint) + self.is_wrapped_scalar = isWrappedScalarType(arg.type) + self.is_symint_or_list = symint and ( + isSymIntType(arg.type) + or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem)) + # TODO: lists of symints are not currently treated as value types + # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem)) + ) + + self.is_lazy_value = isValueType(self.lazy_type, properties) + + @property + def lazy_type(self) -> CType: + assert ( + self.lazy_type_ is not None + ), f"Attempted to access lazy_type for invalid argument {self.name}" + return self.lazy_type_ + + +class LazyIrProperties: + """Collection of properties for an IR node + + The property groups are listed below. Each group is mutually + exclusive, meaning that only one property from each group can be True + at any one time. The properties can be accessed as if they were normal + attributes. The mutual exclusivity is automatically handled. + """ + + Properties: tuple[tuple[str, ...], ...] = ( + ( + "ShapePrecompute", # Assume shape has been precomputed + "ShapeCompute", # Need to compute the shape on construction + "ShapeCache", # Utilize the shape cache to defer computation + ), + ( + "Lower", # Codegen full lower function + "LowerDeclOnly", # Codegen only lower function declaration + ), + ( + "CanBeReused", # Codegen full reuse function + "CanBeReusedDeclOnly", # Codegen only reuse function declaration + ), + ( + "CreateFn", # Codegen full create function + "CreateFnDeclOnly", # Codegen only create function declaration + ), + ( + "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values + ), + ) + + def __init__(self, *default_properties: str) -> None: + properties: dict[tuple[str, ...], str | None] = dict.fromkeys( + LazyIrProperties.Properties + ) + self.__dict__["properties"] = properties + for p in default_properties: + setattr(self, p, True) + + def __getattr__(self, key: str) -> Any: + properties = self.__dict__["properties"] + for values in LazyIrProperties.Properties: + if key in values: + return properties[values] == key + + return self.__getattribute__(key) + + def __setattr__(self, key: str, value: Any) -> Any: + properties = self.__dict__["properties"] + for values in LazyIrProperties.Properties: + if key in values: + properties[values] = key if value else None + return value + + raise KeyError(f"Invalid property: {key}") + + +# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node. +# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML), +# but carries type information from a native FunctionSchema modified for use with IR nodes, +# and preserving original argument names. +# +# TODO: This is not idiomatic with how other torchgen APIs transform on schema. +class LazyIrSchema: + # The name of the operator this function schema describes. + name: OperatorName + + positional_args: tuple[LazyArgument, ...] + keyword_args: tuple[LazyArgument, ...] + + # TODO: Need to handle collisions with argument names at some point + returns: tuple[Return, ...] + + # if this schema has a Generator arg, list its orig ctype/name but don't + # build a LazyArgument since lazy IR doesn't support it + generator_arg: NamedCType | None = None + + # original function schema + func: FunctionSchema + + # Whether or not we are code-genning for SymInt or not + symint: bool + + properties: LazyIrProperties = LazyIrProperties( + # default properties + "ShapePrecompute", + "Lower", + "CanBeReused", + ) + opkind: str | None = None + + def __init__( + self, + func: FunctionSchema, + properties: LazyIrProperties | None = None, + *, + symint: bool, + ) -> None: + if properties: + self.properties = properties + + self.func = func + self.symint = symint + positional_args: list[LazyArgument] = [] + for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]: + if arg_field == "self_arg" and func.arguments.self_arg is not None: + arg = func.arguments.self_arg.argument + positional_args.append( + LazyArgument(arg, self.properties, symint=symint) + ) + elif getattr(func.arguments, arg_field) is not None: + positional_args.extend( + LazyArgument(arg, self.properties, symint=symint) + for arg in getattr(func.arguments, arg_field) + ) + self.positional_args = tuple(positional_args) + + keyword_args: list[LazyArgument] = [] + for arg_field in [ + "pre_tensor_options_kwarg_only", + "tensor_options", + "post_tensor_options_kwarg_only", + "out", + ]: + curr_args = getattr(func.arguments, arg_field) + if curr_args is not None: + if isinstance(curr_args, TensorOptionsArguments): + curr_args = curr_args.all() + for arg in curr_args: + if isGeneratorType(arg.type): + assert ( + self.generator_arg is None + ), "We expect there is only one generator arg" + self.generator_arg = NamedCType( + arg.name, arg.type # type:ignore[arg-type] + ) + keyword_args.extend( + LazyArgument(arg, self.properties, symint=symint) + for arg in curr_args + ) + self.keyword_args = tuple(keyword_args) + self.name = func.name + self.returns = func.returns + + @property + def node_name(self) -> str: + """ + Return camel-case version of op in node. + + Note: This function also appends any `overload_name` in the operation. + For example, if the op is `bitwise_and.Tensor`, the returned name + will be `BitwiseAndTensor`. + """ + op_name = f"{self.name.name}_{self.name.overload_name}".lower() + return "".join(word.capitalize() or "" for word in op_name.split("_")) + + @property + def aten_name(self) -> str: + return str(self.name.name) + + @property + def base_name(self) -> str: + return f"{self.name.name.base}" + + def filtered_args( + self, + positional: bool = True, + keyword: bool = True, + values: bool = True, + scalars: bool = True, + generator: bool = True, + ) -> list[LazyArgument]: + # This function maintains the sorted order of arguments but provides different filtered views. + # Some parts of the code care about kwargs vs args (TS lowerings), + # other parts care about whether they need to wrap the arg in a lazy value or leave it alone. + # Generators are special cased, as they are needed for fallback/shape-inference but not supported + # in TS lowerings and therefore also omitted from lazy IR. + args: list[LazyArgument] = [] + if positional: + args.extend(self.positional_args) + if keyword: + args.extend(self.keyword_args) + + if values and scalars and generator: + return args + elif values and scalars: + return [a for a in args if not a.is_generator] + elif values: + return [a for a in args if a.is_lazy_value] + elif scalars: + return [ + a + for a in args + if not a.is_lazy_value and (generator or not a.is_generator) + ] + + return [] + + @property + def positional_values(self) -> list[LazyArgument]: + return self.filtered_args( + positional=True, keyword=False, values=True, scalars=False + ) + + @property + def positional_scalars(self) -> list[LazyArgument]: + return self.filtered_args( + positional=True, keyword=False, values=False, scalars=True + ) + + @property + def keyword_values(self) -> list[LazyArgument]: + return self.filtered_args( + positional=False, keyword=True, values=True, scalars=False + ) + + @property + def keyword_scalars(self) -> list[LazyArgument]: + return self.filtered_args( + positional=False, keyword=True, values=False, scalars=True + ) diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/meta.py b/.venv/lib/python3.11/site-packages/torchgen/api/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..2e99d151faeaccea7ca47f372fd26f9985ce7249 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/meta.py @@ -0,0 +1,13 @@ +from torchgen.model import NativeFunctionsGroup + + +# Follows dispatcher calling convention, but: +# - Mutable arguments not allowed. Meta functions are always +# written in functional form. Look at FunctionSchema.signature() +# - No tensor returns; instead we return a TensorMeta describing +# the tensor in question + + +def name(g: NativeFunctionsGroup) -> str: + # use the overload name from the functional version + return str(g.functional.func.name).replace(".", "_") diff --git a/.venv/lib/python3.11/site-packages/torchgen/api/python.py b/.venv/lib/python3.11/site-packages/torchgen/api/python.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0f07489887225b1ee0df12815f1e17f506aaf7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/api/python.py @@ -0,0 +1,1519 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +from torchgen.api import cpp +from torchgen.api.types import Binding, CppSignature, CppSignatureGroup +from torchgen.gen import pythonify_default +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + OptionalType, + Return, + Type, + Variant, +) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Data Models +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# [Notes] python binding codegen +# +# The Python binding codegen produces code that takes the input list of +# PyObjects, finds the matching ATen C++ function using PythonArgParser, +# converts the PyObjects into C++ types and calls the ATen C++ function: +# +# +--------+ parsing +------------------------+ binding +-----------------------+ +# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch | +# +--------+ +------------------------+ +-----------------------+ +# +# The following examples demonstrate the data models the Python binding +# codegen needs to deal with and the tasks it needs to accomplish. It +# helps understand the purpose of the new data types we introduced below. +# +# - Function Schema (source of truth) +# +# aten::empty.names(int[] size, *, Dimname[]? names, +# ScalarType? dtype=None, Layout? layout=None, +# Device? device=None, bool? pin_memory=None, +# MemoryFormat? memory_format=None) -> Tensor +# +# - Python Signature +# +# It's used to generate input schema string for PythonArgParser. +# Note: TensorOptions fields are reordered and the additional +# 'requires_grad' field is added: +# +# empty(IntArrayRef size, *, DimnameList? names, +# MemoryFormat? memory_format=None, ScalarType dtype=None, +# Layout layout=torch.strided, Device device=None, +# bool pin_memory=False, bool requires_grad=False) +# +# - C++ Signature +# +# It's used to generate C++ lambda formals & dispatch call. +# Note: the scattered TensorOptions fields are packed into 'options'. +# +# auto dispatch_empty = +# [](IntArrayRef size, std::optional names, +# const TensorOptions & options, +# std::optional memory_format) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return torch::empty(size, names, options, memory_format); +# }; +# +# - Binding between Python Arguments and C++ Arguments +# +# Given a set of Python Arguments in scope, we need produce the +# binding expressions that translate the Python API into C++ API: +# +# Python Args Cpp Args Binding Exprs +# ----------------------------------------------------------------- +# 0: size size '_r.intlist(0)' +# 1: names names 'names' [special init] +# 2: memory_format -------+ +# 3: dtype -----+-|--> options 'options' [special packing] +# 4: layout / | +# 5: device / +--> memory_format '_r.memoryformatOptional(2)' +# 6: pin_memory / +# 7: requires_grad -+ +# +# So the full dispatch expression would look like: +# +# dispatch_empty(_r.intlist(0), names, options, +# _r.memoryformatOptional(2)) +# +# Where does 'names' come from? It involves special local init: +# +# auto __names = _r.toDimnameListOptional(1); +# std::optional names = +# __names ? std::make_optional(DimnameList(__names.value())) +# : std::nullopt; +# +# Where does 'options' come from? It involves special local init +# for TensorOptions. Note that Python side has the additional +# 'requires_grad' field: +# +# const auto options = TensorOptions() +# .dtype(_r.scalartype(3)) +# .device(_r.device(5)) +# .layout(_r.layoutOptional(4)) +# .requires_grad(_r.toBool(7)) +# .pinned_memory(_r.toBool(6)); +# +# In some other cases one Python Argument can map to multiple C++ +# Arguments. For example: +# +# aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) +# -> (Tensor values, Tensor indices) +# +# Python Args Cpp Args Binding Exprs +# --------------------------------------------------------------------- +# +----> max 'out[0]' +# /-----> max_values 'out[1] +# 0: input / self '_r.tensor(0)' +# 1: dim / dim '_r.dimname(1)' +# 2: keepdim / keepdim '_r.toBool(2)' +# 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)' +# +# As demonstrated above, the binding can involve reordering, +# packing, unpacking and special local inits. +# +# +# Let's look at a concrete example: +# +# static PythonArgParser parser({ +# "abs(Tensor input, *, Tensor out=None)", +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- Python Schema, represented by PythonSignature and PythonArgument +# +# }, /*traceable=*/true); +# +# ParsedArgs<2> parsed_args; +# auto _r = parser.parse(nullptr, args, kwargs, parsed_args); +# +# ... +# +# if (_r.isNone(1)) { +# ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out') +# represented by PythonArgParserOutputExpr +# +# // aten::abs(Tensor self) -> Tensor +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- NativeFunction schema, base version +# +# auto dispatch_abs = [](const Tensor & self) -> Tensor { +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- dispatch_lambda_args / dispatch_lambda_return_str +# generated from NativeFunction / CppSignature +# (deprecated PythonSignature is special) +# arguments are represented by DispatchLambdaArgument +# +# pybind11::gil_scoped_release no_gil; +# return self.abs(); +# ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs +# generated from NativeFunction / CppSignature +# }; +# return wrap(dispatch_abs(_r.tensor(0))); +# ~~~~~~~~~~~~~ +# ^ +# +--- dispatch_lambda_exprs +# binding PythonArgParserOutputExpr (python args) +# and DispatchLambdaArgument (c++ args) +# +# } else { +# // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- NativeFunction schema, out-variant +# +# auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return at::abs_out(out, self); +# }; +# return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0))); +# } +# +# +# [Notes] python interface codegen +# The python dataclasses below are used used to generate both python binding code +# and pyi type hint signatures. +# In theory these two should look very similar, but there are number of differences +# in how pyi signatures vs. python_arg_parser signatures are generated. +# These differences have been encapsulated in signature_str() vs. signature_str_pyi() +# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments. +# For examples, only pyi signatures include return types. + + +@dataclass(frozen=True) +class PythonReturns: + returns: tuple[Return, ...] + + +@dataclass(frozen=True) +class PythonArgument: + name: str + type: Type + default: str | None + + # Used to generate the default init expr for some PythonArgParser outputs, e.g.: + # + # _r.layoutWithDefault(3, layout_from_backend(self.options().backend()))) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # ^ + # +--- default_init str + default_init: str | None + + # Compute argument formal for python argument parsing. + # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. + def argument_str(self, *, method: bool = False, symint: bool = True) -> str: + type_str = ( + argument_type_str(self.type, symint=symint) + .replace("const ", "") + .replace(" &", "") + ) + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == "self" and type_str in ["Tensor", "Number"] and not method: + name = "input" + + # add default + if self.default is not None: + default = { + "nullptr": "None", + "::std::nullopt": "None", + "std::nullopt": "None", + "{}": "None", + }.get(self.default, self.default) + return f"{type_str} {name}={default}" + else: + return f"{type_str} {name}" + + def argument_str_pyi( + self, *, method: bool = False, deprecated: bool = False + ) -> str: + type_str = argument_type_str_pyi(self.type) + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == "self" and type_str == "Tensor" and not method and not deprecated: + name = "input" + + if name == "from": # from is a Python keyword... + name += "_" + + # pyi merges the _out and functional variants into the same signature, with an optional out arg + if name == "out" and type_str == "Tensor" and not deprecated: + type_str = "Optional[" + type_str + "]" + + # pyi deprecated signatures don't get defaults for their out arg + treat_as_no_default = ( + deprecated + and isinstance(self, PythonOutArgument) + and self.default == "None" + ) + + # add default + if self.default is not None and not treat_as_no_default: + if ( + isinstance(self.type, ListType) + and self.type.elem == BaseType(BaseTy.int) + and self.default.startswith("{") + and self.default.endswith("}") + ): + default = ( + "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")" + ) + else: + default = { + "nullptr": "None", + "::std::nullopt": "None", + "std::nullopt": "None", + "{}": "None", + "c10::MemoryFormat::Contiguous": "contiguous_format", + "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine", + }.get(self.default, self.default) + return f"{name}: {type_str} = {default}" + else: + return f"{name}: {type_str}" + + +@dataclass(frozen=True) +class PythonOutArgument(PythonArgument): + # In Python signature multiple output fields are packed into one 'out' argument. + # When binding to C++, it's first binded to a local 'out' variable: + # 'auto out = _r.tensorlist_n<2>(2);', + # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc. + # TODO: maybe don't need keep scattered out fields for python signature? + outputs: tuple[PythonArgument, ...] + + @staticmethod + def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None: + if not outputs: + return None + + size = len(outputs) + if size == 1: + return PythonOutArgument( + name=outputs[0].name, + type=outputs[0].type, + default="None", + default_init=None, + outputs=outputs, + ) + elif size > 1: + if any(not a.type.is_tensor_like() for a in outputs): + raise RuntimeError(f"Unsupported output type: {outputs}") + return PythonOutArgument( + name="out", + # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None? + type=ListType(BaseType(BaseTy.Tensor), size), + default="None", + default_init=None, + outputs=outputs, + ) + raise AssertionError(r"Unexpected PythonOutArgument size") + + +@dataclass(frozen=True) +class PythonSignature: + # Base operator name, without inplace/outplace suffix. + name: str + + # Positional arguments. + # TODO: create a dedicated SelfArgument type for 'self'? + input_args: tuple[PythonArgument, ...] + + # Keyword arguments excluding the 'out' argument and scattered kwargs belonging + # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc). + input_kwargs: tuple[PythonArgument, ...] + + output_args: PythonOutArgument | None + + # Return types, which are only used by pyi + returns: PythonReturns + + # These are scattered kwargs arguments belonging to TensorOptions. + # When binding to C++, they are packed into a TensorOptions object 'options'. + # It's possible that the C++ signature doesn't take TensorOptions object (e.g. + # for out variant), in which case they will be used as scattered fields without + # being packed into 'options'. + # TODO: maybe create a PythonTensorOptionsArgument? + tensor_options_args: tuple[PythonArgument, ...] + + # method or function signature? + method: bool + + @property + def deprecated(self) -> bool: + return False + + def arguments( + self, *, skip_outputs: bool = False, skip_tensor_options: bool = False + ) -> tuple[PythonArgument | PythonOutArgument, ...]: + result: list[PythonArgument | PythonOutArgument] = [] + result.extend(self.input_args) + result.extend(self.input_kwargs) + if self.output_args is not None and not skip_outputs: + result.append(self.output_args) + if not skip_tensor_options: + result.extend(self.tensor_options_args) + return tuple(result) + + def arguments_count(self) -> int: + return len(self.arguments()) + + def output_idx(self) -> int: + return len(self.input_args) + len(self.input_kwargs) + + # [old codegen] Compute the Python function signature for argument parsing, + # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: + # this is NOT the same type signature as specified by PEP 484 + # as understood by mypy; our format was independently developed + # and has some quirks to make it more suitable specifically + # for error parsing. + # + # For a translation to mypy-valid type signatures, see + # signature_str_pyi(). + def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str(method=self.method, symint=symint) for a in args + ] + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, "*") + + return f'{self.name}({", ".join(schema_formals)})' + + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str_pyi(method=self.method) for a in args + ] + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, "*") + + # only pyi signatures include returns + returns_str = returns_str_pyi(self) + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: + # only pyi uses vararg signatures + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str_pyi(method=self.method) for a in args + ] + # vararg only applies to pyi signatures. vararg variants are not generated for all signatures + num_args = self.arguments_count() + num_positionalargs = len(self.input_args) + + have_vararg_version = False + if num_args > 0: + vararg_type = args[0].type + if ( + isinstance(vararg_type, ListType) + and str(vararg_type.elem) in ["int", "SymInt"] + and num_positionalargs == 1 + ): + have_vararg_version = True + + if not have_vararg_version: + return None + + # Below are the major changes in vararg vs. regular pyi signatures + # vararg signatures also omit the asterix + assert isinstance(vararg_type, ListType) + schema_formals[0] = ( + "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem) + ) + + returns_str = returns_str_pyi(self) + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + +# The deprecated python signature involves some special logic, so create a +# dedicated data model to store these extra properties. +@dataclass(frozen=True) +class PythonSignatureDeprecated(PythonSignature): + # Schema for the deprecated function + deprecated_schema: FunctionSchema + + # The deprecated signature might miss some arguments that the corresponding + # C++ signature expects. We need store the constant default values to pass in. + # For example: + # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) + # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + # [func call]: self.addmm(mat1, mat2, beta, 1) + # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case. + deprecated_args_exprs: tuple[str, ...] + + @property + def deprecated(self) -> bool: + return True + + def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: + return ( + PythonSignature.signature_str( + self, skip_outputs=skip_outputs, symint=symint + ) + + "|deprecated" + ) + + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str_pyi(method=self.method, deprecated=True) for a in args + ] + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, "*") + + returns_str = returns_str_pyi(self) + return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: + # the codegen doesn't include vararg variants for deprecated signatures + return None + + +# This struct is used to hold the PythonSignature and its corresponding +# NativeFunction BEFORE grouping base and out-variant functions. +# Why not store NativeFunction in PythonSignature or construct PythonSignature +# from NativeFunction? Because they are not 1-1 mapped. +# One native function could have both deprecated and non-deprecated python +# signatures - NativeFunction doesn't contain information to construct the +# deprecated python signature. +# One python signature is used to handle both the base and the out-variant +# function - see 'PythonSignatureGroup'. +@dataclass(frozen=True) +class PythonSignatureNativeFunctionPair: + signature: PythonSignature + function: NativeFunction + + +# We merge pairs of functions with signatures that are equivalent mod +# output arguments, and use a single entry in the python_arg_parser sig +# list for both (output arguments become optional). +@dataclass(frozen=True) +class PythonSignatureGroup: + # The signature used for Python argument parsing. The outplace signature + # is preferred if exists, because it can be used to parse inputs for both + # the out-place variant and the base version (with output omitted). + signature: PythonSignature + + # The regular ATen declaration (e.g. conv2d) + base: NativeFunction + + # The out variant (e.g. conv2d_out) + outplace: NativeFunction | None + + @classmethod + def from_pairs( + cls, + functional: PythonSignatureNativeFunctionPair, + out: PythonSignatureNativeFunctionPair | None, + ) -> PythonSignatureGroup: + if out is None: + return PythonSignatureGroup( + signature=functional.signature, + base=functional.function, + outplace=None, + ) + + # prefer the signature with optional out=... arguments because it's the + # superset that can be used to parse input for both base and outplace. + signature_kwargs = out.signature.__dict__.copy() + + # Out overloads in C++ don't have TensorOptions arguments, + # so take these from the functional variant + signature_kwargs[ + "tensor_options_args" + ] = functional.signature.tensor_options_args + + return PythonSignatureGroup( + signature=type(out.signature)(**signature_kwargs), + base=functional.function, + outplace=out.function, + ) + + +# C++ function dispatch is wrapped in a lambda function. The lambda function +# has almost the same signature as the C++ function, only with some small +# variants - see details below. +# This data model is used to represent arguments of the lambda function +# signature. +@dataclass(frozen=True) +class DispatchLambdaArgument: + name: str + type_str: str + is_out_arg: bool + + +# To pass PyObjects arguments to C++ function (via the lambda wrapper), +# we need first convert PyObjects into simple C++ objects. This work +# is done by PythonArgParser. +# This data model is used to represent the output of PythonArgParser. +# It has 1-1 mapping with PythonArgument in PythonSignature. +@dataclass(frozen=True) +class PythonArgParserOutputExpr: + # argument name + name: str + + # RHS expression to reference PythonArgParser output. + expr: str + + # In some special cases we need create different expr, e.g.: + # '_r.isNone(1)' instead of '_r.tensor(1)'. + index: int + + # The python argument it maps to. + argument: PythonArgument + + @property + def is_none_expr(self) -> str: + return f"_r.isNone({self.index})" + + +# To pass PythonArgParser output to the lambda wrapper, we need bind +# PythonArgParserOutputExpr to DispatchLambdaArgument. +# They are not always 1-1 mapped, e.g. scattered TensorOptions fields +# need be packed into a TensorOptions object, which is the argument +# that the lambda function wrapper takes. +@dataclass(frozen=True) +class DispatchLambdaArgumentExprs: + # The exprs that provide the binding for lambda arguments, e.g.: + # + # 'self' -> '_r.tensor(0)' + # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]' + # 'options' -> 'options' + # + # It has 1-1 mapping with DispatchLambdaArgument. + exprs: Sequence[str] + + # Special local inits, which might introduce new variables that + # the 'exprs' above reference, e.g.: + # + # 'auto out = _r.tensorlist_n<2>(2);' + # + inits: Sequence[str] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Helper Functions +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature: + return CppSignatureGroup.from_native_function(f, method=method).signature + + +def has_tensor_options(f: NativeFunction) -> bool: + return f.func.arguments.tensor_options is not None + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Signature +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# 'simple_type' was introduced by the old codegen, which is slightly +# different from the python schema type, e.g.: doesn't have '?' suffix +# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type. +def argument_type_str( + t: Type, *, simple_type: bool = False, symint: bool = True +) -> str: + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + return "Tensor" + elif t.name == BaseTy.int: + return "int64_t" + elif t.name == BaseTy.float: + return "double" + elif t.name == BaseTy.str: + return "c10::string_view" + elif t.name in [ + BaseTy.bool, + BaseTy.QScheme, + BaseTy.Scalar, + BaseTy.ScalarType, + BaseTy.Generator, + BaseTy.Storage, + BaseTy.Layout, + BaseTy.Device, + BaseTy.DeviceIndex, + BaseTy.MemoryFormat, + BaseTy.Dimname, + BaseTy.Stream, + BaseTy.ConstQuantizerPtr, + BaseTy.SymInt, + ]: + # These python schema type names line up with their function schema names + return t.name.name + + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + # Is it desired to keep '?' for simple_type with new style dispatcher? + return "Tensor?" + elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) + return f"{elem}?" + elif isinstance(t, ListType): + size = t.size if not simple_type else None + if str(t.elem) == "bool": + assert t.size is not None + return f"::std::array" + elif str(t.elem) == "int": + return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" + elif str(t.elem) == "SymInt": + if symint: + return ( + f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef" + ) + else: + return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" + elif str(t.elem) == "Tensor": + return f"TensorList[{size}]" if size is not None else "TensorList" + elif str(t.elem) == "Scalar": + return f"ScalarList[{size}]" if size is not None else "ScalarList" + elif str(t.elem) == "Tensor?": + if simple_type: + return "c10::List<::std::optional>" + else: + return "const c10::List<::std::optional> &" + elif str(t.elem) == "Dimname": + return f"DimnameList[{size}]" if size is not None else "DimnameList" + elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) + return f"ArrayRef<{elem}>" + + raise RuntimeError(f"unrecognized type {repr(t)}") + + +def argument_type_size(t: Type) -> int | None: + l = t.is_list_like() + if l is not None and str(l.elem) != "bool": + return l.size + else: + return None + + +def argument(a: Argument) -> PythonArgument: + return PythonArgument( + name=a.name, + type=a.type, + # TODO: directly translate a.default to python default + default=( + str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False))) + if a.default is not None + else None + ), + default_init=None, + ) + + +# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen +def signature( + f: NativeFunction, *, method: bool = False, pyi: bool = False +) -> PythonSignature: + return signature_from_schema( + f.func, category_override=f.category_override, method=method, pyi=pyi + ) + + +def signature_from_schema( + func: FunctionSchema, + *, + category_override: str | None, + method: bool = False, + pyi: bool = False, +) -> PythonSignature: + args: list[Argument] = [] + args.extend(func.arguments.pre_self_positional) + # Skip SelfArgument if this is method. + if not method and func.arguments.self_arg is not None: + args.append(func.arguments.self_arg.argument) + args.extend(func.arguments.post_self_positional) + args.extend(func.arguments.pre_tensor_options_kwarg_only) + # Skip TensorOptionsArguments. Python side TensorOptions + # arguments are created based on different rules - see below. + args.extend(func.arguments.post_tensor_options_kwarg_only) + args.extend(func.arguments.out) + + input_arg_set = {a.name for a in func.arguments.flat_positional} + kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only} + out_arg_set = {a.name for a in func.arguments.out} + + input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) + input_kwargs = tuple( + map(argument, filter(lambda a: a.name in kwarg_only_set, args)) + ) + outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args))) + + # Reintroduce the scattered fields of TensorOptions for Python. + # Compared to the cpp counterpart, the python arguments have new property + # (default_init) and a new argument 'requires_grad', which require some + # special handlings. + # [old codegen] TODO: because these aren't guaranteed to be 100% faithful + # to the original versions in the yaml, this recreation is a potential + # source of drift between eager and JIT. Pull this logic out to a shared place. + + has_tensor_input_arg = any( + a.type.is_tensor_like() for a in func.arguments.flat_non_out + ) + if any(a.name == "requires_grad" for a in func.schema_order_arguments()): + raise ValueError( + "argument named requires_grad is reserved, should not explicitly add it in the schema" + ) + + # [old codegen] this probably won't work if one of the returns is not a tensor, + # but it will produce a compile-time error that is obvious. + has_tensor_return = any(r.type.is_tensor_like() for r in func.returns) + + name: str = cpp.name(func) + is_factory_function = category_override == "factory" or ( + has_tensor_return and not has_tensor_input_arg + ) + is_like_or_new_function = ( + category_override in ("new", "like") + or name.startswith("new_") + or name.endswith("_like") + ) + is_dummy_function = category_override == "dummy" + + tensor_options_args: list[PythonArgument] = [] + if (is_factory_function or is_like_or_new_function) and not is_dummy_function: + + def topt_default_init(name: str) -> str | None: + topt_args = func.arguments.tensor_options + if topt_args is None: + return None + a = getattr(topt_args, name) + if a.default is None or a.default == "None": + return None + return cpp.default_expr(a.default, a.type, symint=False) + + tensor_options_args.append( + PythonArgument( + name="dtype", + type=OptionalType(BaseType(BaseTy.ScalarType)), + default="None", + default_init=( + None if is_like_or_new_function else topt_default_init("dtype") + ), + ) + ) + tensor_options_args.append( + PythonArgument( + name="layout", + type=OptionalType(BaseType(BaseTy.Layout)), + default="None", + default_init=( + None if is_like_or_new_function else topt_default_init("layout") + ), + ) + ) + tensor_options_args.append( + PythonArgument( + name="device", + type=OptionalType(BaseType(BaseTy.Device)), + default="None", + default_init=( + None + if is_like_or_new_function + else ( + topt_default_init("device") + or "torch::tensors::get_default_device()" + ) + ), + ) + ) + tensor_options_args.append( + PythonArgument( + name="pin_memory", + type=OptionalType(BaseType(BaseTy.bool)), + default="False", + default_init=None, + ) + ) + tensor_options_args.append( + PythonArgument( + name="requires_grad", + type=OptionalType(BaseType(BaseTy.bool)), + default="False", + default_init=None, + ) + ) + + returns = PythonReturns(returns=func.returns) + + return PythonSignature( + name=str(func.name.name), + input_args=input_args, + input_kwargs=input_kwargs, + output_args=PythonOutArgument.from_outputs(outputs), + tensor_options_args=tuple(tensor_options_args), + returns=returns, + method=method, + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Interface +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]: + if len(returns) <= 1 or all(r.name is None for r in returns): + return [] + else: + if any(r.name is None for r in returns): + # When building on Windows, `PyStructSequence_UnnamedField` could not be + # resolved by the linker for some reason, which cause error in building: + # + # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol + # PyStructSequence_UnnamedField + # + # Thus, at this point in time, we do not support unnamed + # fields in structseq; you must either name all fields, + # or none of them. + raise ValueError("Unnamed field is not supported by codegen") + + return [str(r.name) for r in returns] + + +def argument_type_str_pyi(t: Type) -> str: + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + + if isinstance(t, BaseType): + if t.name in [BaseTy.int, BaseTy.DeviceIndex]: + ret = "_int" + if t.name == BaseTy.SymInt: + ret = "Union[_int, SymInt]" + elif t.name == BaseTy.float: + ret = "_float" + elif t.name == BaseTy.str: + ret = "str" + elif t.name == BaseTy.Scalar: + ret = "Union[Number, _complex]" + elif t.name == BaseTy.ScalarType: + ret = "_dtype" + elif t.name == BaseTy.bool: + ret = "_bool" + elif t.name == BaseTy.QScheme: + ret = "_qscheme" + elif t.name == BaseTy.Layout: + ret = "_layout" + elif t.name == BaseTy.Device: + ret = "Optional[DeviceLikeType]" + elif t.name == BaseTy.MemoryFormat: + ret = "memory_format" + elif t.name == BaseTy.Dimname: + ret = "Union[str, ellipsis, None]" + elif t.name == BaseTy.Storage: + ret = "Union[Storage, UntypedStorage]" + elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]: + # These python schema type names line up with their function schema names + ret = t.name.name + + elif isinstance(t, ListType): + if str(t.elem) == "int": + ret = "Union[_int, _size]" if t.size is not None else "_size" + elif t.is_tensor_like(): + # TODO: this doesn't seem right... + # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]] + # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]] + if isinstance(t.elem, OptionalType): + add_optional = True + ret = ( + "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]" + if t.size is not None + else "Union[Tuple[Tensor, ...], List[Tensor]]" + ) + elif str(t.elem) == "float": + ret = "Sequence[_float]" + elif str(t.elem) == "SymInt" and t.size is not None: + elem = argument_type_str_pyi(t.elem) + ret = f"Union[{elem}, Sequence[{elem}]]" + else: + elem = argument_type_str_pyi(t.elem) + ret = f"Sequence[{elem}]" + + else: + raise RuntimeError(f"unrecognized type {repr(t)}") + + if add_optional: + ret = "Optional[" + ret + "]" + + return ret + + +def return_type_str_pyi(t: Type) -> str: + # Where arguments are open to accepting Union, return types should return + # concrete types + + if isinstance(t, OptionalType): + inner = return_type_str_pyi(t.elem) + return f"Optional[{inner}]" + + if isinstance(t, BaseType): + if t.name == BaseTy.Device: + return "_device" + elif t.name == BaseTy.Dimname: + ret = "Optional[str]" + else: + return argument_type_str_pyi(t) + + if isinstance(t, ListType): + inner = return_type_str_pyi(t.elem) + return f"Tuple[{inner}, ...]" + + return argument_type_str_pyi(t) + + +def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: + python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] + structseq_name = signature.name + field_names = structseq_fieldnames(signature.returns.returns) + if field_names: + # These types are structseq objects which act like named NamedTuples, but + # the constructor acts like the constructor of tuple. Using typing.NamedTuple + # does not allow us to override __init__. + seq_type = f"Tuple[{', '.join(python_returns)}]" + structseq_def_lines = [ + f"class {structseq_name}({seq_type}):", + ] + for name, typ in zip(field_names, python_returns): + structseq_def_lines.extend( + [ + " @property", + f" def {name}(self) -> {typ}: ...", + ] + ) + structseq_def_lines.extend( + [ + f" def __new__(cls, sequence: {seq_type}): ...", + f" n_fields: _int = {len(field_names)}", + f" n_sequeunce_fields: _int = {len(field_names)}", + " n_unnamed_fields: _int = 0", + " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", + "", # add an extra newline + ] + ) + structseq_def = "\n".join(structseq_def_lines) + # Example: + # structseq_def = ( + # "class max(Tuple[Tensor, Tensor]):\n" + # " @property\n" + # " def values(self) -> Tensor: ...\n" + # " @property\n" + # " def indices(self) -> Tensor: ...\n" + # " def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n" + # " n_fields: _int = 2", + # " n_sequeunce_fields: _int = 2", + # " n_unnamed_fields: _int = 0", + # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", + # ) + return structseq_name, structseq_def + return None + + +def returns_str_pyi(signature: PythonSignature) -> str: + field_names = structseq_fieldnames(signature.returns.returns) + if field_names: + return f"torch.return_types.{signature.name}" + + python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] + if len(python_returns) > 1: + return "Tuple[" + ", ".join(python_returns) + "]" + if len(python_returns) == 1: + return python_returns[0] + return "None" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# C++ Function Dispatch +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# This section provides APIs to generate the code that does C++ function +# dispatch. The C++ function call is wrapped by a lambda function. +# For example: +# +# // aten::selu_(Tensor(a!) self) -> Tensor(a!) +# auto dispatch_selu_ = [](Tensor self) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return at::selu_(self); +# }; +# +# The lambda function's signature follows the C++ signature in common +# cases, e.g.: +# +# // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +# [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor +# +# For out variant the 'out' argument's type is changed from 'Tensor &' +# to 'Tensor'. It's because when calling the lambda it passes in the +# PythonArgParser output '_r.tensor(3)', which is stack allocated object +# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'. +# +# // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +# [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor +# +# For multi-output case it can keep using reference type because the +# PythonArgParser output has been unpacked to local variables, e.g.: +# +# // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, +# // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) +# [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple +# +# For deprecated python signature, it should follow deprecated python arg order. +# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary? + + +def dispatch_lambda_args( + ps: PythonSignature, f: NativeFunction, symint: bool = True +) -> tuple[DispatchLambdaArgument, ...]: + if isinstance(ps, PythonSignatureDeprecated): + schema = ps.deprecated_schema + else: + schema = f.func + + # Start with cpp arguments - dispatch lambda signature always include 'self' + cpp_args = cpp.arguments( + arguments=schema.arguments, + faithful=False, + symint=symint, + method=False, + cpp_no_default_args=f.cpp_no_default_args, + ) + out_args: set[str] = {a.name for a in schema.arguments.out} + + # Convert from cpp argument to lambda argument + def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: + type_str = cpp_arg.type + is_out_arg = cpp_arg.name in out_args + if ps.method and cpp_arg.name == "self": + # For method's 'self', we can use 'const Tensor &' and simply ignore mutability! + type_str = "const at::Tensor &" + else: + # For other cases we need prevent dangling refs to temps (unless it's + # unpacked scattered output) + # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'. + # TODO: avoid this special handling? + ensure_temp_safe = len(out_args) <= 1 or not is_out_arg + if ensure_temp_safe: + type_str = { + "at::Tensor &": "at::Tensor", + }.get(type_str, type_str) + return DispatchLambdaArgument( + name=cpp_arg.name, + type_str=type_str, + is_out_arg=is_out_arg, + ) + + return tuple(map(dispatch_lambda_arg, cpp_args)) + + +# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean +# it's enough to just extend the list here. Before you do this, make sure +# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. +SUPPORTED_RETURN_TYPES = { + "at::Tensor", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple>", + "::std::vector", + # Needed for flash attention forw/backward + "::std::tuple", + "at::Scalar", + "bool", + "int64_t", + "void*", + "void", + "at::QScheme", + "double", + "at::IntArrayRef", + "at::ScalarType", + "at::Stream", +} + + +def dispatch_lambda_return_str(f: NativeFunction) -> str: + # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &') + # because the dispatch lambdas take mutable arguments *by value*, not + # by reference. If you then return a reference to such an argument, you + # will now have a pointer to a dangling stack entry. Not good. + # + # You want: + # + # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); }; + # ^^^^^^ + # + # *not* + # + # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); }; + # ^^^^^^^ + # + # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing + # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a + # mutable reference to temporary. Maybe we could assign it to a + # variable itself.) + returns_without_annotation = tuple( + Return(r.name, r.type, None) for r in f.func.returns + ) + return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type() + if return_str not in SUPPORTED_RETURN_TYPES: + raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}") + return return_str + + +def cpp_dispatch_target(f: NativeFunction) -> str: + symint = f.func.has_symint() + name = cpp.name(f.func, symint_overload=symint) + if Variant.method in f.variants: + return f"self.{name}" + if Variant.function in f.variants: + if has_tensor_options(f) or f.func.name.name.base.endswith("_like"): + namespace = "torch" + else: + namespace = "at" + return f"{namespace}::{name}" + raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}") + + +def cpp_dispatch_exprs( + f: NativeFunction, + *, + python_signature: PythonSignature | None = None, +) -> tuple[str, ...]: + cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() + + exprs: tuple[str, ...] = () + if not isinstance(python_signature, PythonSignatureDeprecated): + # By default the exprs are consistent with the C++ signature. + exprs = tuple(a.name for a in cpp_args) + else: + # For deprecated python signature we may need fill in some constants. + exprs = tuple( + filter( + lambda n: n != "out" or f.func.is_out_fn(), + python_signature.deprecated_args_exprs, + ) + ) + + if Variant.method in f.variants: + exprs = tuple(filter("self".__ne__, exprs)) + + return exprs + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python / C++ Args Binding +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# We explicitly enumerate the PythonArgParser unpacking methods for all +# supported types. This might be more verbose than necessary, partially +# because of the irregularity of unpacking method naming, partially +# because we want to mimic the old codegen behavior - to reject +# unexpected and/or unsupported cases which the old codegen rejects. +# For certain cases it is intentionally more restrictive than necessary, +# e.g.: it doesn't accepts doublelist with definite size. +def arg_parser_unpack_method( + t: Type, default: str | None, default_init: str | None, *, symint: bool = True +) -> str: + has_default_init = default_init is not None + if has_default_init and str(t) not in ( + "ScalarType?", + "ScalarType", + "Device", + "Device?", + "Layout", + "Layout?", + "bool", + "bool?", + ): + raise RuntimeError(f"type '{t}' does not supported unpacking with default") + + if isinstance(t, BaseType): + if t.name in [ + BaseTy.Tensor, + BaseTy.Stream, + BaseTy.Storage, + BaseTy.Scalar, + BaseTy.Dimname, + ]: + # These unpack methods line up with their schema names + return t.name.name.lower() + elif t.name == BaseTy.ScalarType: + return "scalartypeWithDefault" if has_default_init else "scalartype" + elif t.name == BaseTy.Device: + return "deviceWithDefault" if has_default_init else "device" + elif t.name == BaseTy.DeviceIndex: + return "toInt64" + elif t.name == BaseTy.int: + return "toInt64" + elif t.name == BaseTy.SymInt: + return "toSymInt" if symint else "toInt64" + elif t.name == BaseTy.bool: + return "toBoolWithDefault" if has_default_init else "toBool" + elif t.name == BaseTy.float: + return "toDouble" + elif t.name == BaseTy.str: + return "stringView" + elif t.name == BaseTy.Layout: + return "layoutWithDefault" if has_default_init else "layout" + elif t.name == BaseTy.MemoryFormat: + return "memoryformat" + + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + return "optionalTensor" + elif str(t.elem) == "Generator": + return "generator" + elif str(t.elem) == "Dimname[]": + return "toDimnameListOptional" + elif not has_default_init and default in ( + None, + "None", + "::std::nullopt", + "std::nullopt", + ): + # If default is None: append 'Optional' to elem's unpacking method + return ( + arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional" + ) + else: + # Otherwise, load as underlying type with default + return arg_parser_unpack_method( + t.elem, default, default_init, symint=symint + ) + + elif isinstance(t, ListType): + if str(t.elem) == "Tensor": + # accept and use definite size + return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist" + elif str(t.elem) == "Tensor?": + return "list_of_optional_tensors" + elif str(t.elem) == "Dimname": + # accept definite size + return "dimnamelist" + elif str(t.elem) == "int": + # accept definite size + return "intlist" + elif str(t.elem) == "float": + return "doublelist" + elif str(t.elem) == "SymInt": + # accept definite size + return "symintlist" if symint else "intlist" + elif str(t.elem) == "Scalar": + return "scalarlist" + raise RuntimeError(f"type '{t}' is not supported by PythonArgParser") + + +# Return RHS expression for python argument using PythonArgParser output. +# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' +def arg_parser_output_expr( + arg_index: int, a: PythonArgument, *, symint: bool = True +) -> PythonArgParserOutputExpr: + has_default = a.default_init is not None + unpack_method = arg_parser_unpack_method( + t=a.type, default=a.default, default_init=a.default_init, symint=symint + ) + default = f", {a.default_init}" if has_default else "" + expr = f"_r.{unpack_method}({arg_index}{default})" + + return PythonArgParserOutputExpr( + name=a.name, + expr=expr, + index=arg_index, + argument=a, + ) + + +# Returns a map with key = arg_name and value = PythonArgParserOutputExpr. +def arg_parser_output_exprs( + ps: PythonSignature, f: NativeFunction, *, symint: bool = True +) -> dict[str, PythonArgParserOutputExpr]: + return { + e.name: e + for i, a in enumerate(ps.arguments()) + for e in (arg_parser_output_expr(i, a, symint=symint),) + } + + +# argument name to type for scattered tensor options fields +TENSOR_OPTIONS_FIELDS = { + "dtype": "ScalarType?", + "device": "Device?", + "layout": "Layout?", + "pin_memory": "bool?", + "requires_grad": "bool?", +} + + +# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args). +def dispatch_lambda_exprs( + ps: PythonSignature, f: NativeFunction, *, symint: bool = True +) -> DispatchLambdaArgumentExprs: + # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing + # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser + # outputs. + arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) + lambda_args = dispatch_lambda_args(ps, f, symint=symint) + inits: list[str] = [] + lambda_args_exprs: dict[str, str] = {} + + has_toptions = has_tensor_options(f) + + # 1. special inits/unpacking to provide binding exprs for lambda arguments. + for a in ps.arguments(skip_tensor_options=True): + name = a.name + arg_parser_expr = arg_parser_outputs[a.name].expr + + if has_toptions and name == "self": + # TODO: why this needs to be special case? + inits.extend( + [ + f"auto self = {arg_parser_expr};", + ] + ) + lambda_args_exprs[name] = name + elif ( + isinstance(a, PythonOutArgument) + and len(a.outputs) > 1 + and f.func.is_out_fn() + ): + inits.extend( + [ + f"auto out = {arg_parser_expr};", + ] + ) + for i, out_arg in enumerate(a.outputs): + lambda_args_exprs[out_arg.name] = f"out[{i}]" + elif str(a.type) == "Dimname[]?": + # [old codegen] + # TODO: make this part of something more general, or get rid of it. + # optional> are special. The PythonArgParser returns an + # optional>, which cannot be implicitly converted to + # optional>. One needs to unwrap the optional and rewrap. + inits.extend( + [ + f"auto __{name} = {arg_parser_expr};", + f"::std::optional {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950 + ] + ) + lambda_args_exprs[name] = name + else: + # default case - directly using PythonArgParser output expr + lambda_args_exprs[name] = arg_parser_expr + + # method's self is passed directly to python binding, rather than parsed + if ps.method: + lambda_args_exprs["self"] = "self" + + # 2. special packing/checking for TensorOptions. + tensor_options_args_names = [a.name for a in ps.tensor_options_args] + if has_toptions: + if f.func.is_out_fn(): + raise RuntimeError(f"{f.func}: tensor options with output arg") + for a in ps.tensor_options_args: + if a.name not in TENSOR_OPTIONS_FIELDS: + raise RuntimeError( + f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments" + ) + if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): + raise RuntimeError( + f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'" + ) + if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS): + raise RuntimeError( + f"{f.func}: incomplete tensor options args: {tensor_options_args_names}" + ) + + inits.append( + f"""\ +const auto options = TensorOptions() + .dtype({arg_parser_outputs['dtype'].expr}) + .device({arg_parser_outputs['device'].expr}) + .layout({arg_parser_outputs['layout'].expr}) + .requires_grad({arg_parser_outputs['requires_grad'].expr}) + .pinned_memory({arg_parser_outputs['pin_memory'].expr}); +torch::utils::maybe_initialize_device(options); +""" + ) + lambda_args_exprs["options"] = "options" + + # 3. special case - access scattered TensorOptions fields without packing + # TODO: maybe move to the generator side as it's not related to binding. + if not has_toptions and tensor_options_args_names: + if "dtype" in tensor_options_args_names: + # we're an output-arg variant, check these args against output tensor + if not f.func.is_out_fn(): + raise RuntimeError( + f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}" + ) + if not all(a in tensor_options_args_names for a in ("layout", "device")): + raise RuntimeError( + f"{f.func}: incomplete tensor options for output check" + ) + + inits.append( + f"""\ +check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr}, + {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr}, + {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr}); +""" + ) + # we'll set requires_grad on outgoing tensor + if "requires_grad" not in tensor_options_args_names: + raise RuntimeError( + f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]' + ) + + return DispatchLambdaArgumentExprs( + exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args), + inits=inits, + ) diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/README.md b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bfa43899cc590959c2bfd74e38662ec03aaee3d6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/README.md @@ -0,0 +1,3 @@ +If you add a file to this directory, you **MUST** update +`torch/CMakeLists.txt` and add the file as a dependency to +the `add_custom_command` call. diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__init__.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec9592367ea99650facfb75113d2729d4cdde433 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a934358f0a742d7a1fab2cb2b979f830c7ad16b5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17a36bd17a9a9ffb5e552788a652c6348c8f1e83 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b815d765a88db47d2f9219663c7522e654a9128d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d149d378026b1ecc81345a6687c9ddee6543c848 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdb023d7ca006fc15c029a6d7d36ad6214b8699d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfb992d2175025b4d762057080c05f52e84a580a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..522a43d55a0bf2bdd636e98f25190d2443b29ff8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..967411636b0b187636c9ce695e49af825242a2e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55ee4ed3517fa248f232b6883ae9834d67e9992a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb321ff129a3104b950e304a9599c3c027802273 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5b18f2c7544657f517136b36fc003dc54d94c86 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/build.bzl b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/build.bzl new file mode 100644 index 0000000000000000000000000000000000000000..588bd5944e29477119782591b231fd80a7a57cf4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/build.bzl @@ -0,0 +1,14 @@ +def define_targets(rules): + rules.py_library( + name = "autograd", + srcs = rules.glob(["*.py"]), + data = rules.glob([ + "*.yaml", + "templates/*", + ]), + visibility = ["//:__subpackages__"], + deps = [ + rules.requirement("PyYAML"), + "//torchgen", + ], + ) diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/context.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/context.py new file mode 100644 index 0000000000000000000000000000000000000000..d838aa3c77bbbc0f37cd7fa6e005d85c9e9dd624 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/context.py @@ -0,0 +1,31 @@ +import functools +from typing import Callable + +from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI +from torchgen.context import native_function_manager +from torchgen.utils import T + + +# Like tools.api.context.with_native_function, but for +# NativeFunctionWithDifferentiabilityInfo. +def with_native_function_with_differentiability_info( + func: Callable[[NFWDI], T] +) -> Callable[[NFWDI], T]: + @functools.wraps(func) + def wrapper(f: NFWDI) -> T: + with native_function_manager(f.func): + return func(f) + + return wrapper + + +# Like the above but with an additional dispatch key string argument +def with_native_function_with_differentiability_info_and_key( + func: Callable[[NFWDI, str], T] +) -> Callable[[NFWDI, str], T]: + @functools.wraps(func) + def wrapper(f: NFWDI, key: str) -> T: + with native_function_manager(f.func): + return func(f, key) + + return wrapper diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/derivatives.yaml b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/derivatives.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f7ea3fbeb4ff4a5ec04578f3c3751b2870fc3b0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/derivatives.yaml @@ -0,0 +1,3206 @@ +# Defines derivative formulas and Python signatures of methods on Variable +# +# Note about possibly confusing nomenclature: An 'output gradient' is the +# gradient of an output of a forward function. Output gradients are used as +# the inputs to backward functions. `grads` is a vector of output gradients, +# and `grad == grads[0]`, in all the derivative formulas in this file. +# An 'input gradient' is the gradient of an input to a forward function. +# Input gradients are the outputs of backward functions, corresponding to the +# input names included in the derivative formulas defined in this file. +# Also, every time we talk computing "gradient" we actually mean computing +# the vector jacobian product using the given 'output gradient' as the vector. +# +# Each entry consists of: +# - A 'name', which specifies the ATen name of the function you +# are defining derivatives for, and an argument specification. +# - An optional 'dispatch' entry which can be used to specify +# per-autograd dispatch key derivatives. If this entry is not +# specified, then the gradient entries will be taken as the +# default gradients (i.e. registered for every backward dispatch +# key). (see _test_autograd_multiple_dispatch for an example +# of how to register separate derivates for different dispatch keys). +# The list of allowed dispatch keys (in addition to 'Default' which +# represents the Autograd alias key) is torchgen/model.py:AUTOGRAD_KEYS. +# - One or more gradients entries, mapping differentiable input +# names to a formula specifying how to compute its gradient. +# Note that a single gradient entry can specify the gradient +# formula for multiple input names, by specifying a key +# "input1, input2" (see atan2 for an example). +# - An argument can be flagged as 'non_differentiable'. +# - Optional entry with key 'output_differentiability' and value a list of the +# same length as the number of outputs from the forward function. The list +# should contain only booleans, specifying whether each of the output Tensor +# is differentiable. +# If it is not specified for a function that returns multiple elements but +# uses `grad` instead of `grads[idx]`, then all but the first output will +# be marked as non-differentiable. +# If None of the output is differentiable, you can also add the function +# name to `gen_variable_type.py`'s `DONT_REQUIRE_DERIVATIVE` list. +# +# There are two cases for Tensor and TensorList arguments here: +# - If that argument is differentiable, in the sense that a gradient with respect +# to that argument could exist. You should either: +# - Specify the formula for that gradient +# - Specify not_implemented("function_name") as a formula to say that this is not +# implemented yet (but might be in the future and the user can request that on an issue) +# - If that argument is not differentiable, because it is not a floating point dtype or the +# function is not differentiable with respect to that argument for +# example. You should either: +# - Do not specify any formula for this argument +# - Specify explicitly that this argument is "non_differentiable". Note that in this case, +# we trust you that this argument will never have requires_grad=True and it will be silently +# ignored if it does. +# +# If a function has out-of-place and in-place variants, then the derivative +# definition for the in-place variant is optional. It will default to the +# definition for the out-of-place variant. Note that _out variants are never +# differentiable. +# +# Gradient expressions are standard C++ expressions operating on ATen +# variables. In a gradient expression, the following variables/functions +# are in scope: +# +# - 'grad', the gradient of the output (often spelled grad_output +# in Python) which we are going to left-multiply. +# +# When a function returns multiple *differentiable* outputs, +# you can refer to the gradients of each outputs using 'grads', +# e.g., 'grads[0]', 'grads[1]'. +# +# When a function returns multiple *differentiable* outputs that +# are named, you can refer to the gradients of each outputs using +# 'grad_{name}', e.g., 'grad_x', 'grad_y'. +# +# When a function returns *one* differentiable output (the +# first output) and some more nondifferentiable outputs, +# you MUST refer to the gradient of the differentiable output with +# 'grad' (this case is special-cased in our code generation). +# +# Note that the number of differentiable outputs can be modified by the +# 'output_differentiability' entry (see above). +# +# Across a differentiable function's derivatives set, it is not +# permitted to mix the use of "grad", "grads", and +# "grad_{name}". You must be consistent for that differentiable +# function. +# +# - Any of the input arguments, tensor or non-tensor, including +# argument names that only appear in Declarations.yaml, e.g. 'output'. +# +# - 'result', representing the result of evaluating the forward +# expression for ATen native function declarations. If the forward +# expression outputs a tuple, use 'resultX' instead to access the +# X-th entry +# +# - 'grad_input_mask', a std::array, specifies which input +# gradients are actually needed. For example, in the entry +# `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size +# two array, where `grad_input_mask[0]` is true if `input0` requires +# grad, and `grad_input_mask[1]` is true if `input1` requires grad. +# +# (NB: if your function computes gradient for a list of tensors, +# the `grad_input_mask` will only have a single entry for the list +# specifying if either zero or at least one tensor from the list requires +# grad. If we want to support more fine-grained signalling, +# we'll need some alternate variable which is not a std::array) +# +# - 'retain_variables', a bool which is true if a user has specified +# that saved variables should be retained in case the backwards is +# run again later. This allows an optimization where we can +# destroy saved buffers if we know variables are not going to be retained, +# e.g., it is used by _cudnn_rnn +# +# - `wrap_opt_if`, is a 2-argument function that accepts a tensor +# variable and a boolean condition that dictates whether to save that +# variable in a graph. The result of this function is `c10::optional`, +# and it is `::std::nullopt` when the condition evalutes to `false`, +# otherwise it is the variable wrapped in `c10::optional`. +# For example, wrap_opt_if(var_0, grad_input_mask[1] || grad_input_mask[2]) +# would mean that `var_0` is saved as long as the second (grad_input_mask[1]) +# or the third (grad_input_mask[2]) argument requires gradients. +# Another interpretation of this expression would read as `var_0` is needed +# in the backward computation of the second or the third argument. +# NOTE: the usage of `var_i.requires_grad()` in the conditional expression +# is not supported, use `grad_input_mask[i]` instead. +# NOTE: `wrap_opt_if` could be used to prevent saving redundant variables +# with multi-output backward formulas. +# See https://github.com/pytorch/pytorch/issues/97575 for more details +# on the issue. +# +# If you need a complex expression, e.g., with local variables, +# write a _backward function in torch/csrc/autograd/FunctionsManual.cpp +# and invoke it from here. By the way, go read +# https://github.com/zdevito/ATen/issues/163; this describes an +# important hazard that occurs when porting backwards from Python to C++ +# +# Double backwards gradient expressions can be somewhat confusing; +# the most important thing to remember is: (1) you need to define a +# derivative formula for every input, including inputs named things +# like 'grad_output', and (2) the gradient to multiply with is always +# called 'grad' (even though it really is a grad-grad). +# +# You can also add forward derivative definition by defining a formula for +# a returned value (in general "result" if the name is not specified). This +# formula works the same way as the backward one and advanced implementations +# should also be placed in the FunctionsManual file. +# This formula should compute a single Jacobian vector product using the (primal) +# value of the argument "foo_p", its forward grad "foo_t" and the result of the +# function as "result". +# Note that the forward derivative can be automatically generated in two cases: +# - if your function is linear (NOT affine or multi-linear), then you can +# specify so by just using the string "auto_linear" for the formula. +# - if your function is applied element wise (and has a single input), you +# can specify so by just using the string "auto_element_wise" for the formula. +# +# Note that to avoid unpacking overhead, functions taking TensorList as inputs +# will always have their forward grad formula called. This function is responsible +# to check if any computation is needed and should return an undefined Tensor when +# there is nothing to do. You can check "cat_forward" for a full example. +# +# NB: There are a number of gradient definitions in here which are bogus +# (implemented using zeros_like). These gradients are (hopefully) not +# used by our frontend. You MUST check the frontend code; search for +# OpName.apply to see if it's still using a legacy Python style API. +# +# Note: Returning views. +# The following cases exist: +# - If a function returns no view, it can have arbitrary outputs. +# - If a function return at least one Tensor that is a differentiable view +# of one of its input: +# - If there is only one differentiable output, this Tensor is marked as a +# differentiable view. (alias or transpose for example) +# - If there are more than one differentiable output, by default all the views are +# marked as differentiable views and created with allow_rebase_history=false. +# Meaning that any inplace operation on it will raise an error. (unbind for example) +# +# Notes about undefined output gradients: +# All backward functions must support all combinations of undefined output +# gradient Tensors, where `grad[i].defined() == false`. Depending on the +# number of input and output grads your derivative formula uses, code +# generation may automatically add some level of undefined grad support, +# according to these three cases: +# +# * 1 input grad and 1 output grad: +# Complete undefined grad support is automatically added, so you +# shouldn't have to think about it, unless there is a bug in the code +# generation. +# +# * 1 input grad and multiple output grads: +# Undefined grad support is automatically added ONLY in the case where +# all output grads are undefined. You will have to add explicit support +# for cases where a subset of output grads is undefined. +# +# * multiple input grads: +# No automatic support, so you will need to add it. +# +# If your derivative formula uses more than one output grad, it is usually +# preferable to add undefined grad support in the backward function itself +# (if you're using one), rather than in the derivative formula in this file. +# +# Undefined Tensors are created with the default constructor `at::Tensor()`. +# It is an efficient way to represent a Tensor filled with zeros because +# the Tensor holds no sizing information and no Storage data is allocated. +# But consequentially, Tensor operations cannot be performed on them. +# Therefore, your backward function should treat an undefined output grad as +# a zero, and it needs to be a special case. +# +# If all output grads are undefined, then it should be correct for the +# backward function to return undefined input grads. Since we use the chain +# rule, output grads equal to zero should result in input grads equal to zero, +# unless there is some rare special case. +# +# If a subset of output grads is undefined, then it may be acceptable for +# the backward function to return undefined input grads--it depends on the +# specific function, so you'll have to determine that yourself. If returning +# an undefined Tensor is correct for a given input grad, it is also logically +# correct to return a defined grad full of zeros, but that would not be +# preferable since it would be less efficient. +# +# NB: The parameter names here MUST be consistent with the parameter names +# in native_functions.yaml +- name: abs(Tensor self) -> Tensor + self: grad * self.sgn() + result: handle_r_to_c(result.scalar_type(), self_t.conj() * self_p.sgn()) + +- name: acos(Tensor self) -> Tensor + self: grad * -((-self * self + 1).rsqrt()).conj() + result: auto_element_wise + +- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj())) + result: self_t + maybe_multiply(other_t, alpha) + +- name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + result: self_t.clone() + +- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + batch1: maybe_multiply(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj()) + batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) })), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha) + +- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj()) + tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj()) + result: self_t + maybe_multiply(tensor1_t / tensor2_p, value) - maybe_multiply(tensor2_t * (tensor1_p / tensor2_p) / tensor2_p, value) + +- name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (tensor2 * value).conj()) + tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj()) + result: self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value) + +- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + mat1: mm_mat1_backward(grad, mat2, mat1.sym_sizes(), mat1.sym_strides(), mat1.layout(), alpha) + mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha) + result: maybe_multiply(self_t, beta) + maybe_multiply(mat1_t.mm(mat2_p), alpha) + maybe_multiply(mat1_p.mm(mat2_t), alpha) + +- name: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta) + mat1: mm_mat1_sparse_backward(grad, mat1, mat2, alpha) + mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha) + +- name: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + mat: maybe_multiply(grad.ger(vec.conj()), alpha.conj()) + vec: maybe_multiply(mat.t().conj().mv(grad), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(mat_t.mv(vec_p), alpha) + maybe_multiply(mat_p.mv(vec_t), alpha) + +- name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + vec1: maybe_multiply(grad.mv(vec2.conj()), alpha.conj()) + vec2: maybe_multiply(grad.t().mv(vec1.conj()), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(vec1_t.outer(vec2_p), alpha) + maybe_multiply(vec1_p.outer(vec2_t), alpha) + +- name: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor + theta: affine_grid_generator_backward_symint(grad, size, align_corners) + +- name: alias(Tensor(a) self) -> Tensor(a) + self: grad + result: self_t + +- name: angle(Tensor self) -> Tensor + self: angle_backward(grad, self) + result: handle_r_to_c(result.scalar_type(), angle_backward(self_t.conj(), self_p).conj()) + +# The four items below are necessary because TensorIterator doesn't work on +# Variables (codegen does not unwrap the input Tensor for all() and any() ). +- name: any(Tensor self) -> Tensor + output_differentiability: [False] + +- name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: _is_all_true(Tensor self) -> Tensor + self: non_differentiable + +- name: _is_any_true(Tensor self) -> Tensor + self: non_differentiable + +- name: all(Tensor self) -> Tensor + output_differentiability: [False] + +- name: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: acosh(Tensor self) -> Tensor +# Save one rsqrt in the real case by using that for x real and positive sqrt(x*y) = sqrt(x)*sqrt(y) (not true in the complex case) + self: "self.is_complex() ? grad * ((self + 1).rsqrt() * (self - 1).rsqrt()).conj() : grad * (self * self - 1).rsqrt()" + result: auto_element_wise + +- name: acosh_(Tensor(a!) self) -> Tensor(a!) + self: not_implemented("inplace version of acosh") + +- name: asinh(Tensor self) -> Tensor + self: grad * (self.pow(2) + 1).rsqrt().conj() + result: auto_element_wise + +- name: asinh_(Tensor(a!) self) -> Tensor(a!) + self: not_implemented("inplace version of asinh") + +- name: atanh(Tensor self) -> Tensor + self: grad * 1 / (1 - self.pow(2)).conj() + result: auto_element_wise + +- name: atanh_(Tensor(a!) self) -> Tensor(a!) + self: not_implemented("inplace version of atanh") + +- name: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) + self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) + result: auto_linear + +- name: as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) + self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) + result: auto_linear + +- name: asin(Tensor self) -> Tensor + self: grad * (-self * self + 1).rsqrt().conj() + result: auto_element_wise + +- name: atan(Tensor self) -> Tensor + self: grad / (self * self + 1).conj() + result: auto_element_wise + +- name: atan2(Tensor self, Tensor other) -> Tensor + self, other: atan2_backward(grad, self, other, grad_input_mask) + result: (-self_p * other_t + other_p * self_t) / (self_p.pow(2) + other_p.pow(2)) + +- name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + batch1: maybe_multiply(grad.bmm(batch2.transpose(1, 2).conj()), alpha.conj()) + batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p), alpha) + maybe_multiply(batch1_p.bmm(batch2_t), alpha) + +- name: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + p: zeros_like(p) + result: self_t.zero_() + +- name: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: bmm(Tensor self, Tensor mat2) -> Tensor + self: grad.bmm(mat2.transpose(1, 2).conj()) + mat2: self.transpose(1, 2).conj().bmm(grad) + result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t) + +- name: matmul(Tensor self, Tensor other) -> Tensor + self, other: matmul_backward(grad, self, other, grad_input_mask) + +- name: cat(Tensor[] tensors, int dim=0) -> Tensor + tensors: cat_tensors_backward(grad, to_args_sizes_symint(tensors), to_args_scalartypes(tensors), dim) + result: cat_jvp(tensors, dim) + +- name: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: ceil(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: cholesky(Tensor self, bool upper=False) -> Tensor + self: cholesky_backward(grad, upper, result) + +- name: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] + dispatch: + Default: + # the default case will use the CompositeImplicitAutograd + self: not_implemented("chunk") + AutogradNestedTensor: + self: chunk_backward_nested(grads, self, chunks, dim) + +- name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) + self: cholesky_backward(grad, upper, L) + L: cholesky_jvp(self_t, L, upper) + +- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor + self, input2: cholesky_solve_backward(grad, self, input2, result, upper, grad_input_mask) + result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper) + +- name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor + self: cholesky_inverse_backward(grad, self, upper, result) + result: cholesky_inverse_jvp(self_p, self_t, result, upper) + +# For clamp, gradient is not defined at the boundaries. But empirically it's helpful +# to be able to get gradient on min and max, so we return the subgradient 1 for these cases. +- name: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor + self: clamp_backward(grad, self, min, max) + min, max: clamp_backward_min_max(grad, self, min, max, grad_input_mask) + result: clamp_jvp(self_p, self_t, min_p, min_t, max_p, max_t) + +- name: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + self: clamp_backward(grad, self, min, max) + result: auto_element_wise + +- name: clamp_min(Tensor self, Scalar min) -> Tensor + self: where(self >= min, grad, at::scalar_tensor(0., grad.options())) + result: auto_element_wise + +- name: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor + self: where(self >= min, grad, at::scalar_tensor(0., grad.options())) + min: where(self < min, grad, at::scalar_tensor(0., grad.options())) + result: where(self_p >= min_p, self_t, min_t) + +- name: clamp_max(Tensor self, Scalar max) -> Tensor + self: where(self <= max, grad, at::scalar_tensor(0., grad.options())) + result: auto_element_wise + +- name: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor + self: where(self <= max, grad, at::scalar_tensor(0., grad.options())) + max: where(self > max, grad, at::scalar_tensor(0., grad.options())) + result: where(self_p <= max_p, self_t, max_t) + +- name: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor + self: grad + result: auto_linear + +- name: _lazy_clone(Tensor self) -> Tensor + self: grad + result: auto_linear + +- name: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor + self: _to_copy_backward(grad, self.options()) + result: _to_copy(self_t, dtype, layout, device, pin_memory, non_blocking, memory_format) + # The condition is: if dtype is not nullopt, then isDifferentiableType(*dtype) + # (If dtype IS nullopt, we rely on the regular check that any input requires grad). + output_differentiability: ["!dtype || isDifferentiableType(*dtype)"] + +- name: _coalesce(Tensor self) -> Tensor + self: grad + +- name: complex(Tensor real, Tensor imag) -> Tensor + real: at::real(grad) + imag: at::imag(grad) + result: at::complex(real_t, imag_t) + +- name: polar(Tensor abs, Tensor angle) -> Tensor + abs, angle: polar_backward(grad, result) + result: at::complex(abs_t*angle_p.cos() - angle_t*abs_p*angle_p.sin(), abs_t*angle_p.sin() + angle_t*abs_p*angle_p.cos()) + +- name: _conj(Tensor(a) self) -> Tensor(a) + self: grad.conj() + result: self_t.conj() + +- name: _neg_view(Tensor(a) self) -> Tensor(a) + self: grad.neg() + result: self_t._neg_view() + +- name: _conj_physical(Tensor self) -> Tensor + self: grad.conj_physical() + result: self_t.conj_physical() + +- name: conj_physical_(Tensor(a!) self) -> Tensor(a!) + self: grad.conj_physical() + result: self_t.conj_physical_() + +- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + other: zeros_like(other) + result: copysign_tensor_self_backward(self_t, self_p, result) + +- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + result: auto_element_wise + +- name: cos(Tensor self) -> Tensor + self: grad * -self.sin().conj() + result: auto_element_wise + +- name: cosh(Tensor self) -> Tensor + self: grad * self.sinh().conj() + result: auto_element_wise + +- name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor + output_differentiability: [False] + +- name: count_nonzero(Tensor self, int? dim=None) -> Tensor + output_differentiability: [False] + +- name: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor + self: at::linalg_cross(other.conj(), grad, dim) + other: at::linalg_cross(grad, self.conj(), dim) + result: "at::linalg_cross(self_t, other_p, dim) + at::linalg_cross(self_p, other_t, dim)" + +- name: logcumsumexp(Tensor self, int dim) -> Tensor + self: logcumsumexp_backward(grad, self, result, dim) + result: logcumsumexp_jvp(self_p, self_t, dim) + +- name: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + self: cumprod_backward(grad.to(self.scalar_type()), self, dim, result) + result: "cumprod_jvp(self_t, self_p, result, dim).to(dtype.has_value() ? *dtype : self_p.scalar_type())" + +- name: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + self: cumsum_backward(grad.to(self.scalar_type()), dim) + result: auto_linear + +- name: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) + self: cummaxmin_backward(grad, self, indices, dim) + values: self_t.gather(dim, indices) + +- name: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) + self: cummaxmin_backward(grad, self, indices, dim) + values: self_t.gather(dim, indices) + +- name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor + self, weight, bias: "grad.defined() ? conv_tbc_backward(grad, self, weight, bias, pad) : std::tuple()" + +- name: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity) + +- name: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity) + +- name: deg2rad(Tensor self) -> Tensor + self: deg2rad_backward(grad) + result: auto_element_wise + +- name: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) + A: linalg_det_backward(grad, result, A, LU, pivots) + result: linalg_det_jvp(A_t, result, LU, pivots, A_p.is_contiguous() && !A_p.is_complex()) + output_differentiability: [True, False, False] + +- name: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) + A: slogdet_backward(grad_sign, grad_logabsdet, A, sign, LU, pivots) + sign, logabsdet: slogdet_jvp(LU, pivots, A_t, sign, A_p.is_contiguous() && !A_p.is_complex()) + output_differentiability: [True, True, False, False] + +- name: block_diag(Tensor[] tensors) -> Tensor + tensors: block_diag_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors)) + result: block_diag_jvp(tensors) + +- name: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor + self: grad.diagonal(offset, dim1, dim2) + result: auto_linear + +- name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) + self: diagonal_backward_symint(grad, self.sym_sizes(), offset, dim1, dim2) + result: auto_linear + +- name: diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor + grad_output: grad.diagonal(offset, dim1, dim2) + result: auto_linear + +- name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor + self: norm_backward(grad, self - other, p, result) + other: -norm_backward(grad, self - other, p, result) + result: norm_jvp(self_p - other_p, self_t - other_t, p, result, {}, false) + +# The backward formula is done in this order to improve numerical stability +# of the higher order derivatives, see https://github.com/pytorch/pytorch/issues/43414 +# Note that we don't use "result" because saving it would be BC-breaking when it is used in an inplace operation later +- name: div.Tensor(Tensor self, Tensor other) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type()) + other: div_tensor_other_backward(grad, self, other) + result: (self_t - other_t * result) / other_p + +- name: div.Scalar(Tensor self, Scalar other) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type()) + result: self_t / other + +- name: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode) + other: div_tensor_other_backward(grad, self, other, rounding_mode) + result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p" + +- name: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode) + result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other" + +- name: dot(Tensor self, Tensor tensor) -> Tensor + self: grad * tensor.conj() + tensor: grad * self.conj() + result: at::dot(self_t, tensor_p) + at::dot(self_p, tensor_t) + +- name: vdot(Tensor self, Tensor other) -> Tensor + self: grad.conj() * other + other: grad * self + result: at::vdot(self_t, other_p) + at::vdot(self_p, other_t) + +- name: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) + self: _fused_dropout_backward(grad, result1, p) + +- name: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) + input: "GradMode::is_enabled() ? infinitely_differentiable_native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p)))) : native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p))))" + result0: "(!train.has_value() || train.value()) ? (p == 1 ? 0.0 : 1.0 / (1.0 - p)) * input_t * result1 : input_t" + +- name: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor + grad_output: "native_dropout_double_backward(grad, grad_output, mask, scale)" + mask: 'not_implemented("native_dropout_backward: mask")' + +- name: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: erf(Tensor self) -> Tensor + self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad + result: auto_element_wise + +- name: erfc(Tensor self) -> Tensor + self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad + result: auto_element_wise + +- name: special_erfcx(Tensor self) -> Tensor + self: (2.0 * self * result - 2.0 / sqrt(M_PI)) * grad + result: auto_element_wise + +- name: erfinv(Tensor self) -> Tensor + self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad + result: auto_element_wise + +- name: exp(Tensor self) -> Tensor + self: grad * result.conj() + result: auto_element_wise + +- name: exp2(Tensor self) -> Tensor + self: grad * result.conj() * M_LN2 + result: auto_element_wise + +- name: expm1(Tensor self) -> Tensor + self: grad * (result.conj() + 1) + result: auto_element_wise + +# TODO: this derivative is not SymInt safe, need sum_to support +- name: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) + self: at::sum_to(grad, self.sym_sizes()) + result: auto_linear + +- name: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + +- name: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + +- name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max, grad_factor) : std::tuple()" + +- name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask) + +- name: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor) : std::tuple()" + +- name: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + +- name: fill.Scalar(Tensor self, Scalar value) -> Tensor + self: zeros_like(grad) + result: at::fill(self_t, 0) + +- name: fill.Tensor(Tensor self, Tensor value) -> Tensor + self: zeros_like(grad) + value: grad.sum() + result: at::fill(self_t, value_t) + +- name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.fill_(0) + +- name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) + self: zeros_like(grad) + value: grad.sum() + result: self_t.fill_(value_t) + +- name: floor(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: fmod.Scalar(Tensor self, Scalar other) -> Tensor + self: grad + result: auto_element_wise + +- name: fmod.Tensor(Tensor self, Tensor other) -> Tensor + self: grad + other: -grad * self.div(other, /*rounding_mode=*/"trunc") + result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"trunc") + +- name: frac(Tensor self) -> Tensor + self: grad + result: self_t + +- name: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) + self: grad / exponent.exp2() + mantissa: self_t / exponent.exp2() + +- name: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor + self: gather_backward(grad, self, dim, index, sparse_grad) + index: non_differentiable + result: auto_linear + +- name: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: geqrf(Tensor self) -> (Tensor a, Tensor tau) + self: not_implemented("geqrf") + +- name: indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: _indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: crow_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: col_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: ccol_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: row_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + input, grid: "grad.defined() ? grid_sampler_2d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()" + +- name: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + input, grid: "grad.defined() ? grid_sampler_3d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()" + +# See NOTE [ grid_sample CPU fallback ] +- name: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + input, grid: "grad.defined() ? _grid_sampler_2d_cpu_fallback_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners) : std::tuple()" + +- name: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: hardsigmoid(Tensor self) -> Tensor + self: hardsigmoid_backward(grad, self) + result: auto_element_wise + +- name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor + output_differentiability: [False] + +- name: hardswish(Tensor self) -> Tensor + self: hardswish_backward(grad, self) + result: auto_element_wise + +- name: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor + grad_output: hardswish_backward(grad, self) + self: at::where(at::logical_and(-3.0 < self, self < 3.0), grad * grad_output / 3.0, at::zeros({}, self.options())) + result: "hardswish_backward(grad_output_t, self_p) + + at::where(at::logical_and(-3.0 < self_p, self_p < 3.0), self_t * grad_output_p / 3.0, at::zeros({}, self_p.options()))" + +- name: hypot(Tensor self, Tensor other) -> Tensor + self: grad * self / result + other: grad * other / result + result: self_t * self_p / result + other_t * other_p / result + +- name: i0(Tensor self) -> Tensor + self: grad * at::special_i1(self) + result: auto_element_wise + +- name: special_i0e(Tensor self) -> Tensor + self: grad * (at::special_i1e(self) - self.sgn() * result) + result: auto_element_wise + +- name: special_i1(Tensor self) -> Tensor + self: i1_backward(grad, self, result) + result: auto_element_wise + +- name: special_i1e(Tensor self) -> Tensor + self: i1e_backward(grad, self, result) + result: auto_element_wise + +- name: igamma(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igamma: input")' + other: grad * exp((self - 1) * log(other) - other - lgamma(self)) + +- name: igammac(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igammac: input")' + other: -grad * exp((self - 1) * log(other) - other - lgamma(self)) + +- name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + self: index_backward(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad) + result: auto_linear + +- name: _unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + self: at::_unsafe_index_put(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad, true) + result: auto_linear + +- name: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor + self: at::_unsafe_masked_index_put_accumulate(grad.new_zeros_symint(self.sym_sizes(), self.options()), mask, indices, grad) + mask: non_differentiable + result: _unsafe_masked_index(self_t, mask, indices, 0) + +- name: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor + self: grad + mask: non_differentiable + values: at::_unsafe_masked_index(grad, mask, indices, 0) + result: at::_unsafe_masked_index_put_accumulate(self_t, mask, indices, values_t) + +- name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor + self: grad + # The case source.dim() == 0 is necessary to support scalar tensors of the form + # source.dim() == 0 and index.dim() == 1 and index.size() == (1,), + # This is because source is not broadcastable to index, as source.dim() < index.dim() + source: "maybe_multiply(source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0)), alpha)" + index: non_differentiable + result: at::index_add(self_t, dim, index, maybe_multiply(source_t, alpha)) + +- name: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor + self, source: index_reduce_backward(grad, self, dim, index, source, reduce, include_self, result) + index: non_differentiable + +- name: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor + self: grad.index_fill(dim, index, 0) + # The case source.dim() == 0 is necessary to support scalar tensors of the form + # source.dim() == 0 and index.dim() == 1 and index.size() == (1,), + # This is because source is not broadcastable to index, as source.dim() < index.dim() + source: "source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0))" + index: non_differentiable + result: self_t.index_copy(dim, index, source_t) + +- name: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + self: grad.index_fill(dim, index, 0) + index: non_differentiable + result: self_t.index_fill(dim, index, 0) + +- name: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor + self: grad.index_fill(dim, index, 0) + value: grad.index_select(dim, std::get<0>(at::_unique(index, /*sorted=*/false))).sum() + index: non_differentiable + result: self_t.index_fill(dim, index, value_t) + +- name: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)" + values: grad.index(indices) + result: self_t.index_put(indices, values_t, accumulate) + +- name: _unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + self: "accumulate ? grad : at::_unsafe_index_put(grad, indices, zeros_like(values), false)" + values: at::_unsafe_index(grad, indices) + result: at::_unsafe_index_put(self_t, indices, values_t, accumulate) + +- name: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) + self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)" + values: grad.index(indices) + result: at::_index_put_impl_(self_t, indices, values_t, accumulate, unsafe) + +- name: index_select(Tensor self, int dim, Tensor index) -> Tensor + self: index_select_backward_symint(grad, self.sym_sizes(), dim, index) + index: non_differentiable + result: auto_linear + +- name: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) + A: -at::matmul(inverse.mH(), at::matmul(grad, inverse.mH())) + inverse: -at::matmul(at::matmul(inverse, A_t), inverse) + output_differentiability: [True, False] + +- name: linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor + self: pinv_backward(grad, result, self) + result: pinv_jvp(self_p, result, self_t) + +- name: isnan(Tensor self) -> Tensor + self: non_differentiable + +- name: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor + self: "weight.isComplex() ? grad * (1 - weight.conj().toComplexDouble()) : grad * (1 - weight.toDouble())" + end: grad * weight.conj() + result: at::lerp(self_t, end_t, weight) + +- name: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor + self: grad * (1 - weight).conj() + end: grad * weight.conj() + weight: grad * (end - self).conj() + result: at::lerp(self_t, end_t, weight_p) + weight_t * (end_p - self_p) + +- name: lgamma(Tensor self) -> Tensor + self: grad * digamma(self) + result: auto_element_wise + +- name: digamma(Tensor self) -> Tensor + self: grad * polygamma(1, self) + result: auto_element_wise + +- name: polygamma(int n, Tensor self) -> Tensor + self: grad * polygamma(n + 1, self) + result: auto_element_wise + +- name: polygamma_(Tensor(a!) self, int n) -> Tensor(a!) + self: grad * polygamma(n + 1, self) + result: self_t.mul_(polygamma(n + 1, original_self_p)) + +- name: log(Tensor self) -> Tensor + self: grad.div(self.conj()) + result: auto_element_wise + +- name: log10(Tensor self) -> Tensor + self: grad / (self.conj() * 2.3025850929940456) + result: auto_element_wise + +- name: log1p(Tensor self) -> Tensor + self: log1p_backward(grad, self) + result: auto_element_wise + +- name: log2(Tensor self) -> Tensor + self: grad / (self.conj() * 0.6931471805599453) + result: auto_element_wise + +- name: logaddexp(Tensor self, Tensor other) -> Tensor + self: grad / (1 + exp(other - self)).conj() + other: grad / (1 + exp(self - other)).conj() + result: self_t / (1 + exp(other_p - self_p)) + other_t / (1 + exp(self_p - other_p)) + +- name: logaddexp2(Tensor self, Tensor other) -> Tensor + self: grad / (1 + pow(2, other - self)) + other: grad / (1 + pow(2, self - other)) + result: self_t / (1 + pow(2, other_p - self_p)) + other_t / (1 + pow(2, self_p - other_p)) + +# Note [Gradient formula for xlogy at x = 0, y <= 0] +# x * log(y) is not defined at y <= 0, so we cannot even talk about differentiability +# Now, xlogy(0, y) = 0 by definition. +# This does not make it differentiable as it's not defined in a neighbourhood of a point +# (0, y) when y <= 0. +# Now, when a function is non-differentiable, sometimes we return "a relatively sensible value" +# In this case, as per the discussion in https://github.com/pytorch/pytorch/issues/80770, we choose +# this value to be zero, which is the directional derivative along the line {x = 0}. +- name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor + self: at::xlogy(grad, other).masked_fill((self == 0.) & (other <= 0.), 0.) + other: grad * self / other + result: at::xlogy(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= 0.), 0.) + other_t * self_p / other_p + +- name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + other: grad * self / other + result: auto_element_wise + +- name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + self: "other.toDouble() > 0. + ? at::xlogy(grad, other) + : at::xlogy(grad, other).masked_fill(self == 0., 0.)" + result: auto_element_wise + +# See Note [Gradient formula for xlogy at x = 0, y <= 0] +# Same here but with y <= -1 +- name: special_xlog1py(Tensor self, Tensor other) -> Tensor + self: at::special_xlog1py(grad, other).masked_fill((self == 0.) & (other <= -1.), 0.) + other: grad * self / (other + 1) + result: at::special_xlog1py(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= -1.), 0.) + other_t * self_p / (other_p + 1) + +- name: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor + other: grad * self / (other + 1) + result: auto_element_wise + +- name: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor + self: "other.toDouble() > -1. + ? at::special_xlog1py(grad, other) + : at::special_xlog1py(grad, other).masked_fill(self == 0., 0.)" + result: auto_element_wise + +- name: special_zeta(Tensor self, Tensor other) -> Tensor + self: not_implemented("zeta") + other: grad * -self * special_zeta(self + 1., other) + +- name: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor + other: grad * -self * special_zeta(self.toDouble() + 1., other) + +- name: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor + self: not_implemented("zeta") + +- name: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + self: logsumexp_backward(grad, self, result, dim, keepdim) + result: logsumexp_jvp(self_p, self_t, dim, keepdim) + +- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values) + self, b: linalg_lstsq_backward(grad, self, b, grad_input_mask) + solution: linalg_lstsq_jvp(self_p, b_p, self_t, b_t) + output_differentiability: [True, False, False, False] + +- name: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + A: lu_factor_ex_backward(grad, LU, pivots, pivot) + LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot) + output_differentiability: [True, False, False] + +- name: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) + A: lu_factor_ex_backward(grad, LU, pivots, pivot) + LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot) + output_differentiability: [True, False] + +- name: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U) + A: linalg_lu_backward(grad_L, grad_U, P, L, U, pivot) + L: std::get<0>(linalg_lu_jvp(A_t, P, L, U, pivot)) + U: std::get<1>(linalg_lu_jvp(A_t, P, L, U, pivot)) + output_differentiability: [False, True, True] + +- name: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor + LU: linalg_lu_solve_LU(grad, LU, pivots, result, left, adjoint) + B: "at::linalg_lu_solve(LU, pivots, grad, left, !adjoint)" + result: linalg_lu_solve_jvp(result, LU_p, pivots, LU_t, B_t, left, adjoint) + +- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) + LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1)) + LU_pivots: non_differentiable + L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)" + U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()" + output_differentiability: [False, True, True] + +- name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor + self: grad.masked_fill(mask, 0) + mask: non_differentiable + result: self_t.masked_fill(mask, 0) + +- name: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor + self: grad.masked_fill(mask, 0) + value: masked_fill_backward(grad, mask) + mask: non_differentiable + result: self_t.masked_fill(mask, value_t) + +- name: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + self: grad.masked_fill(mask, 0) + source: masked_scatter_backward_symint(grad, mask, source.sym_sizes()) + mask: non_differentiable + result: self_t.masked_scatter(mask, source_t) + +- name: masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor + grad_output: zeros_like(grad_output).masked_scatter(mask, grad) + mask: non_differentiable + result: masked_scatter_backward(grad_output_t, mask, grad_output_t.sizes()) + +- name: masked_select(Tensor self, Tensor mask) -> Tensor + self: masked_select_backward(grad, self, mask) + mask: non_differentiable + result: auto_linear + +- name: linalg_matrix_exp(Tensor self) -> Tensor + self: linalg_matrix_exp_differential(self, grad, /*adjoint*/ true) + result: linalg_matrix_exp_differential(self_p, self_t, /*adjoint*/ false) + +- name: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: max(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +- name: maximum(Tensor self, Tensor other) -> Tensor + self: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0) + other: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0) + result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p > other_p).to(result.scalar_type())) * (self_t - other_t) + +- name: fmax(Tensor self, Tensor other) -> Tensor + self: grad.masked_fill((self >= other).logical_or_(other.isnan()).logical_not_(), 0) + other: grad.masked_fill((self >= other).logical_or_(other.isnan()), 0) + result: other_t + (self_p > other_p).logical_or_(other_p.isnan()) * (self_t - other_t) + +- name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor + self: grad.expand_symint(self.sym_sizes()) / self.sym_numel() + result: auto_linear + +- name: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: mean_backward(grad, self.sym_sizes(), dim, self.sym_numel(), keepdim) + result: auto_linear + +- name: median(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +- name: nanmedian(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +# This is in theory incorrect in the following case: +# sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value +# | at middle position of the +# | list between two `b`s. E.g., +# | +# ^the middle position +# The gradient exists and is essentially 0 in this case. +# +# In case where the middle position is at the boundary of `b` range, e.g., +# sorted list: [..., a, b, b, ..., b, b, c, ...] +# | +# ^the middle position +# The backward implementation is correct in the sense that it returns the +# subgradient on one side. +- name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: min(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +- name: minimum(Tensor self, Tensor other) -> Tensor + self: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0) + other: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0) + result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p < other_p).to(result.scalar_type())) * (self_t - other_t) + +- name: fmin(Tensor self, Tensor other) -> Tensor + self: grad.masked_fill((self <= other).logical_or_(other.isnan()).logical_not_(), 0) + other: grad.masked_fill((self <= other).logical_or_(other.isnan()), 0) + result: other_t + (self_p <= other_p).logical_or_(other_p.isnan()) * (self_t - other_t) + +- name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) + result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) + +- name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) + result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) + +- name: mm(Tensor self, Tensor mat2) -> Tensor + self: mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), 1) + mat2: mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1) + result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t) + +- name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: mul.Tensor(Tensor self, Tensor other) -> Tensor + self: mul_tensor_backward(grad, other, self.scalar_type()) + other: mul_tensor_backward(grad, self, other.scalar_type()) + result: other_t * self_p + self_t * other_p + +- name: mul.Scalar(Tensor self, Scalar other) -> Tensor + self: mul_tensor_backward(grad, other, self.scalar_type()) + result: self_t * other + +- name: mv(Tensor self, Tensor vec) -> Tensor + self: grad.ger(vec.conj()) + vec: self.conj().t().mv(grad) + result: mv(self_t, vec_p) + mv(self_p, vec_t) + +- name: mvlgamma(Tensor self, int p) -> Tensor + self: mvlgamma_backward(grad, self, p) + result: auto_element_wise + +- name: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + self: grad * at::isfinite(self) + result: auto_element_wise + +- name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) + +- name: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) + +- name: _native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*training=*/false, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, /*training=*/false, eps) + +- name: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, Tensor(), Tensor(), result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, Tensor(), Tensor(), result1, result2, training, eps) + +- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask) + save_mean: not_implemented("native_batch_norm_backward save_mean") + save_invstd: not_implemented("native_batch_norm_backward save_invstd") + +- name: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_layer_norm_backward_symint(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple()" + result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape) + +- name: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask) + bias: Tensor() + mean: not_implemented("native_layer_norm_backward mean") + rstd: not_implemented("native_layer_norm_backward rstd") + +- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" + result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) + result1: group_norm_mean_jvp(input_t, result1, group) + result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group) + +- name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: neg(Tensor self) -> Tensor + self: grad.neg() + result: auto_element_wise + +- name: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/true, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps) + +- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps) + +- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) + input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask) + save_mean: not_implemented("batch_norm_backward save_mean") + save_var: not_implemented("batch_norm_backward save_var") + reserve: not_implemented("batch_norm_backward reserve") + +- name: nextafter(Tensor self, Tensor other) -> Tensor + self: not_implemented("nextafter") + other: not_implemented("nextafter") + +- name: norm.Scalar(Tensor self, Scalar p=2) -> Tensor + self: norm_backward(grad, self, p, result) + result: norm_jvp(self_p, self_t, p, result) + +- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor + self: norm_backward(grad, self, p, result, dim, keepdim) + result: norm_jvp(self_p, self_t, p, result, dim, keepdim) + +- name: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor + self: norm_backward(grad, self.to(grad.scalar_type()), p, result) + result: norm_jvp(self_p, self_t, p, result) + +- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim) + result: norm_jvp(self_p, self_t, p, result, dim, keepdim) + +- name: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: linalg_vector_norm_backward(grad, self, ord, result, dim, keepdim) + result: linalg_vector_norm_jvp(self_p, self_t, ord, result, dim, keepdim) + +- name: _pdist_forward(Tensor self, float p=2) -> Tensor + self: _pdist_backward(grad, self, p, result) + +- name: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor + grad: not_implemented("_pdist_backward") + self: not_implemented("_pdist_backward") + pdist: not_implemented("_pdist_backward") + +- name: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor + x1, x2: _euclidean_dist_backward(grad, x1, x2, result) + +- name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor + x1: _cdist_backward(grad.contiguous(), x1, x2, p, result) + x2: _cdist_backward(grad.mT().contiguous(), x2, x1, p, result.mT().contiguous()) + +- name: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor + grad: not_implemented("_cdist_backward") + x1: not_implemented("_cdist_backward") + x2: not_implemented("_cdist_backward") + cdist: not_implemented("_cdist_backward") + +- name: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor + mean: at::zeros_symint(mean.sym_sizes(), grad.options()) + result: auto_element_wise + +- name: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor + std: at::zeros_symint(std.sym_sizes(), grad.options()) + result: auto_element_wise + +- name: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor + mean: at::zeros_symint(mean.sym_sizes(), grad.options()) + std: at::zeros_symint(std.sym_sizes(), grad.options()) + result: zeros_like(mean_t) + +- name: linalg_householder_product(Tensor input, Tensor tau) -> Tensor + input, tau: householder_product_backward(grad, result, input, tau) + result: householder_product_jvp(input_t, tau_t, result, input_p, tau_p) + +- name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor + self, input2, input3: ormqr_backward(grad, result, self, input2, input3, left, transpose, grad_input_mask) + +- name: permute(Tensor(a) self, int[] dims) -> Tensor(a) + self: permute_backwards(grad, dims) + result: auto_linear + +- name: poisson(Tensor self, Generator? generator=None) -> Tensor + self: zeros_like(self) + result: auto_element_wise + +- name: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + self: pow_backward(grad, self, exponent) + result: auto_element_wise + +- name: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + self: pow_backward_self(grad, self, exponent) + exponent: pow_backward_exponent(grad, self, exponent, result) + result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result)).conj() + +- name: pow.Scalar(Scalar self, Tensor exponent) -> Tensor + exponent: pow_backward_exponent(grad, self, exponent, result) + result: auto_element_wise + +- name: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor + self: prod_backward(grad, self.to(grad.scalar_type()), result) + result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result) * self_t.conj()).sum().conj() + +- name: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim) + result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj() + +- name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor + self: "accumulate ? grad : grad.put(index, zeros_like(source), false)" + index: non_differentiable + source: grad.take(index).reshape_as(source) + result: self_t.put(index, source_t, accumulate) + +- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) + A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode) + Q, R: linalg_qr_jvp(A_t, Q, R, mode) + +- name: rad2deg(Tensor self) -> Tensor + self: rad2deg_backward(grad) + result: auto_element_wise + +- name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: reciprocal(Tensor self) -> Tensor + self: -grad * (result * result).conj() + result: auto_element_wise + +- name: remainder.Scalar(Tensor self, Scalar other) -> Tensor + self: grad + result: auto_element_wise + +- name: remainder.Tensor(Tensor self, Tensor other) -> Tensor + self: grad + other: -grad * self.div(other, /*rounding_mode=*/"floor") + result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"floor") + +- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor + self: renorm_backward(grad, self, p, dim, maxnorm) + result: renorm_jvp(self_p, self_t, p, dim, maxnorm) + +- name: repeat(Tensor self, SymInt[] repeats) -> Tensor + self: repeat_backward(grad, repeats, self.sym_sizes()) + result: auto_linear + +- name: special_entr(Tensor self) -> Tensor + self: grad * (-(1 + self.log())) + result: auto_element_wise + +- name: special_ndtri(Tensor self) -> Tensor + self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp() + result: auto_element_wise + +- name: special_log_ndtr(Tensor self) -> Tensor + self: grad / std::sqrt(2 * M_PI) * (result + self.pow(2) / 2).neg().exp() + result: auto_element_wise + +# [Note: Sometimes view derivatives] +# The following situation applies to other operations as well. +# TODO: This note is only referenced by to_dense and to_sparse*. Make +# this more generic if it's been referenced more than once. +# +# DO NOT define a backward for reshape! +# reshape is special in that it sometimes returns a view, and sometimes not. +# Defining a backward will make codegen spit out the forward call as +# as_variable(baseType->reshape(self)), +# making it impossible (hard) to detect when it is actually a view. +# - name: reshape(Tensor self, IntArrayRef shape) + +- name: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + +- name: round(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: round.decimals(Tensor self, *, int decimals) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: rsqrt(Tensor self) -> Tensor + self: -0.5 * grad * result.pow(3).conj() + result: auto_element_wise + +- name: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + self: grad.scatter(dim, index, 0) + index: non_differentiable + src: grad.gather(dim, index) + result: self_t.scatter(dim, index, src_t) + +- name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + self: grad.scatter(dim, index, 0) + index: non_differentiable + result: self_t.scatter(dim, index, 0) + +- name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + self: grad + index: non_differentiable + src: grad.gather(dim, index) + result: scatter_add(self_t, dim, index, src_t) + +- name: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) + dispatch: + Default: + self: select_backward_symint(grad, self.sym_sizes(), dim, index) + result: auto_linear + AutogradNestedTensor: + self: _nested_select_backward_symint(grad, self, dim, index) + +- name: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + grad_output: grad.select_symint(dim, index) + result: auto_linear + +- name: sigmoid(Tensor self) -> Tensor + self: sigmoid_backward(grad, result) + result: auto_element_wise + +- name: logit(Tensor self, float? eps=None) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_logit_backward(grad, self, eps) : logit_backward(grad, self, eps)" + result: auto_element_wise + +- name: sign(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: sgn(Tensor self) -> Tensor + self: sgn_backward(self, grad, result) + # Cannot use auto_element_wise here because the Jacobian is *not* Hermitian (in fact, it is symmetric) + # The function is not holomorphic, so there's no reason for its Jacobian to be Hermitian + # auto_element_wise has a name that's a bit deceiving in the complex case + result: sgn_backward(self_p, self_t, result) + +- name: sin(Tensor self) -> Tensor + self: grad * self.cos().conj() + result: auto_element_wise + +- name: sinc(Tensor self) -> Tensor + self: sinc_backward(grad, self) + result: auto_element_wise + +- name: sinh(Tensor self) -> Tensor + self: grad * self.cosh().conj() + result: auto_element_wise + +- name: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + self: slice_backward_wrapper(grad, self.sym_sizes(), dim, start, end, step) + result: auto_linear + +- name: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + grad_output: grad.slice_symint(dim, start, end, step) + result: auto_linear + +- name: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + self: grad.slice_symint(dim, start, end, step) + src: slice_scatter_symint(grad, zeros_like(self), dim, start, end, step) + result: auto_linear + +- name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step) + src: grad.slice_symint(dim, start, end, step) + result: auto_linear + +- name: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor + self: select_scatter_symint(grad, zeros_like(src), dim, index) + src: grad.select_symint(dim, index) + result: auto_linear + +- name: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor + self: diagonal_scatter(grad, zeros_like(src), offset, dim1, dim2) + src: grad.diagonal(offset, dim1, dim2) + result: auto_linear + +- name: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + self: as_strided_scatter_backward(grad, TensorGeometry(self), TensorGeometry(src), size, stride, storage_offset) + # See Note [as_strided_scatter backward support] + src: grad.contiguous().as_strided_symint(size, stride, storage_offset) + result: auto_linear + +- name: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) + A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1]) + result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())" + output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user + +- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + output_differentiability: [True, False] + values: gather_with_keepdimed_indices(self_t, dim, indices, true) + +- name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + output_differentiability: [True, False] + values: gather_with_keepdimed_indices(self_t, dim, indices, true) + +- name: split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] + self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options()) + result: auto_linear + +- name: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options()) + result: auto_linear + +- name: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] + dispatch: + Default: + self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options()) + result: auto_linear + AutogradNestedTensor: + self: _nested_split_with_sizes_backward(grads, split_sizes, dim, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), self.options()) + +- name: unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options()) + result: auto_linear + +- name: sqrt(Tensor self) -> Tensor + self: grad / (2 * result.conj()) + result: auto_element_wise + +- name: squeeze(Tensor(a) self) -> Tensor(a) + self: unsqueeze_to(grad, self.sym_sizes()) + result: auto_linear + +- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) + dispatch: + Default: + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + self: grad.unsqueeze(dim) + +- name: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) + dispatch: + Default: + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + self: unsqueeze_multiple(grad, dim, self.dim()) + +- name: squeeze_(Tensor(a!) self) -> Tensor(a!) + self: unsqueeze_to(grad, self.sym_sizes()) + result: auto_linear + +- name: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + +- name: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + +- name: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + self: std_backward(result, grad, self, dim, correction, keepdim) + # pointwise (variance) + sum + sqrt + result: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result)).masked_fill_(result == 0, 0) + +- name: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + self: std_mean_backward(grads[0], grads[1], self, result0, dim, correction, keepdim) + result0: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result0)).masked_fill_(result0 == 0, 0) + # linear + result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) + +- name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), maybe_multiply(-grad, alpha.conj())) + result: self_t - maybe_multiply(other_t, alpha) + +- name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + result: auto_element_wise + +- name: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj())) + other: handle_r_to_c(other.scalar_type(), grad) + result: -maybe_multiply(self_t, alpha) + other_t + +- name: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj())) + result: auto_element_wise + +- name: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor + self: grad.expand_symint(self.sym_sizes()) + result: auto_linear + +- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + dispatch: + Default: + self: sum_backward(grad, self.sym_sizes(), dim, keepdim) + result: auto_linear + AutogradNestedTensor: + # TODO: replace this function once semantics for nested tensor expand have been settled on + self: _nested_sum_backward(grad, self, dim, keepdim) + +- name: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) + result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype) + +# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here +- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow_symint(-1, 0, S.sym_size(-1)) : grad_U, + grad_S, + full_matrices && grad_Vh.defined() ? grad_Vh.narrow_symint(-2, 0, S.sym_size(-1)) : grad_Vh, + full_matrices ? U.narrow_symint(-1, 0, S.sym_size(-1)) : U, + S, + full_matrices ? Vh.narrow_symint(-2, 0, S.sym_size(-1)) : Vh)" + U, S, Vh: linalg_svd_jvp(A_t, U, S, Vh, full_matrices) + +- name: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors) + A: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/true) + eigenvalues, eigenvectors: linalg_eig_jvp(A_t, eigenvalues, eigenvectors, /*is_hermitian=*/true) + +- name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) + self: handle_r_to_c(self.scalar_type(), linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/false)) + eigenvalues, eigenvectors: linalg_eig_jvp(self_t, eigenvalues, eigenvectors, /*is_hermitian=*/false) + +- name: t(Tensor(a) self) -> Tensor(a) + self: grad.t() + result: auto_linear + +- name: t_(Tensor(a!) self) -> Tensor(a!) + self: grad.t() + result: auto_linear + +- name: one_hot(Tensor self, int num_classes=-1) -> Tensor + self: non_differentiable + +- name: flip(Tensor self, int[] dims) -> Tensor + self: grad.flip(dims) + result: auto_linear + +- name: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor + self: grad.roll_symint(fmap(reverse_list_symint(shifts), [](c10::SymInt i){return -i;}), reverse_list(dims)) + result: auto_linear + +- name: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor + self: grad.rot90(-k, dims) + result: auto_linear + +- name: take(Tensor self, Tensor index) -> Tensor + self: take_backward(grad, self, index) + index: non_differentiable + result: auto_linear + +- name: tan(Tensor self) -> Tensor + self: grad * (1 + result.pow(2)).conj() + result: auto_element_wise + +- name: tanh(Tensor self) -> Tensor + self: tanh_backward(grad, result) + result: auto_element_wise + +- name: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + output_differentiability: [True, False] + values: gather(self_t, dim, indices) + +- name: trace(Tensor self) -> Tensor + self: trace_backward_symint(grad, self.sym_sizes()) + result: auto_linear + +- name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + self: grad.transpose(dim0, dim1) + result: auto_linear + +- name: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + self: grad.transpose(dim0, dim1) + result: auto_linear + +- name: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) + self, A: triangular_solve_backward(grad_solution, grad_cloned_coefficient, self, A, solution, upper, transpose, unitriangular, grad_input_mask) + solution: triangular_solve_jvp(solution, A_p, A_t, self_t, upper, transpose, unitriangular) + cloned_coefficient: A_t + +- name: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor + self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask) + result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular) + +- name: tril(Tensor self, int diagonal=0) -> Tensor + self: grad.tril(diagonal) + result: auto_linear + +- name: triu(Tensor self, int diagonal=0) -> Tensor + self: grad.triu(diagonal) + result: auto_linear + +- name: trunc(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +# DO NOT define a backward for to_dense +# See [Note: Sometimes view derivatives] +# - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor +# +- name: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor + self: to_dense_backward(grad, self, masked_grad) + +# DO NOT define a backward for to_sparse.sparse_dim +# See [Note: Sometimes view derivatives] +# - name: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor +# +- name: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse +# See [Note: Sometimes view derivatives] +# - name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor +# +- name: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_csr +# See [Note: Sometimes view derivatives] +# - name: to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_csc +# See [Note: Sometimes view derivatives] +# - name: to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_bsr +# See [Note: Sometimes view derivatives] +# - name: to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_bsc +# See [Note: Sometimes view derivatives] +# - name: to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +- name: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor + self: to_mkldnn_backward(grad, self) + +- name: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) + self: unfold_backward_symint(grad, self.sym_sizes(), dimension, size, step) + result: auto_linear + +- name: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor + grad_in: grad.unfold(dim, size, step) + result: auto_linear + +- name: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + self: not_implemented("_unique") + +- name: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_dim") + +- name: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_consecutive") + +- name: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_dim_consecutive") + +- name: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("_unique2") + +- name: _unsafe_view(Tensor self, SymInt[] size) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + +- name: lift(Tensor self) -> Tensor + self: grad + result: auto_linear + +- name: lift_fresh(Tensor(a) self) -> Tensor(a) + self: grad + result: auto_linear + +- name: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) + self: grad.squeeze(dim) + result: auto_linear + +- name: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) + self: grad.squeeze(dim) + result: auto_linear + +- name: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + self: var_backward(grad, self, dim, correction, keepdim) + # pointwise + sum + result: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) + +- name: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + self: var_mean_backward(grads[0], grads[1], self, dim, correction, keepdim) + result0: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) + # linear + result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) + +- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a) + dispatch: + Default: + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + self: grad.reshape_as(self) + result: auto_linear + +- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + output_differentiability: [False] + +- name: view_as_real(Tensor(a) self) -> Tensor(a) + self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1 + result: at::view_as_real(self_t) + +- name: view_as_complex(Tensor(a) self) -> Tensor(a) + self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy] + result: at::view_as_complex(self_t) + +- name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor + condition: non_differentiable + self: where(condition, grad, 0) + other: where(condition, 0, grad) + result: where(condition, self_t, other_t) + +# weight_norm_cuda_interface_backward does not have an explicitly defined derivative, so if we do happen +# to be running backward with create_graph=True, fall back to a backward function that uses +# differentiable ops. +- name: _weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) + v, g: "grad.defined() ? (GradMode::is_enabled() ? _weight_norm_differentiable_backward(grad.contiguous(), v, g, result1, dim) : _weight_norm_interface_backward(grad.contiguous(), v, g, result1, dim)) : std::tuple()" + +- name: zero_(Tensor(a!) self) -> Tensor(a!) + self: zeros_like(grad) + result: auto_linear + +- name: sparse_mask(Tensor self, Tensor mask) -> Tensor + self: sparse_mask_backward(grad, mask, self.layout()) + mask: non_differentiable + +- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + indices: non_differentiable + values: grad.sparse_mask(result)._values() + +- name: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + compressed_indices: non_differentiable + plain_indices: non_differentiable + # TODO: remove to_dense after gh-107381 is fixed + values: grad.to_dense().sparse_mask(result).values() + +- name: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor + self: at::_sparse_sum_backward(grad, self, dim) + +- name: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor + self: grad * _standard_gamma_grad(self, result) + +- name: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor + self: not_implemented("_standard_gamma_grad") + +- name: values(Tensor(a) self) -> Tensor(a) + dispatch: + Default: + self: values_backward(grad, self) + AutogradNestedTensor: + self: at::_nested_view_from_buffer(grad.contiguous(), self._nested_tensor_size(), self._nested_tensor_strides(), self._nested_tensor_storage_offsets()) + +# Why is _values() not differentiable? +# See NOTE [ Sparse: autograd and API ] +- name: _values(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +# NN +- name: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor + i1, i2, i3: "_trilinear_backward(grad, + wrap_opt_if(i1, grad_input_mask[1] || grad_input_mask[2]), + wrap_opt_if(i2, grad_input_mask[0] || grad_input_mask[2]), + wrap_opt_if(i3, grad_input_mask[0] || grad_input_mask[1]), + expand1, expand2, expand3, sumdim, grad_input_mask)" + result: "_trilinear(i1_t, i2_p, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) + + _trilinear(i1_p, i2_t, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) + + _trilinear(i1_p, i2_p, i3_t, expand1, expand2, expand3, sumdim, unroll_dim)" + +- name: constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor + self: constant_pad_nd_backward(grad, pad) + result: constant_pad_nd_symint(self_t, pad, 0) + +- name: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_backward(grad, self, target, weight, reduction) + target: binary_cross_entropy_target_backward(grad, self, target, weight, reduction) + result: "apply_loss_reduction( + binary_cross_entropy_backward(self_t, self_p, target_p, weight, at::Reduction::None) + + binary_cross_entropy_target_backward(target_t, self_p, target_p, weight, at::Reduction::None), + reduction)" + +- name: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_double_backward(grad_output, grad, self, target, weight, reduction) + target: binary_cross_entropy_double_backward_target(grad, grad_output, self, target, weight, reduction) + grad_output: binary_cross_entropy_double_backward_grad_output(grad, self, target, weight, reduction) + result: " binary_cross_entropy_double_backward(grad_output_p, self_t, self_p, target_p, weight, reduction) + + binary_cross_entropy_double_backward_target(target_t, grad_output_p, self_p, target_p, weight, reduction) + + binary_cross_entropy_double_backward_grad_output(grad_output_t, self_p, target_p, weight, reduction)" + +- name: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_with_logits_backward(grad, self, target, weight, pos_weight, reduction) + target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction) + result: "apply_loss_reduction( + binary_cross_entropy_with_logits_backward(self_t, self_p, target_p, weight, pos_weight, at::Reduction::None) + + binary_cross_entropy_with_logits_target_backward(target_t, self_p, target_p, weight, pos_weight, at::Reduction::None), + reduction)" + +- name: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + indices: non_differentiable + weight: embedding_backward_symint(grad, indices, weight.sym_size(0), padding_idx, scale_grad_by_freq, sparse) + result: auto_linear + +- name: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor + grad_output: embedding_dense_double_backward_symint(grad, indices, padding_idx) + indices: non_differentiable + result: auto_linear + +- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) + indices: non_differentiable + offsets: non_differentiable + weight: _embedding_bag_backward_symint(grad, indices, offsets, result1, result2, result3, weight.sym_size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx) + per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode, padding_idx) + +- name: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + grad: not_implemented("_embedding_bag_backward") + indices: non_differentiable + offsets: non_differentiable + offset2bag: non_differentiable + bag_size: non_differentiable + maximum_indices: non_differentiable + per_sample_weights: not_implemented("_embedding_bag_backward") + +- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + grad: not_implemented("_embedding_bag_dense_backward") + indices: non_differentiable + offset2bag: non_differentiable + bag_size: non_differentiable + maximum_indices: non_differentiable + per_sample_weights: not_implemented("_embedding_bag_dense_backward") + +- name: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) + indices: non_differentiable + self: not_implemented("embedding_renorm") + +- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + self: mse_loss_backward(grad, self, target, reduction) + target: mse_loss_backward(grad, target, self, reduction) + result: apply_loss_reduction(mse_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None).conj() + mse_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None).conj(), reduction) + +- name: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor + self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction) + target: non_differentiable + +- name: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) + self: multilabel_margin_loss_backward(grad, self, target, reduction, is_target) + target: non_differentiable + +- name: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + self: nll_loss_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight) + target: non_differentiable + output: std::get<0>(nll_loss_forward_symint(self_t, target, weight, reduction, ignore_index)) + +- name: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + self: nll_loss2d_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight) + target: non_differentiable + output: std::get<0>(nll_loss2d_forward_symint(self_t, target, weight, reduction, ignore_index)) + +- name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor + self: smooth_l1_loss_backward(grad, self, target, reduction, beta) + target: smooth_l1_loss_backward(grad, target, self, reduction, beta) + result: apply_loss_reduction(smooth_l1_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, beta).conj() + smooth_l1_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, beta).conj(), reduction) + +- name: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor + self: huber_loss_backward(grad, self, target, reduction, delta) + target: huber_loss_backward(grad, target, self, reduction, delta) + result: apply_loss_reduction(huber_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, delta).conj() + huber_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, delta).conj(), reduction) + +- name: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + self: soft_margin_loss_backward(grad, self, target, reduction) + result: apply_loss_reduction(soft_margin_loss_backward(self_t.conj(), self_p, target, at::Reduction::None).conj(), reduction) + +- name: relu(Tensor self) -> Tensor + self: threshold_backward(grad, result, 0) + result: auto_element_wise + +- name: silu(Tensor self) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)" + result: auto_element_wise + +- name: mish(Tensor self) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)" + result: auto_element_wise + +- name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor + self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self) + result: auto_element_wise + +- name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) + self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result) + result: self_t.copy_(elu_backward(original_self_t, alpha, scale, input_scale, /* is_result */ true, result)) + +- name: celu(Tensor self, Scalar alpha=1.0) -> Tensor + self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self) + result: auto_element_wise + +- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) + self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result) + result: self_t.copy_(elu_backward(original_self_t, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)) + +- name: gelu(Tensor self, *, str approximate='none') -> Tensor + self: gelu_backward(grad, self, approximate) + result: auto_element_wise + +- name: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor + grad_output: gelu_backward(grad, self, approximate) + self: gelu_double_backward(grad, grad_output, self, approximate) + result: gelu_backward(grad_output_t, self_p, approximate) + gelu_double_backward(self_t, grad_output_p, self_p, approximate) + +- name: glu(Tensor self, int dim=-1) -> Tensor + # TODO: glu_backward can benefit from forward result, + # and forward ad/forward over reverse ad for that matter + self: glu_backward(grad, self, dim) + result: glu_jvp(result, self_p, self_t, dim) + +- name: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor + self: hardshrink_backward(grad, self, lambd) + result: auto_element_wise + +- name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor + grad_out: hardshrink_backward(grad, self, lambd) + self: zeros_like(grad) + result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_out_t, at::zeros({}, result.options()).expand_as(result)) + +- name: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor + self: hardtanh_backward(grad, self, min_val, max_val) + result: auto_element_wise + +- name: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor + self: leaky_relu_backward(grad, self, negative_slope, false) + result: auto_element_wise + +- name: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) + self: leaky_relu_backward(grad, result, negative_slope, true) + result: self_t.copy_(leaky_relu_backward(original_self_t.conj(), result, negative_slope, true).conj()) + +- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) + self: log_sigmoid_backward(grad, self, buffer) + output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj() + output_differentiability: [True, False] + +- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _log_softmax_backward_data(grad, result, dim, self.scalar_type()) + result: self_t - logsumexp_jvp(self_p, self_t, {dim}, true) + +- name: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _sparse_log_softmax_backward_data(grad, result, dim, self) + +- name: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor + self: _masked_softmax_backward(grad, result, mask, dim) + mask: non_differentiable + +- name: _prelu_kernel(Tensor self, Tensor weight) -> Tensor + self, weight: "grad.defined() ? _prelu_kernel_backward(grad, self, weight) : std::tuple()" + result: at::where(self_p >= 0, self_t, weight_p * self_t + weight_t * self_p) + +- name: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) + grad_output: "grads[0].defined() ? + (grads[1].defined() ? at::where(self >= 0, grads[0], grads[0] * weight + grads[1] * self) + : at::where(self >= 0, grads[0], grads[0] * weight)) + : at::where(self >= 0, at::zeros({}, grad_output.options()), grads[1] * self)" + self: "grads[1].defined() ? at::where(self >= 0, at::zeros({}, self.options()), grad_output * grads[1]) : zeros_like(self)" + weight: "grads[0].defined() ? at::where(self >= 0, at::zeros({}, weight.options()), grad_output * grads[0]) : zeros_like(self)" + result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t) + result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p) + +- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor + self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) + result: auto_element_wise + +- name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) + self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true) + +- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _softmax_backward_data(grad, result, dim, self.scalar_type()) + result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true)) + +- name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _sparse_softmax_backward_data(grad, result, dim, self) + +- name: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor + self: sparse_sparse_matmul_backward(grad, self, other, 0) + other: sparse_sparse_matmul_backward(grad, self, other, 1) + +- name: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor + self: softplus_backward(grad, self, beta, threshold) + result: auto_element_wise + +- name: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor + self: softshrink_backward(grad, self, lambd) + result: auto_element_wise + +- name: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor + self: threshold_backward(grad, self, threshold) + result: auto_element_wise + +- name: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) + self: threshold_backward(grad, self, threshold) + result: self_t.copy_(threshold_backward(self_t.conj(), original_self_p, threshold).conj()) + +- name: reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor + self: reflection_pad1d_backward_symint(grad, self, padding) + result: auto_linear + +- name: reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor + self: reflection_pad2d_backward_symint(grad, self, padding) + result: auto_linear + +- name: reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor + self: reflection_pad3d_backward_symint(grad, self, padding) + result: auto_linear + +- name: replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor + self: replication_pad1d_backward_symint(grad, self, padding) + result: auto_linear + +- name: replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor + self: replication_pad2d_backward_symint(grad, self, padding) + result: auto_linear + +- name: replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor + self: replication_pad3d_backward_symint(grad, self, padding) + result: auto_linear + +- name: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor + self: upsample_linear1d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales) + result: auto_linear + +- name: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_bilinear2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_bilinear2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_bicubic2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_bicubic2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_trilinear3d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_d, scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + self: upsample_nearest1d_backward_symint(grad, output_size, self.sym_sizes(), scales) + result: auto_linear + +- name: _upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + self: _upsample_nearest_exact1d_backward_symint(grad, output_size, self.sym_sizes(), scales) + result: auto_linear + +- name: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_nearest2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_nearest_exact2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_nearest3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_nearest_exact3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w) + result: auto_linear + +- name: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor + self: pixel_unshuffle(grad, upscale_factor) + result: auto_linear + +- name: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor + self: pixel_shuffle(grad, downscale_factor) + result: auto_linear + +- name: channel_shuffle(Tensor self, SymInt groups) -> Tensor + self: channel_shuffle_symint(grad, grad.sym_size(1) / groups) + result: auto_linear + +- name: _adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + self: _adaptive_avg_pool2d_backward(grad, self) + result: auto_linear + +- name: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + self: _adaptive_avg_pool3d_backward(grad, self) + result: auto_linear + +- name: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) + self: adaptive_max_pool2d_backward(grad, self, result1) + result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) + output_differentiability: [True, False] + +- name: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) + self: adaptive_max_pool3d_backward(grad, self, result1) + result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) + output_differentiability: [True, False] + +- name: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + result: auto_linear + +- name: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + result: auto_linear + +- name: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) + self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, result1) + result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) + output_differentiability: [True, False] + +- name: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) + self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, result1) + result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) + output_differentiability: [True, False] + +- name: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + input, weight, bias: "grad.defined() ? linear_backward(input, grad, weight, grad_input_mask) : std::tuple()" + +- name: linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + self, grad_output, weight: linear_double_backward(grads, self, grad_output, weight) + +#mps +- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode) + +- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + self, weight, bias: "grad.defined() ? mps_convolution_backward_symint(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple()" + +- name: mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + grad_output, self, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) + +- name: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) + result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) + output_differentiability: [True, False] + +- name: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) + result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) + output_differentiability: [True, False] + +- name: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor + self: max_pool_double_backward(grad, indices, 2) + indices: non_differentiable + result: auto_linear + +- name: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor + self: max_pool_double_backward(grad, indices, 3) + indices: non_differentiable + result: auto_linear + +- name: convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + result: convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups) + +# TorchScript serializes calls to _convolution so this entry is present until that is changed to use convolution. +# Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context +# by convolution_backward instead of being passed along from the forward pass. +- name: _convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32) + +- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) + result0: std::get<0>(convolution_backward_symint(grad_output_p, input_p, weight_t, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + std::get<0>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + result1: std::get<1>(convolution_backward_symint(grad_output_p, input_t, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + std::get<1>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + result2: convolution_backward_jvp_grad_bias(grad_output_t, result2) + +- name: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + input, weight, bias: "grad.defined() ? convolution_backward_overrideable_symint(grad, input, weight, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + +- name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) + +- name: slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" + +- name: slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" + +- name: _slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor + self, weight, bias: "grad.defined() ? _slow_conv2d_backward_symint(grad, self, weight, kernel_size, stride, padding, grad_input_mask) : std::tuple()" + +- name: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + grad_output, self, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask) + +- name: _conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" + +- name: conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" + +- name: slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple()" + +- name: slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + +- name: slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + +- name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + self: im2col(grad, kernel_size, dilation, padding, stride) + result: auto_linear + +- name: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + self: col2im_symint(grad, {self.sym_size(-2), self.sym_size(-1)}, kernel_size, dilation, padding, stride) + result: auto_linear + +- name: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + grad_output: _adaptive_avg_pool2d_symint(grad, {grad_output.sym_size(-2), grad_output.sym_size(-1)}) + self: zeros_like(self) + result: _adaptive_avg_pool2d_backward(grad_output_t, self_p) + +- name: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor + grad_output: _adaptive_avg_pool3d_symint(grad, { grad_output.sym_size(-3), grad_output.sym_size(-2), grad_output.sym_size(-1) }) + self: zeros_like(self) + result: _adaptive_avg_pool3d_backward(grad_output_t, self_p) + +- name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 2) + self: zeros_like(self) + result: auto_linear + +- name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + result: auto_linear + +- name: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + self: zeros_like(self) + result: avg_pool2d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + +- name: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + self: zeros_like(self) + result: avg_pool3d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + +- name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor + grad_output: elu_backward(grad, alpha, scale, input_scale, is_result, self_or_result) + self_or_result: elu_double_backward(grad, grad_output, alpha, scale, input_scale, is_result, self_or_result) + result: elu_backward(grad_output_t, alpha, scale, input_scale, is_result, self_or_result_p) + elu_double_backward(self_or_result_t, grad_output_p, alpha, scale, input_scale, is_result, self_or_result_p) + +- name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 2) + self: zeros_like(self) + result: auto_linear + +- name: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + result: auto_linear + +- name: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor + grad_output: glu_double_backward_grad_output(grad, self, dim) + self: glu_double_backward(grad, grad_output, self, dim) + result: glu_backward_jvp(result, grad_output_p, self_p, grad_output_t, self_t, dim) + +- name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor + grad_output: hardtanh_backward(grad, self, min_val, max_val) + self: zeros_like(grad) + result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result)) + +- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor + grad_output: log_sigmoid_backward(grad, self, buffer) + self: log_sigmoid_double_backward(grad * grad_output, self) + result: log_sigmoid_backward(grad_output_t, self_p, buffer) + log_sigmoid_double_backward(self_t * grad_output_p, self_p) + +- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true) + output: (-grad_output.sum(dim, true) * output.exp() * grad.to(output.dtype())).to(output.dtype()) + +- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor + # self_is_result is always false here since double backward call is an out-of-place call, self is input itself + grad_output: leaky_relu_backward(grad, self, negative_slope, false) + self: zeros_like(grad) + # leaky_relu_backward(grad_output, self, negative_slope, false) + # computes grad_output * at::where(self_p > 0, 1, negative_slope) + # so the jvp formula is the following: + # grad_output_t * at::where(self_p > 0, self_p.new_ones([]), negative_slope); + # + # leaky_relu_backward(grad_output, result, negative_slope, true) + # computes grad_output * at::where(result > 0, 1, negative_slope) + # under the assumption that `negative_slope` is positive (otherwise, + # it is not possible to compute the gradient). + # + # so the jvp formula is the following: + # grad_output_t * at::where(result_p > 0, result_p.new_ones([]), negative_slope); + # with the assumption that negative_slope is positive. + # + # Combined together that results in the following optimized kernel which + # also checks the assumption that negative_slope is positive when self_is_result + # is True: + result: leaky_relu_backward(grad_output_t, self_p, negative_slope, self_is_result) + +# This derivative is mps-only, and `error_for_max_pool2d_double_backward` just raises an error. +- name: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + grad_output: error_for_max_pool2d_double_backward() + self: zeros_like(self) + result: auto_linear + +- name: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 2) + self: zeros_like(self) + indices: non_differentiable + result: auto_linear + +- name: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + indices: non_differentiable + result: auto_linear + +- name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + grad_output: mse_loss_backward(grad, self, target, reduction) + self: mse_loss_double_backward(grad * grad_output, self, reduction) + target: -mse_loss_double_backward(grad * grad_output, target, reduction) + result: " mse_loss_double_backward(self_t * grad_output_p, self_p, reduction) + - mse_loss_double_backward(target_t * grad_output_p, target_p, reduction) + + mse_loss_backward(grad_output_t, self_p, target_p, reduction) + " + +- name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + grad_output: nll_loss_symint(grad, target, weight, reduction, ignore_index) + self: zeros_like(grad) + target: non_differentiable + +- name: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + grad_output: nll_loss2d_symint(grad, target, weight, reduction, ignore_index) + self: zeros_like(grad) + target: non_differentiable + +- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor + # self_is_result is always false here since double backward call is an out-of-place call, self is input itself + grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) + self: zeros_like(grad) + result: rrelu_with_noise_backward(grad_output_t, self_p, noise, lower, upper, training, false) + +- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + grad_output: reflection_pad1d_symint(grad, padding) + self: zeros_like(self) + result: reflection_pad1d_backward_symint(grad_output_t, self_p, padding) + +- name: reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + grad_output: reflection_pad2d_symint(grad, padding) + self: zeros_like(self) + result: reflection_pad2d_backward_symint(grad_output_t, self_p, padding) + +- name: reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + grad_output: reflection_pad3d_symint(grad, padding) + self: zeros_like(self) + result: reflection_pad3d_backward_symint(grad_output_t, self_p, padding) + +- name: replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + grad_output: replication_pad1d_symint(grad, padding) + self: zeros_like(self) + result: replication_pad1d_backward_symint(grad_output_t, self_p, padding) + +- name: replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + grad_output: replication_pad2d_symint(grad, padding) + self: zeros_like(self) + result: replication_pad2d_backward_symint(grad_output_t, self_p, padding) + +- name: replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + grad_output: replication_pad3d_symint(grad, padding) + self: zeros_like(self) + result: replication_pad3d_backward_symint(grad_output_t, self_p, padding) + +- name: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self, mat1, mat2: "sparse_sampled_addmm_backward(grad, + self, + wrap_opt_if(mat1, grad_input_mask[2]), + wrap_opt_if(mat2, grad_input_mask[1]), + alpha, beta, grad_input_mask)" + +- name: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor) + output_differentiability: [True, False] + self, other: "grad.defined() ? _sparse_mm_reduce_impl_backward(self, grad, other, reduce, result1, grad_input_mask) : std::tuple()" + +- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor + grad_output: smooth_l1_loss_backward(grad, self, target, reduction, beta) + self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) + target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) + result: " smooth_l1_loss_double_backward(self_t * grad_output_p, self_p, target_p, reduction, beta) + - smooth_l1_loss_double_backward(target_t * grad_output_p, self_p, target_p, reduction, beta) + + smooth_l1_loss_backward(grad_output_t, self_p, target_p, reduction, beta) + " + +- name: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor + grad_output: huber_loss_double_backward_grad_output(grad, grad_output, self, target, reduction, delta) + self: huber_loss_double_backward(grad * grad_output, self, target, reduction, delta) + target: -huber_loss_double_backward(grad * grad_output, self, target, reduction, delta) + +- name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor + grad_output: softplus_backward(grad, self, beta, threshold) + self: softplus_double_backward(grad * grad_output, self, beta, threshold) + result: "softplus_backward(grad_output_t, self_p, beta, threshold) + + softplus_double_backward(self_t * grad_output_p, self_p, beta, threshold)" + +- name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, input_dtype) + output: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(output.dtype()) + +- name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, reduction) + self: soft_margin_loss_double_backward(grad * grad_output, self, target, reduction) + +- name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor + grad_output: softshrink_backward(grad, self, lambd) + self: zeros_like(grad) + result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_output_t, at::zeros({}, result.options()).expand_as(result)) + +- name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor + grad_output: threshold_backward(grad, self, threshold) + self: zeros_like(grad) + result: zeros_like(self_t) + threshold_backward(grad_output_t, self_p, threshold) + +- name: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor + grad_output: upsample_linear1d_symint(grad, output_size, align_corners, scales) + result: auto_linear + +- name: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_bilinear2d_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_bilinear2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_bicubic2d_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_bicubic2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_trilinear3d_symint(grad, output_size, align_corners, scales_d, scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + grad_output: upsample_nearest1d_symint(grad, output_size, scales) + result: auto_linear + +- name: _upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + grad_output: _upsample_nearest_exact1d_symint(grad, output_size, scales) + result: auto_linear + +- name: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_nearest2d_symint(grad, output_size, scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_nearest_exact2d_symint(grad, output_size, scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_nearest3d_symint(grad, output_size, scales_d, scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_nearest_exact3d_symint(grad, output_size, scales_d, scales_h, scales_w) + result: auto_linear + +- name: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor + grad_output: sigmoid_backward(grad, output.conj()) + output: grad.conj() * grad_output * (-2 * output.conj() + 1) + result: sigmoid_backward(grad_output_t, output_p) + output_t.conj() * grad_output_p * (-2 * output_p.conj() + 1) + +- name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor + grad_output: tanh_backward(grad, output.conj()) + output: grad.conj() * (-2 * output.conj() * grad_output) + result: tanh_backward(grad_output_t, output_p) + output_t.conj() * (-2 * output_p.conj() * grad_output_p) + +# cudnn +- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity) + +- name: _cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity) + +- name: cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, output_padding, stride, dilation, true, groups, {grad_input_mask[0], grad_input_mask[1]})" + +- name: _mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + self, weight: "grad.defined() ? mps_convolution_transpose_backward_symint(self, grad, weight, padding, output_padding, stride, dilation, groups, grad_input_mask) : std::tuple()" + +- name: cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, std::vector(padding.size(), 0), stride, dilation, false, groups, {grad_input_mask[0], grad_input_mask[1]})" + +- name: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output + self, grid: "grad.defined() ? cudnn_grid_sampler_backward(self, grid, grad) : std::tuple()" + +- name: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid + theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W) + +# NB: Why is the backwards here so complicated? CuDNN cannot be used to compute +# backward in evaluation mode, because the math for backward in evaluation mode +# is different (since the forward math is different), and CuDNN does not support +# it. And in any case, you shouldn't be using this bn in evaluation mode, +# because it should be merged into the previous convolution (left for future +# work.) +# NB2: The quotes around the gradient are needed to appease YAML parsing rules. +- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) + +# HACK: save_mean and save_var are going to be passed in as +# requires_grad variables (even though we'll never backprop through +# them) so we need to prevent the unpacking from triggering an error. +- name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) + save_mean: not_implemented("cudnn_batch_norm_backward save_mean") + save_var: not_implemented("cudnn_batch_norm_backward save_var") + reserveSpace: not_implemented("cudnn_batch_norm_backward reserveSpace") + input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) + +# nnpack + +- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor + # NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here. + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector(padding.size(), 1), false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + +#LSTM MPS +- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + output_differentiability: [True, True, True, False, False, False] + input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, result5, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)" + +- name: lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[]) + + + +# Only frst three of _cudnn_rnn outputs can have gradients. +# _cudnn_rnn outputs: (output, hy, cy, reserve, weight_buf) +- name: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + dropout_state: non_differentiable + output_differentiability: [True, True, True, False, False] + input, hx, cx, weight: "_cudnn_rnn_backward_symint(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" + +- name: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + dropout_state: non_differentiable + input: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + weight: not_implemented_list("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + hx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + cx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_hy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_cy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + +# miopen + +- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple()" + +- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + +- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + +- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) + +- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) + save_mean: not_implemented("miopen_batch_norm_backward save_mean") + save_var: not_implemented("miopen_batch_norm_backward save_var") + input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) + +- name: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + dropout_state: non_differentiable + output_differentiability: [True, True, True, False, False] + input, hx, cx, weight: "miopen_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" + +- name: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + dropout_state: non_differentiable + +- name: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor) + output_differentiability: [True, True, True, False] + input, weight0, weight1, weight2, weight3, hx_, cx_: "GradMode::is_enabled() ? mkldnn_rnn_layer_differentiable_backward(input, weight0, weight1, weight2, weight3, hx_, cx_, result0, result1, result2, grads[0], grads[1], grads[2], reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, result3) : mkldnn_rnn_layer_backward(input, weight0, weight1, weight2, weight3, hx_, cx_, result0, result1, result2, grads[0], grads[1], grads[2], reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, result3)" + +- name: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + +# mkldnn +- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + +- name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor + self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask) + +- name: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + self: mkldnn_max_pool2d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode) + +- name: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + self: mkldnn_max_pool3d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode) + +- name: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor + self: mkldnn_adaptive_avg_pool2d_backward(grad, self) + +- name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + +# NestedTensor +- name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + list: "grad.defined()? at::unbind(grad) : std::vector(list.size())" + +- name: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor + t: grad.to_padded_tensor_symint(0, t.sym_sizes()) + mask: non_differentiable + +- name: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor + padded: _nested_from_padded_backward(grad, padded, fuse_transform_0213) + cpu_nested_shape_example: non_differentiable + +- name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor + self: at::_nested_from_padded(grad, self._nested_tensor_size()) + padding: non_differentiable + +- name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) + self: grad.values() + nested_size: non_differentiable + nested_strides: non_differentiable + +- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) + self: grad.values() + offsets: non_differentiable + lengths: non_differentiable + dummy: non_differentiable + +- name: _nested_get_values(Tensor(a) self) -> Tensor(a) + self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt)" + +# Transformer +- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + self: _softmax_backward_data(grad, result, dim, self.scalar_type()) + result: result * (self_t - safe_logsumexp_jvp(self_p, self_t, {dim}, true)) + +- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + output_differentiability: [True, False, False, False] + query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale) + +- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) + +- name: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp) + output_differentiability: [True, False] + query, key, value: _scaled_dot_product_flash_attention_for_cpu_backward(grad, query, key, value, output, logsumexp, dropout_p, is_causal, attn_mask, scale) + +- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False] + query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale, window_size_left, window_size_right) + +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) + output_differentiability: [True, False, False, False, False, False] + query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) + +- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) + +- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) + +# fft +- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor + self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) + result: auto_linear + +- name: _fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor + self: fft_c2r_backward(grad, dim, normalization) + result: auto_linear + +- name: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + self: _fft_c2c_symint(grad, dim, normalization, !forward) + result: auto_linear + +- name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] + dispatch: + Default: + self: unbind_backward(grads, dim) + result: auto_linear + AutogradNestedTensor: + self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())" + result: auto_linear + +- name: stack(Tensor[] tensors, int dim=0) -> Tensor + tensors: stack_tensors_backward(grad, dim, to_args_scalartypes(tensors)) + result: stack_jvp(tensors, dim) + +# fused RNN kernels + +# Only frst two of _thnn_fused_lstm_cell outputs can have gradients. +# _thnn_fused_lstm_cell outputs: (hy, cy, workspace) +- name: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, True, False] + input_gates, hidden_gates, cx, input_bias, hidden_bias: "GradMode::is_enabled() ? _thnn_differentiable_lstm_cell_backward(grads[0], grads[1], input_gates, hidden_gates, input_bias, hidden_bias, cx, result1) : _thnn_fused_lstm_cell_backward(grads[0], grads[1], cx, result1, result2, input_bias.defined())" + +- name: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) + input_gates, hidden_gates, hx, input_bias, hidden_bias: "grad.defined() ? (GradMode::is_enabled() ? _thnn_differentiable_gru_cell_backward(grad, input_gates, hidden_gates, hx, input_bias, hidden_bias) : _thnn_fused_gru_cell_backward(grad, result1, input_bias.defined())) : std::tuple()" + +# PackedSequence helpers +- name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) + input: _pack_padded_sequence_backward_symint(grad, input.sym_sizes(), result1, batch_first) + +# TH wrappers +- name: eq.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: eq.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: ge.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: ge.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: gt.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: gt.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: le.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: le.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: lt.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: lt.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: ne.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: ne.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor + output_differentiability: [False] + +- name: nonzero(Tensor self) -> Tensor + output_differentiability: [False] + +- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor + data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial) + +- name: _pin_memory(Tensor self, Device? device=None) -> Tensor + self: grad + +- name: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor + self: non_differentiable + other: non_differentiable + output_differentiability: [False] + +- name: _test_warn_in_autograd(Tensor self) -> Tensor + self: warn_backwards(grad) + +- name: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor + dispatch: + Default: + self: grad.expand_symint(self.sym_sizes()) + 1 + result: auto_linear + AutogradNestedTensor: + self: grad.mul(grad) + AutogradCUDA: + self: grad.expand_symint(self.sym_sizes()) * 2 + +- name: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor + dispatch: + AutogradNestedTensor: + self: grad.mul(grad).add(grad) + +- name: _test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a) + dispatch: + Default: + self: grad.reshape_as(self) + AutogradCUDA: + self: grad.reshape_as(self) + 1 + +- name: _efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + output_differentiability: [False] + +- name: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor + self, src: scatter_reduce_backward(grad, self, dim, index, src, reduce, include_self, result) + index: non_differentiable + result: scatter_reduce_jvp(self_p, self_t, dim, index, src_p, src_t, reduce, include_self, result) + +- name: special_airy_ai(Tensor x) -> Tensor + x: non_differentiable + +- name: special_bessel_j0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_bessel_j1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_bessel_y0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_bessel_y1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_modified_bessel_i0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_modified_bessel_i1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_modified_bessel_k0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_modified_bessel_k1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_scaled_modified_bessel_k0(Tensor x) -> Tensor + x: non_differentiable + +- name: special_scaled_modified_bessel_k1(Tensor x) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_spherical_bessel_j0(Tensor x) -> Tensor + x: non_differentiable + +- name: _reshape_copy(Tensor self, SymInt[] size) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + +# note(crcrpar): `torchgen/api/autograd` logic would unwantedly replace substrings of `self` and `other` of function names. +- name: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] + self: div_tensor_self_backward(grads[i], other[i], self[i].scalar_type()) + other: div_tensor_other_backward(grads[i], self[i], other[i]) + result: (self_t - other_t * result[i]) / other_p + +- name: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] + self: pow_backward_self(grads[i], self[i], exponent[i]) + exponent: pow_backward_exponent(grads[i], self[i], exponent[i], result[i]) + result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result[i])).conj() + +- name: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] + self: pow_backward(grads[i], self[i], exponent[i]) + result: pow_backward(self_t.conj(), self_p, exponent[i]).conj() + +- name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] + exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i]) + +# note(crcrpar): following definitions seem necessary because the reference native functions +# of `maximum` and `minimum` don't have the overload def with Scalar as their second argument. +- name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] > scalar, 0) + result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p < scalar).to(result[i].scalar_type())) * (self_t - scalar) + +- name: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > scalars[i], 0) + result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p < scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i]) + +- name: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] < scalar, 0) + result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p > scalar).to(result[i].scalar_type())) * (self_t - scalar) + +- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0) + result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p > scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i]) + +# note(crcrpar): forward-mode AD is tricky for a simple string replace to handle: +# formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)` +- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[] + self: norm_backward(grads[i], self[i], ord, result[i]) + result: norm_jvp(self_p, self_t, ord, result[i]) diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py new file mode 100644 index 0000000000000000000000000000000000000000..c32779b3a2825e82d18a57bdeea76c47707e4284 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py @@ -0,0 +1,132 @@ +""" +For procedural tests needed for __torch_function__, we use this function +to export method names and signatures as needed by the tests in +test/test_overrides.py. + +python -m tools.autograd.gen_annotated_fn_args \ + aten/src/ATen/native/native_functions.yaml \ + aten/src/ATen/native/tags.yaml \ + $OUTPUT_DIR \ + tools/autograd + +Where $OUTPUT_DIR is where you would like the files to be +generated. In the full build system, OUTPUT_DIR is +torch/testing/_internal/generated +""" + +from __future__ import annotations + +import argparse +import os +import textwrap +from collections import defaultdict +from typing import Any, Sequence, TYPE_CHECKING + +import torchgen.api.python as python +from torchgen.context import with_native_function +from torchgen.gen import parse_native_yaml +from torchgen.utils import FileManager + +from .gen_python_functions import ( + is_py_fft_function, + is_py_linalg_function, + is_py_nn_function, + is_py_special_function, + is_py_torch_function, + is_py_variable_method, + should_generate_py_binding, +) + + +if TYPE_CHECKING: + from torchgen.model import Argument, BaseOperatorName, NativeFunction + + +def gen_annotated( + native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str +) -> None: + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + mappings = ( + (is_py_torch_function, "torch._C._VariableFunctions"), + (is_py_nn_function, "torch._C._nn"), + (is_py_linalg_function, "torch._C._linalg"), + (is_py_special_function, "torch._C._special"), + (is_py_fft_function, "torch._C._fft"), + (is_py_variable_method, "torch.Tensor"), + ) + annotated_args: list[str] = [] + for pred, namespace in mappings: + groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list) + for f in native_functions: + if not should_generate_py_binding(f) or not pred(f): + continue + groups[f.func.name.name].append(f) + for group in groups.values(): + for f in group: + annotated_args.append(f"{namespace}.{gen_annotated_args(f)}") + + template_path = os.path.join(autograd_dir, "templates") + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_with_template( + "annotated_fn_args.py", + "annotated_fn_args.py.in", + lambda: { + "annotated_args": textwrap.indent("\n".join(annotated_args), " "), + }, + ) + + +@with_native_function +def gen_annotated_args(f: NativeFunction) -> str: + def _get_kwargs_func_exclusion_list() -> list[str]: + # functions that currently don't work with kwargs in test_overrides.py + return [ + "diagonal", + "round_", + "round", + "scatter_", + ] + + def _add_out_arg( + out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool + ) -> None: + for arg in args: + if arg.default is not None: + continue + out_arg: dict[str, Any] = {} + out_arg["is_kwarg_only"] = str(is_kwarg_only) + out_arg["name"] = arg.name + out_arg["simple_type"] = python.argument_type_str( + arg.type, simple_type=True + ) + size_t = python.argument_type_size(arg.type) + if size_t: + out_arg["size"] = size_t + out_args.append(out_arg) + + out_args: list[dict[str, Any]] = [] + _add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False) + if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list(): + _add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True) + + return f"{f.func.name.name}: {repr(out_args)}," + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate annotated_fn_args script") + parser.add_argument( + "native_functions", metavar="NATIVE", help="path to native_functions.yaml" + ) + parser.add_argument("tags", metavar="TAGS", help="path to tags.yaml") + parser.add_argument("out", metavar="OUT", help="path to output directory") + parser.add_argument( + "autograd", metavar="AUTOGRAD", help="path to template directory" + ) + args = parser.parse_args() + gen_annotated(args.native_functions, args.tags, args.out, args.autograd) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_autograd.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e7be149ad6d8043126a9420217eb3bfe4d42e6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_autograd.py @@ -0,0 +1,147 @@ +""" +To run this file by hand from the root of the PyTorch +repository, run: + +python -m tools.autograd.gen_autograd \ + aten/src/ATen/native/native_functions.yaml \ + aten/src/ATen/native/tags.yaml \ + $OUTPUT_DIR \ + tools/autograd + +Where $OUTPUT_DIR is where you would like the files to be +generated. In the full build system, OUTPUT_DIR is +torch/csrc/autograd/generated/ +""" + +# gen_autograd.py generates C++ autograd functions and Python bindings. +# +# It delegates to the following scripts: +# +# gen_autograd_functions.py: generates subclasses of torch::autograd::Node +# gen_variable_type.py: generates VariableType.h which contains all tensor methods +# gen_python_functions.py: generates Python bindings to THPVariable +# + +from __future__ import annotations + +import argparse +import os + +from torchgen.api import cpp +from torchgen.api.autograd import ( + match_differentiability_info, + NativeFunctionWithDifferentiabilityInfo, +) +from torchgen.gen import parse_native_yaml +from torchgen.selective_build.selector import SelectiveBuilder + +from . import gen_python_functions +from .gen_autograd_functions import ( + gen_autograd_functions_lib, + gen_autograd_functions_python, +) +from .gen_inplace_or_view_type import gen_inplace_or_view_type +from .gen_trace_type import gen_trace_type +from .gen_variable_factories import gen_variable_factories +from .gen_variable_type import gen_variable_type +from .gen_view_funcs import gen_view_funcs +from .load_derivatives import load_derivatives + + +def gen_autograd( + native_functions_path: str, + tags_path: str, + out: str, + autograd_dir: str, + operator_selector: SelectiveBuilder, + disable_autograd: bool = False, +) -> None: + # Parse and load derivatives.yaml + differentiability_infos, used_dispatch_keys = load_derivatives( + os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path + ) + + template_path = os.path.join(autograd_dir, "templates") + + native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions + fns = sorted( + filter( + operator_selector.is_native_function_selected_for_training, native_funcs + ), + key=lambda f: cpp.name(f.func), + ) + fns_with_diff_infos: list[ + NativeFunctionWithDifferentiabilityInfo + ] = match_differentiability_info(fns, differentiability_infos) + + # Generate VariableType.h/cpp + if not disable_autograd: + gen_variable_type( + out, + native_functions_path, + tags_path, + fns_with_diff_infos, + template_path, + used_dispatch_keys, + ) + + gen_inplace_or_view_type( + out, native_functions_path, tags_path, fns_with_diff_infos, template_path + ) + + # operator filter not applied as tracing sources are excluded in selective build + gen_trace_type(out, native_funcs, template_path) + # Generate Functions.h/cpp + gen_autograd_functions_lib(out, differentiability_infos, template_path) + + # Generate variable_factories.h + gen_variable_factories(out, native_functions_path, tags_path, template_path) + + # Generate ViewFuncs.h/cpp + gen_view_funcs(out, fns_with_diff_infos, template_path) + + +def gen_autograd_python( + native_functions_path: str, + tags_path: str, + out: str, + autograd_dir: str, +) -> None: + differentiability_infos, _ = load_derivatives( + os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path + ) + + template_path = os.path.join(autograd_dir, "templates") + + # Generate Functions.h/cpp + gen_autograd_functions_python(out, differentiability_infos, template_path) + + # Generate Python bindings + deprecated_path = os.path.join(autograd_dir, "deprecated.yaml") + gen_python_functions.gen( + out, native_functions_path, tags_path, deprecated_path, template_path + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate autograd C++ files script") + parser.add_argument( + "native_functions", metavar="NATIVE", help="path to native_functions.yaml" + ) + parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml") + parser.add_argument("out", metavar="OUT", help="path to output directory") + parser.add_argument( + "autograd", metavar="AUTOGRAD", help="path to autograd directory" + ) + args = parser.parse_args() + gen_autograd( + args.native_functions, + args.tags, + args.out, + args.autograd, + SelectiveBuilder.get_nop_selector(), + ) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py new file mode 100644 index 0000000000000000000000000000000000000000..e8141658b0335cb7272a4bb885b49fdb934d1bbd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py @@ -0,0 +1,675 @@ +# Generates ADInplaceOrViewType.h/cpp +# +# NOTE: If any changes are being made to the ADInplaceOrView codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# The fallback is expected to mimick this codegen, so we should keep the two in sync. + +from __future__ import annotations + +from torchgen.api import cpp +from torchgen.api.autograd import ( + dispatch_strategy, + gen_differentiable_outputs, + NativeFunctionWithDifferentiabilityInfo, +) +from torchgen.api.types import ( + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + DispatcherSignature, + intArrayRefT, + longT, + OptionalCType, + symIntArrayRefT, + SymIntT, + tensorT, +) +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.model import ( + NativeFunction, + SchemaKind, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import FileManager + +from .context import with_native_function_with_differentiability_info +from .gen_trace_type import ( + get_return_value, + MANUAL_AUTOGRAD, + tie_return_values, + type_wrapper_name, +) + + +# See NOTE [ Autograd View Variables ] in variable.h for details. +# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT, +# you **MUST** also update the public list of view ops accordingly in +# docs/source/tensor_view.rst. Note not all ATen functions are exposed to public, +# e.g alias & sparse_coo_tensor_with_dims_and_tensors. +# +# A map: function name => name of the argument that all outputs are view of + +VIEW_FUNCTIONS_WITH_METADATA_CHANGE = [ + "view_as_complex", + "view_as_real", + "_conj", + "_neg_view", + "_nested_get_values", + "_nested_view_from_buffer", + "_nested_view_from_jagged", +] + +VIEW_FUNCTIONS = { + "numpy_T": "self", + "alias": "self", + "as_strided": "self", + "diagonal": "self", + "expand": "self", + "permute": "self", + "select": "self", + "slice": "self", + "slice_inverse": "self", + "split": "self", + "split_with_sizes": "self", + "squeeze": "self", + "t": "self", + "transpose": "self", + "unfold": "self", + "unsqueeze": "self", + "flatten": "self", + "view": "self", + "unbind": "self", + "_indices": "self", + "_values": "self", + "indices": "self", + "values": "self", + "crow_indices": "self", + "col_indices": "self", + "ccol_indices": "self", + "row_indices": "self", + # sparse_coo ctor output should really be views of both indices and values, + # but we only supports making as view of a single variable, and indices is + # discrete anyways. + # FIXME: clone indices on construction. + "sparse_coo_tensor_with_dims_and_tensors": "values", + "_reshape_alias": "self", + "_test_autograd_multiple_dispatch_view": "self", +} + +for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE: + VIEW_FUNCTIONS[key] = "self" + +# note: some VIEW_FUNCTIONS are just compositions of the view functions above +# this list contains both the root view functions and any that are purely composed +# of viewing functions, and is used by the JIT to determine when an operator +# may return a view of its inputs; however they may sometimes return a copy. +# (e.g. `contiguous`) +RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union( + { + "chunk", + "detach", + "contiguous", + "reshape", + "reshape_as", + "expand_as", + "view_as", + "real", + "imag", + "narrow", + "movedim", + "tensor_split", + "swapdims", + "swapaxes", + "mT", + "mH", + "adjoint", + "matrix_H", + } +) + +# These are the functions we consider views for the purposes of validating +# StorageImpl and TensorImpl in gen_variable_type. +# `_unsafe_view` is not included in VIEW_FUNCTIONS above because it is not a +# view for the purposes of ADInplaceOrView kernel, we do not want to call as_view +# See NOTE [Unsafe View] for more info. +ALL_VIEW_FUNCTIONS = { + **VIEW_FUNCTIONS, + "_unsafe_view": "self", +} + +ARRAYREF_TO_VEC = CodeTemplate( + """\ +auto ${vec} = ${arg}.vec(); +""" +) + +OPTIONAL_TO_VAL = CodeTemplate( + """\ +auto ${val} = ${arg}.value_or(${default}); +""" +) + +CALL_DISPATCH = CodeTemplate( + """\ +at::_ops::${unambiguous_name}::call(${unpacked_args})""" +) + +REVERSE_VIEW_DISPATCH = CodeTemplate( + """\ +${reverse_name}(${unpacked_args})""" +) + +MULTI_OUTPUT_VIEW_ITERATION = CodeTemplate( + """\ +for (auto ${view_idx} : c10::irange(${var}.size())) { + ${body} +} +""" +) + +SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate( + """\ +std::unique_ptr func(nullptr); +std::function rev_func=nullptr; +if (${is_view_with_metadata_change} || + !self.unsafeGetTensorImpl()->support_as_strided() || + self.unsafeGetTensorImpl()->is_python_dispatch() || + c10::AutogradState::get_tls_state().get_view_replay_enabled()) { + ${replay_view_func} + ${reverse_replay_view_func} +} +""" +) + +REPLAY_VIEW_FUNC = CodeTemplate( + """\ +func = std::make_unique<${view_func_name}>(${view_func_args}); +""" +) + +REVERSE_REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate( + """\ +rev_func = [=](const at::Tensor& ${input_view}) { + return ${reverse_replay_view_call}; +}; +""" +) + +METHOD_DEFINITION = CodeTemplate( + """\ +${return_type} ${type_wrapper_name}(${formals}) { + ${type_definition_body} +} +""" +) + +WRAPPER_REGISTRATION = CodeTemplate( + """\ +m.impl("${unqual_operator_name_with_overload}", + TORCH_FN(${class_type}::${type_wrapper_name}) +); +""" +) + +AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate( + """\ +m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImplementedFallback()); +""" +) + +INPLACE_REDISPATCH = CodeTemplate( + """\ +{ + at::AutoDispatchBelowADInplaceOrView guard; + at::_ops::${unambiguous_name}::redispatch(${unpacked_args}); +} +""" +) + +ASSIGN_RETURN_VALUE = CodeTemplate( + """\ +${return_values} = ${rhs_value}; +""" +) + +VIEW_REDISPATCH = CodeTemplate( + """\ +${assign_return_values} ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + return at::_ops::${unambiguous_name}::redispatch(${unpacked_args}); +})(); +""" +) + +TMP_VAR = "_tmp" + + +# FIXME: Ideally these functions should be methods on Type class, but we have a +# comment in codegen/model.py there saying these concepts are not well defined. +# Thus we put a version that commonly used by autograd codegen here. +def is_tensor_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is None + + +def is_tensor_list_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is not None + + +UNPACK_TENSOR = CodeTemplate( + """\ +auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""" +) + + +def unpacked_name(arg_name: str) -> str: + return arg_name + "_" + + +# e.g. select.int -> select_copy_int_inverse() +def inverse_view_name(f: NativeFunction) -> str: + copy_variant = f"{f.root_name}_copy" + overload = f"{f.func.name.overload_name}" + if overload != "": + overload = "_" + overload + return f"{copy_variant}{overload}_inverse" + + +def extract_bindings(f: NativeFunction) -> list[Binding]: + return [ + r + for a in f.func.schema_order_arguments() + for r in cpp.argument( + a, + method=False, + symint=True, + cpp_no_default_args=set(), + faithful=False, + has_tensor_options=False, + ) + ] + + +@with_native_function +def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]: + body: list[str] = [] + unpacked_bindings: list[Binding] = [] + + for i, binding in enumerate(extract_bindings(f)): + assert not isinstance(binding.argument, SelfArgument) + if isinstance(binding.argument, TensorOptionsArguments): + raise RuntimeError("VariableKernel shouldn't take TensorOptions") + + is_nullable = binding.argument.type.is_nullable() + if not binding.argument.type.is_tensor_like() or is_nullable: + unpacked_bindings.append(binding) + continue + + is_tensor_list = is_tensor_list_type(binding.argument.type) + ref = (not is_nullable) and not is_tensor_list + suffix = "_opt" if is_nullable and not is_tensor_list else "" + body.append( + UNPACK_TENSOR.substitute( + arg_name=binding.name, + arg_pos=i, + suffix=suffix, + ref="&" if ref else "", + ) + ) + unpacked_bindings.append( + Binding( + name=unpacked_name(binding.name), + nctype=binding.nctype, + argument=binding.argument, + default=binding.default, + ) + ) + + return body, unpacked_bindings + + +def get_base_name(f: NativeFunction) -> str: + return f.func.name.name.base # TODO: should be str(f.func.name.name)? + + +def get_view_info(f: NativeFunction) -> str | None: + base_name = get_base_name(f) + view_info = VIEW_FUNCTIONS.get(base_name, None) + if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT: + view_info = "self" + return view_info + + +def emit_view_func( + f: NativeFunction, bindings: list[Binding], view_idx: str | None = None +) -> str: + """Generate an additional lambda function to recover views in backward when as_strided is not supported. + See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details. + """ + # TODO: Clean this logic up if we get rid of reverse view funcs or reify them. + input_base = "input_base" + replay_view_func = "" + updated_args: list[str] = [] + known_view_arg_simple_types: list[CType] = [ + BaseCType(longT), + OptionalCType(BaseCType(longT)), + BaseCType(SymIntT), + OptionalCType(BaseCType(SymIntT)), + BaseCType(boolT), + BaseCType(intArrayRefT), + BaseCType(symIntArrayRefT), + ConstRefCType(BaseCType(tensorT)), + ConstRefCType(OptionalCType(BaseCType(tensorT))), + ] + for binding in bindings: + arg, arg_type = binding.name, binding.nctype.type + if arg == "self": + updated_args.append(input_base) + continue + if arg_type not in known_view_arg_simple_types: + known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types]) + raise TypeError( + f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: " + f"{known_types_str}. Please update the list or materialize it so that it can be closed " + "over by value, also add a test in pytorch/xla/test/test_operations.py where this code " + "is exercised." + ) + if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType( + symIntArrayRefT + ): + # It's not safe to close over IntArrayRef by value, since this is a + # reference type, so materialize a vector to close over by value + arg_vec = arg + "_vec" + replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec) + updated_args.append(arg_vec) + elif arg_type == OptionalCType(BaseCType(longT)): + # Materialize int64_t? to int64_t + arg_value = arg + "_val" + replay_view_func += OPTIONAL_TO_VAL.substitute( + arg=arg, val=arg_value, default="0" + ) + updated_args.append(arg_value) + elif arg_type == ConstRefCType(BaseCType(tensorT)) or arg_type == ConstRefCType( + OptionalCType(BaseCType(tensorT)) + ): + # NB: Closing over a tensor. If a user modifies this tensor, this will be silently + # incorrect. The proper thing to do is to store the version counter and copy on write. + updated_args.append(arg) + else: + updated_args.append(arg) + + from .gen_view_funcs import view_func_name + + view_func_args = [b.name for b in bindings if b.name != "self"] + if view_idx is not None: + view_func_args.append(f"{view_idx}") + replay_view_func += REPLAY_VIEW_FUNC.substitute( + view_func_name=view_func_name(f, include_namespace=True), + view_func_args=view_func_args, + ) + + input_view = "input_view" + reverse_unpacked_args = [ + "self", + f"{input_view}", + # inverse_return_mode= + "at::functionalization::InverseReturnMode::AlwaysView", + *(() if view_idx is None else (f"{view_idx}",)), + # skip input_base arg + *updated_args[1:], + ] + + from torchgen.api.functionalization import reverse_name + + reverse_replay_view_call = REVERSE_VIEW_DISPATCH.substitute( + reverse_name=reverse_name(f, include_namespace=True), + unpacked_args=reverse_unpacked_args, + ) + reverse_replay_view_func = REVERSE_REPLAY_VIEW_LAMBDA_FUNC.substitute( + input_view=input_view, reverse_replay_view_call=reverse_replay_view_call + ) + + is_view_with_metadata_change = ( + "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false" + ) + + return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute( + is_view_with_metadata_change=is_view_with_metadata_change, + replay_view_func=replay_view_func, + reverse_replay_view_func=reverse_replay_view_func, + ) + + +def emit_view_body( + fn: NativeFunctionWithDifferentiabilityInfo, var: str +) -> tuple[str, str]: + # See NOTE [ Autograd View Variables ] in variable.h for details. + f = fn.func + base_name = get_base_name(f) + view_info = get_view_info(f) + call = "" + differentiable_outputs = gen_differentiable_outputs(fn) + differentiable_output_vars = {r.name for r in differentiable_outputs} + if not isinstance(view_info, str): + raise TypeError( + f"The view info should be a string for {base_name}, but it is: {view_info}" + ) + if len(differentiable_output_vars) == 0: + # no output is differentiable (.indices() for SparseTensors for example) + rhs_value = ( + f"as_view({view_info}, {var}, " + f"/* is_bw_differentiable */ false, /* is_fw_differentiable */ false)" + ) + elif len(differentiable_output_vars) == 1: + # Single differentiable output (Tensor or Tensor[]) + return_info = differentiable_outputs[0] + # We only support simple Tensor or a TensorList for functions that return views + if not is_tensor_type(return_info.type) and not is_tensor_list_type( + return_info.type + ): + raise RuntimeError( + f"{base_name} that return differentiable views can only return Tensor or Tensor[]" + ) + + # See Note [ View + Inplace detection] + def get_creation_meta_in_mode(original: str) -> str: + creation_meta_with_grad_mode = f"(at::GradMode::is_enabled() ? {original} : CreationMeta::NO_GRAD_MODE)" + return f"InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : {creation_meta_with_grad_mode}" + + # Only allow rebasing of the history if we return a single Tensor + # If we are in a no grad block, raise a warning + # See NOTE [ View + Inplace detection ] for more details about this logic + if is_tensor_list_type(return_info.type): + creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE") + view_idx = "view_idx" + view_func = emit_view_func( + f, extract_bindings(f), view_idx=view_idx + ).strip() + as_view_call = ( + f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], " + "/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, " + "/* view_func */ std::move(func), /* rev_view_func */ rev_func, " + f"/* creation_meta */ {creation_meta});" + ) + call += MULTI_OUTPUT_VIEW_ITERATION.substitute( + var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}" + ) + rhs_value = f"std::move({var})" + else: + call += emit_view_func(f, extract_bindings(f), view_idx=None) + creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT") + rhs_value = ( + f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, " + "/* is_fw_differentiable */ true, " + f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})" + ) + else: + # This could be supported but we don't need it at the moment, so keeping things simple. + raise RuntimeError( + "Function that return multiple differentiable output " + "when at least one of them is view is not supported." + ) + return call, rhs_value + + +def modifies_arguments(f: NativeFunction) -> bool: + return f.func.kind() in [SchemaKind.inplace, SchemaKind.out] + + +@with_native_function_with_differentiability_info +def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]: + f = fn.func + inplace_view_body: list[str] = [] + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset" + redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) + + # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. + # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. + if modifies_arguments(f): # inplace op + inplace_view_body.append( + INPLACE_REDISPATCH.substitute( + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + for r in cpp.return_names(f): + inplace_view_body.append(f"increment_version({r});") + else: + assert get_view_info(f) is not None + inplace_view_body.append( + VIEW_REDISPATCH.substitute( + assign_return_values="auto " + TMP_VAR + " = ", + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + call, rhs_value = emit_view_body(fn, TMP_VAR) + inplace_view_body.append(call) + assert rhs_value is not None + inplace_view_body.append( + ASSIGN_RETURN_VALUE.substitute( + return_values=tie_return_values(f), rhs_value=rhs_value + ) + ) + if f.func.returns: + inplace_view_body.append(f"return {get_return_value(f)};") + return inplace_view_body + + +@with_native_function +def gen_formals(f: NativeFunction) -> str: + return ", ".join( + # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + ["c10::DispatchKeySet ks"] + + [ + f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' + for a in f.func.schema_order_arguments() + ] + ) + + +@with_native_function_with_differentiability_info +def inplace_or_view_method_definition( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> str | None: + f = fn.func + if get_view_info(f) is None and ( + # For functions that modify their inputs but don't return them, + # we can't give them autograd support. + # See https://github.com/pytorch/pytorch/issues/53796 + not modifies_arguments(f) + or len(f.func.returns) == 0 + ): + return None + return METHOD_DEFINITION.substitute( + return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), + type_wrapper_name=type_wrapper_name(f), + formals=gen_formals(f), + type_definition_body=emit_inplace_or_view_body(fn), + ) + + +@with_native_function_with_differentiability_info +def inplace_or_view_method_registration( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> str | None: + f = fn.func + if get_view_info(f) is None and ( + not modifies_arguments(f) or len(f.func.returns) == 0 + ): + return None + return WRAPPER_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type="ADInplaceOrView", + ) + + +def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool: + f = fn.func + name = cpp.name(f.func) + return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == "use_derived" + + +def gen_inplace_or_view_type_env( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> dict[str, list[str]]: + definition = inplace_or_view_method_definition(fn) + registration = inplace_or_view_method_registration(fn) + + return { + "ops_headers": ( + [f"#include "] + if definition is not None + else [] + ), + "inplace_or_view_method_definitions": [definition] + if definition is not None + else [], + "inplace_or_view_wrapper_registrations": [registration] + if registration is not None + else [], + } + + +def gen_inplace_or_view_type( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], + template_path: str, +) -> None: + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + num_shards = 2 + + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_sharded( + "ADInplaceOrViewType.cpp", + [fn for fn in fns_with_infos if use_derived(fn)], + key_fn=lambda fn: fn.func.root_name, + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/ADInplaceOrViewType.cpp", + }, + env_callable=gen_inplace_or_view_type_env, + num_shards=2, + sharded_keys={ + "ops_headers", + "inplace_or_view_method_definitions", + "inplace_or_view_wrapper_registrations", + }, + ) diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_python_functions.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_python_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..44453306a0ecbf65452c0287a8c903b9d11f0600 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_python_functions.py @@ -0,0 +1,1402 @@ +# Generates Python bindings for ATen functions +# +# The bindings are generated as methods on python_variable or functions on the +# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse +# or torch._C._special objects. +# + +# Code tries to stick to the following rules: +# +# - templates should be colocated with the functions that use them. +# no templates are currently shared between functions, but if that +# happens, maybe put the template with the first one +# +# - don't use environment dictionaries when calling template.substitute(). +# pass named arguments directly for everything, otherwise it's much too +# hard to track what's actually being used and by who +# +# - colocate any new hacks/adjustments with existing ones of the same kind. +# ideally in a data structure rather than code if possible. See e.g. +# SCHEMA_DEFAULT_CONVERSION_HACKS, etc. +# +# - similarly, conversions from one format to another should ideally happen +# all at once in a single place. +# +# - no nontrivial nested functions. couple-liners are ok but please no more. +# especially avoid functions that read/write outer variables defined far away. +# +# - raise RuntimeError instead of asserting, and put as much +# information as is available into the message. I.e. no need to +# plumb in new params whose only purpose is to fill out an error +# message, but use what's there +# + +from __future__ import annotations + +import itertools +import re +from collections import defaultdict +from typing import Callable, Iterable, Sequence + +import yaml + +from torchgen.api import cpp +from torchgen.api.python import ( + arg_parser_output_exprs, + cpp_dispatch_exprs, + cpp_dispatch_target, + dispatch_lambda_args, + dispatch_lambda_exprs, + dispatch_lambda_return_str, + has_tensor_options, + PythonSignature, + PythonSignatureDeprecated, + PythonSignatureGroup, + PythonSignatureNativeFunctionPair, + signature, + signature_from_schema, + structseq_fieldnames, +) +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml +from torchgen.model import ( + Argument, + BaseOperatorName, + FunctionSchema, + NativeFunction, + SchemaKind, + Type, + Variant, +) +from torchgen.utils import FileManager, split_name_params +from torchgen.yaml_utils import YamlLoader + +from .gen_inplace_or_view_type import is_tensor_list_type +from .gen_trace_type import should_trace + + +# +# declarations blocklist +# We skip codegen for these functions, for various reasons. +# Future PRs will categorize this list and eliminate or hoist +# them out of eager-only codegen. +# See https://github.com/pytorch/pytorch/issues/30788 +# + +# These functions require manual Python bindings or are not exposed to Python +_SKIP_PYTHON_BINDINGS = [ + "alias", + "contiguous", + "is_cuda", + "is_sparse", + "is_sparse_csr", + "size", + "stride", + "sym_size", + "sym_stride", + "sym_storage_offset", + "sym_numel", + ".*_backward", + ".*_backward_(out|input|weight|bias)", + ".*_forward", + ".*_forward_out", + ".*_jvp", + "_unsafe_view", + "tensor", + "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*", + "_range.*", + "_sparse_add_out", + "_sparse_div.*", + "_sparse_mul.*", + "_sparse_sub.*", + "_sparse_dense_add_out", + "index", + "index_out", + "unique_dim_consecutive", + "_cumsum.*", + "_cumprod.*", + "_sum.*", + "_prod.*", + "_th_.*", + "_thnn_.*", + "range.*", + "_solve.*", + "_inverse.*", + "_cholesky.*", + "_triangular_solve.*", + "_qr.*", + "_svd.*", + "slice", + "item", + "_local_scalar_dense", + "to", + "_to_copy", + "_to_copy_out", + "_reshape_copy", + "_reshape_copy_out", + "copy_sparse_to_sparse_", + "copy_", + "_foreach_copy", + "numpy_T", + "matrix_H", + "mT", + "mH", # these need to be an attributes in Python, not functions + "nonzero(_(out|numpy))?", + "set_data", + ".*_overrideable", # overrideable functions for backend extension + "data", + "is_leaf", + "output_nr", + "_version", + "requires_grad_", + "retains_grad", + "set_", + "_fw_primal", + "fake_quantize_per_tensor_affine_cachemask", + "fake_quantize_per_channel_affine_cachemask", + "_new_zeros_with_same_feature_meta", + "_has_same_storage_numel", # used for forward AD internals + "_reshape_alias", + "replace_", # only used by the functionalization pass, doesn't need to be exposed to python + "copy", # only used by the functionalization pass + "fill.Tensor", # only used by the functionalization pass + "fill.Scalar", # only used by the functionalization pass + "lift.*", + "normal_functional", # only used by the functionalization pass + "nbytes", + "itemsize", + "_batch_norm_with_update", + "_batch_norm_with_update_out", + "_batch_norm_no_update", +] + +SKIP_PYTHON_BINDINGS = [ + re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS +] + +# These function signatures are not exposed to Python. Note that this signature +# list does not support regex. +SKIP_PYTHON_BINDINGS_SIGNATURES = [ + "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", + "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", + "mul.Scalar(Tensor self, Scalar other) -> Tensor", + "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", + "div.Scalar(Tensor self, Scalar other) -> Tensor", + "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", +] + + +@with_native_function +def should_generate_py_binding(f: NativeFunction) -> bool: + # NativeFunctions that are entirely code-generated should not get python bindings + # because these codegen implementations are often inefficient. A handful of + # view_copy style ops were exposed accidentally when they were handwritten and now + # that we are moving them to codegen for bc reasons we need to keep them exposed in + # python. + if "generated" in f.tags and "view_copy" not in f.tags: + return False + + name = cpp.name(f.func) + for skip_regex in SKIP_PYTHON_BINDINGS: + if skip_regex.match(name): + return False + + signature = str(f.func) + for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: + if pattern == signature: + return False + return True + + +def get_pycname(name: BaseOperatorName) -> str: + return f"THPVariable_{name}" + + +def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool: + return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0 + + +def is_py_variable_method(f: NativeFunction) -> bool: + return f.python_module is None and Variant.method in f.variants + + +def is_py_torch_function(f: NativeFunction) -> bool: + return f.python_module is None and Variant.function in f.variants + + +def is_py_nn_function(f: NativeFunction) -> bool: + return f.python_module == "nn" + + +def is_py_fft_function(f: NativeFunction) -> bool: + return f.python_module == "fft" + + +def is_py_linalg_function(f: NativeFunction) -> bool: + return f.python_module == "linalg" + + +def is_py_nested_function(f: NativeFunction) -> bool: + return f.python_module == "nested" + + +def is_py_sparse_function(f: NativeFunction) -> bool: + return f.python_module == "sparse" + + +def is_py_special_function(f: NativeFunction) -> bool: + return f.python_module == "special" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Main Function +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def gen( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + deprecated_yaml_path: str, + template_path: str, + *, + symint: bool = True, +) -> None: + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + native_functions = list(filter(should_generate_py_binding, native_functions)) + + methods = load_signatures(native_functions, deprecated_yaml_path, method=True) + create_python_bindings( + fm, + methods, + is_py_variable_method, + None, + "python_variable_methods.cpp", + method=True, + symint=symint, + ) + + # NOTE: num_shards here must be synced with gatherTorchFunctions in + # torch/csrc/autograd/python_torch_functions_manual.cpp + functions = load_signatures(native_functions, deprecated_yaml_path, method=False) + create_python_bindings_sharded( + fm, + functions, + is_py_torch_function, + "torch", + "python_torch_functions.cpp", + method=False, + num_shards=3, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_nn_function, + "torch.nn", + "python_nn_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_fft_function, + "torch.fft", + "python_fft_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_linalg_function, + "torch.linalg", + "python_linalg_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_nested_function, + "torch.nested", + "python_nested_functions.cpp", + method=False, + ) + + create_python_bindings( + fm, + functions, + is_py_sparse_function, + "torch.sparse", + "python_sparse_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_special_function, + "torch.special", + "python_special_functions.cpp", + method=False, + symint=symint, + ) + + # Currently, we only use `functions` to generate `return_types` bindings. + # All methods which return structseq have function variant at this point. + # If any method only operator with structseq is added in the future, + # we will have to address that. + create_python_return_type_bindings( + fm, functions, lambda fn: True, "python_return_types.cpp" + ) + create_python_return_type_bindings_header( + fm, functions, lambda fn: True, "python_return_types.h" + ) + + valid_tags = parse_tags_yaml(tags_yaml_path) + + def gen_tags_enum() -> dict[str, str]: + return { + "enum_of_valid_tags": ( + "".join( + [f'\n.value("{tag}", at::Tag::{tag})' for tag in sorted(valid_tags)] + ) + ) + } + + fm.write("python_enum_tag.cpp", gen_tags_enum) + + +def group_filter_overloads( + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], +) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]: + grouped: dict[ + BaseOperatorName, list[PythonSignatureNativeFunctionPair] + ] = defaultdict(list) + for pair in pairs: + if pred(pair.function): + grouped[pair.function.func.name.name].append(pair) + return grouped + + +def create_python_bindings( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: str | None, + filename: str, + *, + method: bool, + symint: bool = True, +) -> None: + """Generates Python bindings to ATen functions""" + py_methods: list[str] = [] + ops_headers: list[str] = [] + py_method_defs: list[str] = [] + py_forwards: list[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=str): + overloads = grouped[name] + py_methods.append( + method_impl(name, module, overloads, method=method, symint=symint) + ) + py_method_defs.append(method_def(name, module, overloads, method=method)) + py_forwards.extend(forward_decls(name, overloads, method=method)) + ops_headers.append(f"#include ") + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "ops_headers": ops_headers, + "py_forwards": py_forwards, + "py_methods": py_methods, + "py_method_defs": py_method_defs, + }, + ) + + +def create_python_return_type_bindings( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + filename: str, +) -> None: + """ + Generate function to initialize and return named tuple for native functions + which returns named tuple and registration invocations in `python_return_types.cpp`. + """ + py_return_types_definition: list[str] = [] + py_return_types_registrations: list[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=str): + overloads = grouped[name] + definitions, registrations = generate_return_type_definition_and_registrations( + overloads + ) + py_return_types_definition.append( + "" if not definitions else "\n".join(definitions) + ) + py_return_types_registrations.append( + "" if not registrations else "\n".join(registrations) + ) + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "py_return_types": py_return_types_definition, + "py_return_types_registrations": py_return_types_registrations, + }, + ) + + +def create_python_return_type_bindings_header( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + filename: str, +) -> None: + """ + Generate function to initialize and return named tuple for native functions + which returns named tuple and relevant entry for the map in `python_return_types.cpp`. + """ + py_return_types_declarations: list[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=str): + overloads = grouped[name] + declarations = generate_return_type_declarations(overloads) + py_return_types_declarations.append( + "" if not declarations else "\n".join(declarations) + ) + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "py_return_types_declarations": py_return_types_declarations, + }, + ) + + +def create_python_bindings_sharded( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: str | None, + filename: str, + *, + method: bool, + num_shards: int, + symint: bool = True, +) -> None: + """Generates Python bindings to ATen functions""" + grouped = group_filter_overloads(pairs, pred) + + def key_func( + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] + ) -> str: + return kv[0].base + + def env_func( + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] + ) -> dict[str, list[str]]: + name, fn_pairs = kv + return { + "ops_headers": [f"#include "], + "py_forwards": list(forward_decls(name, fn_pairs, method=method)), + "py_methods": [ + method_impl(name, module, fn_pairs, method=method, symint=symint) + ], + "py_method_defs": [method_def(name, module, fn_pairs, method=method)], + } + + fm.write_sharded( + filename, + grouped.items(), + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + }, + key_fn=key_func, + env_callable=env_func, + num_shards=num_shards, + sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"}, + ) + + +def load_signatures( + native_functions: list[NativeFunction], + deprecated_yaml_path: str, + *, + method: bool, + skip_deprecated: bool = False, + pyi: bool = False, +) -> Sequence[PythonSignatureNativeFunctionPair]: + @with_native_function + def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair: + return PythonSignatureNativeFunctionPair( + signature=signature(f, method=method, pyi=pyi), + function=f, + ) + + pairs = list(map(gen_signature_pairs, native_functions)) + deprecated = load_deprecated_signatures( + pairs, deprecated_yaml_path, method=method, pyi=pyi + ) + return pairs if skip_deprecated else pairs + deprecated + + +def load_deprecated_signatures( + pairs: Sequence[PythonSignatureNativeFunctionPair], + deprecated_yaml_path: str, + *, + method: bool, + pyi: bool, +) -> list[PythonSignatureNativeFunctionPair]: + # The deprecated.yaml doesn't have complete type information, we need + # find and leverage the original ATen signature (to which it delegates + # the call) to generate the full python signature. + # We join the deprecated and the original signatures using type-only form. + + # group the original ATen signatures by name + grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list) + for pair in pairs: + grouped[pair.signature.name].append(pair) + + # find matching original signatures for each deprecated signature + results: list[PythonSignatureNativeFunctionPair] = [] + + with open(deprecated_yaml_path) as f: + deprecated_defs = yaml.load(f, Loader=YamlLoader) + + for deprecated in deprecated_defs: + schema = FunctionSchema.parse(deprecated["name"]) + aten_name, call_args = split_name_params(deprecated["aten"]) + is_out = aten_name.endswith("_out") + if is_out: + aten_name = aten_name.replace("_out", "") + + # HACK: these are fixed constants used to pass the aten function. + # The type must be known ahead of time + known_constants = { + "1": Type.parse("Scalar"), + } + schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} + for name in call_args: + assert ( + name in schema_args_by_name or name in known_constants + ), f"deprecation definiton: Unrecognized value {name}" + + # Map deprecated signature arguments to their aten signature and test + # if the types and alias annotation match. + def is_schema_compatible( + aten_schema: FunctionSchema, + ) -> bool: + arguments: Iterable[Argument] + if is_out: + arguments = itertools.chain( + aten_schema.arguments.out, aten_schema.arguments.flat_non_out + ) + else: + arguments = aten_schema.arguments.flat_all + + for i, arg in enumerate(arguments): + if i < len(call_args): + arg_name = call_args[i] + if arg_name in known_constants: + schema_type = known_constants[arg_name] + schema_annotation = None + else: + schema_arg = schema_args_by_name[arg_name] + schema_type = schema_arg.type + schema_annotation = schema_arg.annotation + + if schema_type != arg.type or schema_annotation != arg.annotation: + return False + else: + if arg.default is None: + return False + + return len(schema.returns) == len(aten_schema.returns) and all( + a == b for a, b in zip(schema.returns, aten_schema.returns) + ) + + any_schema_found = False + for pair in grouped[aten_name]: + if not is_schema_compatible(pair.function.func): + continue + any_schema_found = True + + python_sig = signature_from_schema( + schema, + category_override=pair.function.category_override, + method=method, + pyi=pyi, + ) + + results.append( + PythonSignatureNativeFunctionPair( + signature=PythonSignatureDeprecated( + name=python_sig.name, + input_args=python_sig.input_args, + input_kwargs=python_sig.input_kwargs, + output_args=python_sig.output_args, + tensor_options_args=python_sig.tensor_options_args, + method=python_sig.method, + deprecated_schema=schema, + deprecated_args_exprs=tuple(call_args), + returns=python_sig.returns, + ), + function=pair.function, + ) + ) + assert ( + any_schema_found + ), f"No native function with name {aten_name} matched signature:\n {str(schema)}" + + return results + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Named Tuple Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@with_native_function +def gen_structseq_typename_key(f: NativeFunction) -> str: + name = cpp.name(f.func) + fieldnames = structseq_fieldnames(f.func.returns) + return "_".join([name] + fieldnames) + + +def emit_structseq_call( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> tuple[list[str], dict[str, str]]: + """ + Generate block of named tuple type def inits, and add typeref snippets + to declarations that use them + """ + typenames: dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + typedefs: list[str] = [] # typedef declarations and init code + + for overload in overloads: + fieldnames = structseq_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_structseq_typename_key(overload.function) + typename = typenames.get(tn_key) + if typename is None: + typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' + typenames[tn_key] = typename + typedefs.append( + f"""\ +static PyTypeObject* {typename} = generated::get_{name}_structseq();""" + ) + + return typedefs, typenames + + +def generate_return_type_definition_and_registrations( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> tuple[list[str], list[str]]: + """ + Generate block of function in `python_return_types.cpp` to initialize + and return named tuple for a native function which returns named tuple + and registration invocations in same file. + """ + typenames: dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + definitions: list[str] = [] # function definition to register the typedef + registrations: list[str] = [] # register call for the typedef + + for overload in overloads: + fieldnames = structseq_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames) + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_structseq_typename_key(overload.function) + typename = typenames.get(tn_key) + + if typename is None: + typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}' + typenames[tn_key] = typename + definitions.append( + f"""\ +PyTypeObject* get_{name}_structseq() {{ + static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }}; + static PyTypeObject {typename}; + static bool is_initialized = false; + static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }}; + if (!is_initialized) {{ + PyStructSequence_InitType(&{typename}, &desc); + {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; + is_initialized = true; + }} + return &{typename}; +}} +""" + ) + registrations.append( + f'addReturnType(return_types_module, "{name}", generated::get_{name}_structseq());' + ) + + return definitions, registrations + + +def generate_return_type_declarations( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> list[str]: + """ + Generate block of function declarations in `python_return_types.h` to initialize + and return named tuple for a native function. + """ + typenames: dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + declarations: list[str] = [] # function declaration to register the typedef + + for overload in overloads: + fieldnames = structseq_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_structseq_typename_key(overload.function) + typename = typenames.get(tn_key) + + if typename is None: + typename = ( + f'{name}NamedTuple{"" if not declarations else len(declarations)}' + ) + typenames[tn_key] = typename + declarations.append(f"PyTypeObject* get_{name}_structseq();") + + return declarations + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Method Impl Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# python binding for all overloads of a particular function/method +PY_VARIABLE_METHOD_VARARGS = CodeTemplate( + r"""\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + ${method_header} + static PythonArgParser parser({ + ${signatures} + }, /*traceable=*/${traceable}); + + ParsedArgs<${max_args}> parsed_args; + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); + ${check_has_torch_function} + switch (_r.idx) { + ${dispatch} + } + ${method_footer} +} + +""" +) + +# handler for a single parsed signature - may be a single overload or +# a pair of overloads that whose signatures only differ in output params +# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch}) +PY_VARIABLE_CASE = CodeTemplate( + """\ +case ${overload_index}: { + ${body} +} +""" +) + +# python binding for single-overload function/method +PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate( + """\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + ${method_header} + static PythonArgParser parser({ + ${signatures} + }, /*traceable=*/${traceable}); + + ParsedArgs<${max_args}> parsed_args; + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); + ${check_has_torch_function} + ${dispatch} + ${method_footer} +} + +""" +) + +# python binding for a method with no args, shortcuts parsing +PY_VARIABLE_METHOD_NOARGS = CodeTemplate( + """\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args) +{ + ${method_header} + ${check_has_torch_function} + ${dispatch} + ${method_footer} +} + +""" +) + + +def method_impl( + name: BaseOperatorName, + module: str | None, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, + symint: bool = True, +) -> str: + """ + Generate a python binding for all overloads of an op. + """ + pycname = get_pycname(name) + noarg = is_noarg(overloads) + structseq_inits, structseq_typenames = emit_structseq_call(overloads) + + method_header = ["HANDLE_TH_ERRORS"] + method_header += structseq_inits + method_header += ( + ["const Tensor& self = THPVariable_Unpack(self_);"] if method else [] + ) + + method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"] + + traceable = "true" if all(should_trace(o.function) for o in overloads) else "false" + + grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads( + overloads, symint=symint + ) + is_singleton = len(grouped_overloads) == 1 + signatures: list[str] = [] + dispatch: list[str] = [] + for overload_index, overload in enumerate(grouped_overloads): + signature = overload.signature.signature_str(symint=symint) + signatures.append(f"{cpp_string(str(signature))},") + dispatch_body = emit_dispatch_case(overload, structseq_typenames, symint=symint) + dispatch.append( + PY_VARIABLE_CASE.substitute( + overload_index=overload_index, body=dispatch_body + ) + if not is_singleton + else dispatch_body + ) + + if noarg: + template = PY_VARIABLE_METHOD_NOARGS + elif is_singleton: + template = PY_VARIABLE_METHOD_VARARGS_SINGLETON + else: + template = PY_VARIABLE_METHOD_VARARGS + + return template.substitute( + name=name, + pycname=pycname, + method_header=method_header, + max_args=max(o.signature.arguments_count() for o in overloads), + signatures=signatures, + traceable=traceable, + check_has_torch_function=gen_has_torch_function_check( + name=name, + module=module, + noarg=noarg, + method=method, + ), + dispatch=dispatch, + method_footer=method_footer, + self_="self_" if method else "nullptr", + ) + + +def gen_has_torch_function_check( + name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool +) -> str: + if noarg: + if method: + return f"""\ +if(check_has_torch_function(self_)) {{ + return handle_torch_function(self_, "{name}"); +}} +""" + else: + return "" + + self_ = "self_" if method else "nullptr" + namespace = ( + { + "torch": "THPVariableFunctionsModule", + "torch.nn": "THPNNVariableFunctionsModule", + "torch.fft": "THPFFTVariableFunctionsModule", + "torch.linalg": "THPLinalgVariableFunctionsModule", + "torch.nested": "THPNestedVariableFunctionsModule", + "torch.sparse": "THPSparseVariableFunctionsModule", + "torch.special": "THPSpecialVariableFunctionsModule", + }[module] + if module + else "THPVariableClass" + ) + + return f"""\ +if(_r.has_torch_function()) {{ + return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}"); +}} +""" + + +# handler for output/no-output overload pair +PY_VARIABLE_OUT = CodeTemplate( + """\ +if (_r.isNone(${out_idx})) { + ${call_dispatch} +} else { + ${call_dispatch_out} +} +""" +) + + +def emit_dispatch_case( + overload: PythonSignatureGroup, + structseq_typenames: dict[str, str], + *, + symint: bool = True, +) -> str: + """ + Emit dispatch code for a single parsed signature. This corresponds to either + a single native function, or a pair that differ only in output params. In the + latter case, a single python signature is used for both and dispatching + switches on the presence/absence of passed output args. + """ + if overload.outplace is not None: + # dispatch output and no-output variants, branch on _r.isNone() + return PY_VARIABLE_OUT.substitute( + out_idx=overload.signature.output_idx(), + call_dispatch=emit_single_dispatch( + overload.signature, overload.base, structseq_typenames, symint=symint + ), + call_dispatch_out=emit_single_dispatch( + overload.signature, + overload.outplace, + structseq_typenames, + symint=symint, + ), + ) + else: + # no-output version only + return emit_single_dispatch( + overload.signature, overload.base, structseq_typenames, symint=symint + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Forward Declarations Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def forward_decls( + name: BaseOperatorName, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, +) -> tuple[str, ...]: + if method: + return () + + pycname = get_pycname(name) + if is_noarg(overloads): + return ( + f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args); +""", + ) + else: + return ( + f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); +""", + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Method Def (Binding Table Entry) Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def method_def( + name: BaseOperatorName, + module: str | None, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, +) -> str: + """ + Generate method def entry. + """ + pycname = get_pycname(name) + + if name.dunder_method: + # PyMethodDef entry for binary op, throws not implemented error + pycname = f"TypeError_to_NotImplemented_<{pycname}>" + + if is_noarg(overloads): + flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS" + else: + pycname = f"castPyCFunctionWithKeywords({pycname})" + flags = "METH_VARARGS | METH_KEYWORDS" + + if module == "torch": + flags += " | METH_STATIC" + + return f'{{"{name}", {pycname}, {flags}, NULL}},' + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Overload Sorting and Grouping +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def group_overloads( + overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True +) -> Sequence[PythonSignatureGroup]: + bases: dict[str, PythonSignatureNativeFunctionPair] = {} + outplaces: dict[str, PythonSignatureNativeFunctionPair] = {} + + # first group by signature ignoring out arguments + for overload in overloads: + sig = overload.signature.signature_str(skip_outputs=True, symint=symint) + if overload.function.func.is_out_fn(): + if sig in outplaces: + raise RuntimeError( + f"Found duplicated function definition:\n- {overload.function.func}.\n" + f"Existing definition:\n- {outplaces[sig].function.func}." + ) + outplaces[sig] = overload + else: + if sig in bases: + raise RuntimeError( + f"Found duplicated function definition:\n- {overload.function.func}.\n" + f"Existing definition:\n- {bases[sig].function.func}." + ) + bases[sig] = overload + + for sig, out in outplaces.items(): + if sig not in bases: + candidates: list[str] = [] + for overload in overloads: + if ( + str(overload.function.func.name.name) + == str(out.function.func.name.name) + and not overload.function.func.is_out_fn() + and not overload.signature.deprecated + ): + candidates.append( + overload.signature.signature_str( + skip_outputs=True, symint=symint + ) + ) + out_sig = out.signature.signature_str(symint=symint) + raise RuntimeError( + f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. " + f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema " + "correctly in native_functions.yaml. We discovered the following candidate(s): \n" + + "\n".join(f"- {candidate}" for candidate in candidates) + ) + + grouped = [ + PythonSignatureGroup.from_pairs( + functional=base, + out=outplaces.get(sig), + ) + for sig, base in bases.items() + ] + return sort_overloads(grouped, symint=symint) + + +# This function declares a partial order on declarations, and sorts them according +# to its linear extension. This is necessary, because there's some ambiguity in the +# choice of overload, and we want a different order. +# +# See Note[Order of overloads matters] +# +# A few examples of ambiguous python signature pairs. +# +# All parameters have the same type, except one taking Tensor the other taking +# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor +# object can be accepted as Scalar type parameter (see python_arg_parser.cpp). +# Therefore, same input arguments might be accepted by either python signature. +# We want to always parse the one taking Tensor first. +# +# bitwise_and(Tensor input, Tensor other, *, Tensor out=None) +# bitwise_and(Tensor input, Scalar other, *, Tensor out=None) +# +# If they have different number of parameters then they are not ambiguous - but +# the difference on output param can be ignored as it's optional. +# +# multiply(Tensor input, Tensor other, *, Tensor out=None) +# multiply(Tensor input, Scalar other) +# +# Both positional args and keyword-only args are considered together. +# +# subtract(Tensor other, *, Scalar alpha=1) +# subtract(Scalar other, Scalar alpha=1) +# +# A few ambiguous cases which it does NOT handle yet. +# +# If there is any difference in other parameters besides the Tensor/Scalar +# difference, then they are not considered ambiguous by this method anymore. +# However, the difference could be too trivial to disambiguate. +# +# foo(Tensor input, Scalar other, Scalar bar) +# foo(Tensor input, Tensor other, double bar) +# +# If they are taking different number of parameters then they are not considered +# ambiguous anymore, even if the difference is only on optional kwargs. +# +# foo(Scalar other, Scalar alpha=1) +# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1) +# + + +def sort_overloads( + grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True +) -> Sequence[PythonSignatureGroup]: + # NB: Smaller here means lower priority + + def is_arg_smaller(t1: Type, t2: Type) -> bool: + return ( + str(t1) == "Scalar" + and str(t2) == "Tensor" + or str(t1) == "Scalar?" + and str(t2) == "Tensor?" + or "Dimname" in str(t1) + and "Dimname" not in str(t2) + or + # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been + # discussed why it is important to prioritize int/int? over int[] + str(t1) == "int[]" + and (str(t2) == "int" or str(t2) == "int?") + or + # TensorList currently throws an error during argument parsing, that's why it needs to be + # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087 + str(t1) == "Tensor[]" + and str(t2).find("[]") != -1 + or + # Prioritize IntArrayRef overload over SymIntArrayRef + str(t1) == "SymInt[]" + and str(t2) == "int[]" + or + # Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly + # converted to either int or SymInt. Prioritize the Tensor overload since it otherwise gets shadowed. + (str(t1) == "SymInt" or str(t1) == "int") + and str(t2) == "Tensor" + ) + + def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool: + """Returns True if s1 < s2 in the partial order.""" + args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True) + if len(args1) != len(args2): + return False + # TODO: should use some canonical form instead of 'str(arg.type)' - see comments + # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which + # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'. + equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2)) + smaller_or_equal = all( + str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type) + for arg1, arg2 in zip(args1, args2) + ) + return smaller_or_equal and not equal + + # First sort by signature + grouped_overloads = sorted( + grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint) + ) + + # Construct the relation graph + larger_than: dict[int, set[int]] = defaultdict(set) + for i1, overload1 in enumerate(grouped_overloads): + for i2, overload2 in enumerate(grouped_overloads): + if is_smaller(overload1.signature, overload2.signature): + larger_than[i1].add(i2) + + if not larger_than: + return list(grouped_overloads) + + # Use a topological sort to sort overloads according to the partial order. + N = len(grouped_overloads) + sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N))) + + for idx in range(N): + # The size of sorted_ids will grow to N eventually. + i = sorted_ids[idx] + for j in sorted(larger_than.keys()): + larger = larger_than[j] + larger.discard(i) + if not larger: + del larger_than[j] + sorted_ids.append(j) + + return [grouped_overloads[x] for x in sorted_ids] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Codegen API Integration +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def emit_single_dispatch( + ps: PythonSignature, + f: NativeFunction, + structseq_typenames: dict[str, str], + *, + symint: bool = True, +) -> str: + """ + Emit dispatch code for a single native function. + """ + + @with_native_function + def go(f: NativeFunction) -> str: + # header comments + if isinstance(ps, PythonSignatureDeprecated): + schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}" + else: + schema_comment = f"// aten::{f.func}" + + deprecated = "[deprecated] " if ps.deprecated else "" + + # dispatch lambda signature + name = cpp.name(f.func) + lambda_formals = ", ".join( + f"{a.type_str} {a.name}" for a in dispatch_lambda_args(ps, f, symint=symint) + ) + lambda_return = dispatch_lambda_return_str(f) + + # dispatch lambda body + dispatch_callee = cpp_dispatch_target(f) + dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps)) + + # from arg parser outputs to dispatch lambda arguments + parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) + lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint) + inits = "\n".join(lambda_arg_exprs.inits) + lambda_args = ", ".join(lambda_arg_exprs.exprs) + + # scatter fields + # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky + # solution for enabling the 'requires_grad' argument for tensor methods + # new_full, new_empty, and new_zeros. A much better but more difficult to + # implement solution involves refactoring according to Ed's description here: + # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 + need_set_requires_grad = ps.tensor_options_args and ( + not has_tensor_options(f) + or (ps.method and ("requires_grad" in parser_outputs)) + ) + set_requires_grad = ( + f'.set_requires_grad({parser_outputs["requires_grad"].expr})' + if need_set_requires_grad + else "" + ) + + if lambda_return == "void": + # Make in-place foreach return `self` at python-binding level. + # ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954 + self_arg = f.func.arguments.self_arg + return_stmt: str + if ( + str(f.func.name).startswith("_foreach_") + and f.func.kind() == SchemaKind.inplace + ): + # note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place + # variant and it unlikely to have it in the future. Thus it's safe to have the following assert. + assert self_arg is not None and is_tensor_list_type( + self_arg.argument.type + ) + return_stmt = """PyObject* self_tensorlist = _r.args[0]; +Py_INCREF(self_tensorlist); +return self_tensorlist; +""" + else: + return_stmt = "Py_RETURN_NONE;" + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + {dispatch_callee}({dispatch_args}); +}}; +dispatch_{name}({lambda_args}){set_requires_grad}; +{return_stmt} +""" + else: + typename = structseq_typenames.get(gen_structseq_typename_key(f)) + structseq_typeref = f"{typename}, " if typename is not None else "" + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + return {dispatch_callee}({dispatch_args}); +}}; +return wrap({structseq_typeref}dispatch_{name}({lambda_args}){set_requires_grad}); +""" + + return go(f) diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_trace_type.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_trace_type.py new file mode 100644 index 0000000000000000000000000000000000000000..3b462655010417a655efec0114b118ba2fa0bd6a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_trace_type.py @@ -0,0 +1,536 @@ +from __future__ import annotations + +import itertools +from typing import Sequence + +from torchgen.api import cpp +from torchgen.api.types import DispatcherSignature +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments +from torchgen.utils import FileManager + + +# Note [Manual Backend kernels] +# For these ops, we want to manually register to dispatch key Backend and +# skip codegen-ed registeration to all keys before Backend. +# For codegen this means: +# - op set below must match ops with manual_kernel_registration=True in native_functions.yaml +# where we skip codegen backend kernels +# - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration +# - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration +# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_BACKEND = { + "options", + "data", + "set_data", + "is_leaf", + "output_nr", + "_version", + "retain_grad", + "_backward", + "requires_grad_", +} + +# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD_AND_TRACER = { + "resize_", + "resize_as_", + "detach", + "detach_", + "copy_", + "_fw_primal", + "_make_dual", +} + +# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops: +# union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER) +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_BACKEND | MANUAL_AUTOGRAD_AND_TRACER + +# These functions we don't want to record for tracing, because we always want +# to trace their constituent parts. This is a temporary hack in lieue +# of proper scopes, where subsequent compilation passes can ask for the unfolding +# on demand. Only concrete ATen methods can be disabled this way; it will have +# NO EFFECT otherwise. +DONT_RECORD_TRACE = { + "convolution", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "lstm_cell", + "gru_cell", + "rnn_tanh_cell", + "rnn_relu_cell", + # FIXME: figure out a better way when we support sparse tensors in jit + "_coalesced", +} + + +def should_trace(f: NativeFunction) -> bool: + # Operations involving Storage or Type are not traceable at the moment + if any( + str(arg.type) in {"Storage", "Type", "ConstQuantizerPtr"} + for arg in f.func.schema_order_arguments() + ): + return False + # We can't trace functions which don't have any Tensor or TensorList returns + if not any(r.type.is_tensor_like() for r in f.func.returns): + return False + return f.func.name.name.base not in DONT_RECORD_TRACE + + +SELECT = CodeTemplate( + """\ + +if (${cond}) { + ${true} +} else { + ${false} +} +""" +) + +OP_NAME = CodeTemplate( + """\ +op_name = c10::Symbol::fromQualString("aten::${trace_name}"); +""" +) + +# These functions have their names recorded under trace renamed, +RENAME_TRACE = { + "zero": "zeros_like", # replacing aten::zero_ with aten::zeros_like + "fill": "full_like", # replacing aten::fill_ with aten::full_like +} + + +def format_trace_op_name(f: NativeFunction) -> str: + # TODO: byte-for-byte compatible with old codegen behavior - should clean up + if ( + f.func.kind() in (SchemaKind.functional, SchemaKind.out) + or f.func.name.name.dunder_method + ): + # special case for *_out functions: the in-place and out-of-place ops + # are overloaded with the same name in the JIT + trace_name = str(f.func.name.name) + trace_name = RENAME_TRACE.get(trace_name, trace_name) + return OP_NAME.substitute(trace_name=trace_name) + + # otherwise, this is an in-place op and we need to emit both in- and + # out-of-place versions + outplace_trace_name = f.func.name.name.base + inplace_trace_name = cpp.name(f.func) + outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name) + inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name) + + return SELECT.substitute( + cond="tracer_state->force_outplace", + true=OP_NAME.substitute(trace_name=outplace_trace_name), + false=OP_NAME.substitute(trace_name=inplace_trace_name), + ) + + +ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""") + + +def format_trace_inputs(f: NativeFunction) -> str: + def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]: + if isinstance(arg, TensorOptionsArguments): + name = "options" + return [ + ADD_TRACE_INPUT.substitute( + name=name, input="c10::optTypeMetaToScalarType(options.dtype_opt())" + ), + ADD_TRACE_INPUT.substitute(name=name, input="options.layout()"), + ADD_TRACE_INPUT.substitute(name=name, input="options.device()"), + ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()"), + ] + else: + name = arg.name + if str(arg.type) == "Tensor?[]": + return [f'jit::tracer::addInputs(node, "{name}", {name});'] + else: + return [ADD_TRACE_INPUT.substitute(name=name, input=name)] + + args: list[Argument | TensorOptionsArguments] = list( + f.func.schema_order_arguments() + ) + + if f.func.is_out_fn(): + # *_out functions take the result as a separate argument, but we don't want to + # trace that argument directly. Instead, we trace its TensorOptions. + # So first, we need to remove the out argument from the list of arguments to trace. + num_out_args = len(f.func.arguments.out) + args = args[:-num_out_args] + + trace_inputs = itertools.chain.from_iterable( + dispatch_trace_input(arg) for arg in args + ) + + if f.func.is_out_fn(): + # for *_out functions, handle the result argument differently for inplace/outplace. + # For inplace: just add the input to the end to confirm with the JIT schema + inplace = [ + ADD_TRACE_INPUT.substitute( + name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name + ) + for i in range(num_out_args) + ] + + # for outplace: do nothing, except if the function is a factory. + # Factories are a bit special because their out-of-place overloads + # take an extra TensorOptions argument, which is missing in the _out function + has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) + has_tensor_input_arg = any( + a.type.is_tensor_like() for a in f.func.arguments.flat_non_out + ) + is_factory_method = f.category_override == "factory" or ( + has_tensor_return and not has_tensor_input_arg + ) + + # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method` + # flag for the whole family of ops with the same basename if any of them is a + # factory method. For most cases the whole family of ops are indeed all factory + # method - 'normal' is the only exception. So we handle it specially here to avoid + # cloning the old logic. + if f.func.name.name.base == "normal": + is_factory_method = True + + if is_factory_method: + outplace = [ + ADD_TRACE_INPUT.substitute( + name="out", + input="c10::optTypeMetaToScalarType(out.options().dtype_opt())", + ), + ADD_TRACE_INPUT.substitute(name="out", input="out.options().layout()"), + ADD_TRACE_INPUT.substitute(name="out", input="out.options().device()"), + ADD_TRACE_INPUT.substitute( + name="out", input="out.options().pinned_memory()" + ), + ] + else: + outplace = [] + + trace_inputs = itertools.chain( + trace_inputs, + [ + SELECT.substitute( + cond="tracer_state->force_outplace", + true="\n".join(outplace), + false="\n".join(inplace), + ) + ], + ) + + return "\n".join(trace_inputs) + + +# `torch.jit.trace` have undocumented keyword argument `_force_outplace`, +# which force jit to replace functions with outplace variants (for +# example `aten::add_` becomes `aten::add`). +# +# This replacement implemented in-place with minimum modifications of +# arguments stack (as it assumes that outplace call has the same arguments +# as inplace version). +# +# However there are no such substitutions available for `aten::fill_` +# and `aten::zero_` operators, as we never implemented `aten::fill` +# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with +# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`. +# +# But as they potentially can have different arguments, we also have +# to hack into the stack and add missing ones. +# +# A possible alternative would be: +# +# - Add `aten::fill` and `aten::zero` +# +# - Or keep `aten::zeros_like` arguments aligned with `aten::zero_` +# arguments (inside of the `native_functions.yaml`) +RENAME_TRACE_ADD_ARGS = { + "fill": """\ + jit::tracer::addInputs(node, "options", ::std::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt)); + ::std::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +""", + "zero": """\ + jit::tracer::addInputs(node, "options", ::std::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt)); + ::std::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +""", +} + +INPLACE_GUARD = CodeTemplate( + """\ +jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input}); +""" +) + +PRE_RECORD_TRACE = CodeTemplate( + """\ +torch::jit::Node* node = nullptr; +std::shared_ptr tracer_state; +if (jit::tracer::isTracing()) { + tracer_state = jit::tracer::getTracingState(); + at::Symbol op_name; + ${set_op_name} + node = tracer_state->createNode(op_name, /*num_outputs=*/0); + jit::tracer::recordSourceLocation(node); + ${add_trace_inputs} + tracer_state->insertNode(node); + ${inplace_guard} + jit::tracer::setTracingState(nullptr); +} +""" +) + + +def format_prerecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return "" + + # TODO: clean up old codegen behavior + is_inplace = ( + f.func.kind() in (SchemaKind.inplace, SchemaKind.out) + and not f.func.name.name.dunder_method + ) + add_args = ( + RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else "" + ) + additional_inputs = ( + SELECT.substitute( + cond="tracer_state->force_outplace", + true=add_args, + false="", + ) + if add_args + else "" + ) + + return PRE_RECORD_TRACE.substitute( + set_op_name=format_trace_op_name(f), + add_trace_inputs=format_trace_inputs(f) + additional_inputs, + inplace_guard=INPLACE_GUARD.substitute( + name=cpp.name(f.func), + mutable_input=f.func.arguments.out[0].name + if f.func.arguments.out + else "self", + ) + if is_inplace + else "", + ) + + +POST_RECORD_TRACE = CodeTemplate( + """\ +if (tracer_state) { + jit::tracer::setTracingState(std::move(tracer_state)); + ${add_trace_outputs} +} +""" +) + + +def format_postrecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return "" + + # For outplacing ops, *_out overloads require special handling to move the + # output *argument* to a return value + if f.func.is_out_fn(): + output_names_outplace = [arg.name for arg in f.func.arguments.out] + output_names_inplace = cpp.return_names(f) + + # Code size optimization: the common case is that the return value is + # the same for both variants + if output_names_outplace == output_names_inplace: + outputs = [ + f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace + ] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + + selection = SELECT.substitute( + cond="force_outplace", + true="\n".join( + f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace + ), + false="\n".join( + f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace + ), + ) + return POST_RECORD_TRACE.substitute(add_trace_outputs=selection) + else: + output_names = cpp.return_names(f) + outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + + +def tie_return_values(f: NativeFunction) -> str: + if len(f.func.returns) == 1: + return f'auto {f.func.returns[0].name or "result"}' + names = cpp.return_names(f) + return f'auto [{", ".join(names)}]' + + +def get_return_value(f: NativeFunction) -> str: + names = cpp.return_names(f) + if len(f.func.returns) == 1: + return names[0] + if f.func.kind() == SchemaKind.out: + return f'std::forward_as_tuple({", ".join(names)})' + else: + moved = ", ".join(f"std::move({name})" for name in names) + return f"std::make_tuple({moved})" + + +TRACE_DISPATCH = CodeTemplate( + """\ +${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});""" +) + + +def emit_trace_body(f: NativeFunction) -> list[str]: + trace_body: list[str] = [] + + trace_body.append(format_prerecord_trace(f)) + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)" + redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) + + assign_return_values = ( + f"{tie_return_values(f)} = " + if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable] + and f.func.returns + else "" + ) + + # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. + # We could probably work harder to ensure that the fast variants are + # called instead, but the perf benefit would be minimal. + trace_body.append( + TRACE_DISPATCH.substitute( + assign_return_values=assign_return_values, + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + + trace_body.append(format_postrecord_trace(f)) + if f.func.returns: + trace_body.append(f"return {get_return_value(f)};") + return trace_body + + +METHOD_DEFINITION = CodeTemplate( + """\ +${return_type} ${type_wrapper_name}(${formals}) { + ${type_definition_body} +} +""" +) + + +def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str: + if f.func.name.overload_name: + name = f"{cpp.name(f.func)}_{f.func.name.overload_name}" + else: + name = cpp.name(f.func) + + # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key. + # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated, + # the key argument should not be passed. + # We do not append key if it is Default so that generated functions from + # before per-dispatch-key derivatives were added retain the same names. + if key != "Default": + name = name + f"_{key}" + return name + + +@with_native_function +def method_definition(f: NativeFunction) -> str: + assert cpp.name(f.func) not in MANUAL_TRACER + + formals = ", ".join( + # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + ["c10::DispatchKeySet ks"] + + [ + f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}' + for a in f.func.schema_order_arguments() + ] + ) + + return METHOD_DEFINITION.substitute( + return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), + type_wrapper_name=type_wrapper_name(f), + formals=formals, + type_definition_body=emit_trace_body(f), + ) + + +WRAPPER_REGISTRATION = CodeTemplate( + """\ +m.impl("${name}", + TORCH_FN(${class_type}::${type_wrapper_name}) +); +""" +) + + +@with_native_function +def method_registration(f: NativeFunction) -> str: + assert cpp.name(f.func) not in MANUAL_TRACER + + return WRAPPER_REGISTRATION.substitute( + name=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type="TraceType", + ) + + +def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]: + return { + "ops_headers": [f"#include "], + "trace_method_definitions": [method_definition(fn)], + "trace_wrapper_registrations": [method_registration(fn)], + } + + +def gen_trace_type( + out: str, native_functions: list[NativeFunction], template_path: str +) -> None: + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_sharded( + "TraceType.cpp", + [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], + key_fn=lambda fn: fn.root_name, + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/TraceType.cpp", + }, + env_callable=gen_trace_type_func, + num_shards=5, + sharded_keys={ + "ops_headers", + "trace_method_definitions", + "trace_wrapper_registrations", + }, + ) diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_factories.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_factories.py new file mode 100644 index 0000000000000000000000000000000000000000..f206939bd535a887827a8f8170e99e6d37a71aef --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_factories.py @@ -0,0 +1,116 @@ +# Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables. +# +# This writes one file: variable_factories.h + +from __future__ import annotations + +import re + +import torchgen.api.python as python +from torchgen.api import cpp +from torchgen.api.types import CppSignatureGroup +from torchgen.context import with_native_function +from torchgen.gen import parse_native_yaml +from torchgen.model import NativeFunction, TensorOptionsArguments, Variant +from torchgen.utils import FileManager, mapMaybe + + +OPTIONAL_TYPE_PATTERN = re.compile(r"std::optional<(.+)>") +TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") + + +# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc. +# TODO: maybe update the cpp argument API to take optional namespace argument? +def fully_qualified_type(argument_type: str) -> str: + def maybe_optional_type(type: str, is_opt: bool) -> str: + return f"std::optional<{type}>" if is_opt else type + + opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type) + is_opt = opt_match is not None + if opt_match: + argument_type = argument_type[opt_match.start(1) : opt_match.end(1)] + match = TYPE_PATTERN.match(argument_type) + if match is None: + return maybe_optional_type(argument_type, is_opt) + index = match.start(1) + qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}" + return maybe_optional_type(qualified_type, is_opt) + + +def gen_variable_factories( + out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str +) -> None: + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + factory_functions = [fn for fn in native_functions if is_factory_function(fn)] + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_with_template( + "variable_factories.h", + "variable_factories.h", + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/variable_factories.h", + "ops_headers": [ + f"#include " for fn in factory_functions + ], + "function_definitions": list(mapMaybe(process_function, factory_functions)), + }, + ) + + +@with_native_function +def is_factory_function(f: NativeFunction) -> bool: + if Variant.function not in f.variants: + return False + + name = cpp.name(f.func) + has_tensor_options = python.has_tensor_options(f) + return has_tensor_options or name.endswith("_like") + + +@with_native_function +def process_function(f: NativeFunction) -> str | None: + name = cpp.name(f.func) + has_tensor_options = python.has_tensor_options(f) + is_factory = has_tensor_options or name.endswith("_like") + + if Variant.function not in f.variants or not is_factory: + return None + + cpp_sigs = CppSignatureGroup.from_native_function(f, method=False) + sigs = [cpp_sigs.signature] + if cpp_sigs.symint_signature is not None: + sigs.append(cpp_sigs.symint_signature) + r = "" + for sig in sigs: + formals: list[str] = [] + exprs: list[str] = [] + requires_grad = "false" + for arg in sig.arguments(): + qualified_type = fully_qualified_type(arg.type) + if arg.default: + formals.append(f"{qualified_type} {arg.name} = {arg.default}") + else: + formals.append(f"{qualified_type} {arg.name}") + + if isinstance(arg.argument, TensorOptionsArguments): + # note: we remove the requires_grad setting from the TensorOptions because + # it is ignored anyways (and we actually have an assertion that it isn't set + # which would fail otherwise). We handle requires_grad explicitly here + # instead of passing it through to the kernel. + exprs.append( + f"at::TensorOptions({arg.name}).requires_grad(::std::nullopt)" + ) + # Manually set the requires_grad bit on the result tensor. + requires_grad = f"{arg.name}.requires_grad()" + else: + exprs.append(arg.name) + + r += f"""\ +inline at::Tensor {sig.name()}({', '.join(formals)}) {{ + at::AutoDispatchBelowADInplaceOrView guard; + return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad}); +}} +""" + return r diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_type.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_type.py new file mode 100644 index 0000000000000000000000000000000000000000..4bec1871ae483ea7b12f7c3ff9ecc6198ea8c383 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_type.py @@ -0,0 +1,2180 @@ +# Generates VariableType.h/cpp +# +# **If any changes are being made to the VariableType codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# +# VariableType is a subclass of at::Type that provides the binding code +# necessary to provide a differentiable version of ATen operators. There are a +# number of different things we could mean: +# +# - Given a non-differentiable forward implementation, we might +# directly associate it with a backward implementation to make +# it differentiable. This is the common case. +# +# - Some functions don't need a backwards implementation, because +# backpropagation will never propagate beyond them. There are a +# number of different reasons why this may be the case: +# +# - The function has no differentiable inputs +# - The function's output is not differentiable +# - The function has no data dependency on its input +# +# - Some function don't need a backwards implementation because they +# are implemented as a composition of other (differentiable) ATen +# functions. These are dispatched directly to the Type superclass, +# which will in turn dispatch back to VariableType for its +# differentiable subcomponents. +# + +from __future__ import annotations + +import re +from typing import Callable, Sequence + +from torchgen.api import cpp +from torchgen.api.autograd import ( + DifferentiableInput, + dispatch_strategy, + ForwardDerivative, + gen_differentiable_outputs, + is_differentiable, + NativeFunctionWithDifferentiabilityInfo, + SavedAttribute, +) +from torchgen.api.types import ( + ArrayRefCType, + BaseCppType, + BaseCType, + Binding, + DispatcherSignature, + intArrayRefT, + iTensorListRefT, + ListCType, + MutRefCType, + OptionalCType, + scalarT, + SpecialArgName, + stringT, + symIntArrayRefT, + TENSOR_LIST_LIKE_CTYPES, + tensorListT, + tensorT, + TupleCType, + VectorCType, +) +from torchgen.code_template import CodeTemplate +from torchgen.context import ( + native_function_manager, + with_native_function, + with_native_function_and, +) +from torchgen.model import ( + Argument, + BaseType, + ListType, + NativeFunction, + SchemaKind, + SelfArgument, + TensorOptionsArguments, +) +from torchgen.utils import FileManager, mapMaybe + +from .context import with_native_function_with_differentiability_info_and_key +from .gen_inplace_or_view_type import ( + ALL_VIEW_FUNCTIONS, + ASSIGN_RETURN_VALUE, + AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION, + gen_formals, + get_base_name, + get_view_info, + is_tensor_list_type, + is_tensor_type, + METHOD_DEFINITION, + modifies_arguments, + TMP_VAR, + unpack_args, + unpacked_name, + use_derived, + WRAPPER_REGISTRATION, +) +from .gen_trace_type import ( + get_return_value, + MANUAL_AUTOGRAD_AND_TRACER, + MANUAL_BACKEND, + tie_return_values, + type_wrapper_name, +) + + +# We don't set or modify grad_fn on these methods. Generally, they return +# tensors that have requires_grad=False. In-place functions listed here will +# not examine or modify requires_grad or grad_fn. +# NB: this does NOT include overload name +DONT_REQUIRE_DERIVATIVE = { + # These only depend on the input Tensor's shape and device, not the data + "empty_like", + "ones_like", + "full_like", + "zeros_like", + "rand_like", + "randn_like", + "new_empty", + "new_empty_strided", + "new_full", + "new_zeros", + "new_ones", + # These are only implemented on integral types + "__and__", + "__iand__", + "__ilshift__", + "__ior__", + "__irshift__", + "__ixor__", + "__lshift__", + "__or__", + "__rshift__", + "__xor__", + # These work on integral data types, and hence don't require derivative + "_sobol_engine_draw", + "_sobol_engine_ff", + "_sobol_engine_scramble_", + "_sobol_engine_initialize_state_", + # This is an unsafe method that is meant to be out of reach of autograd. + "_coalesced_", + # Quantize functions should not record gradients + "quantize_per_tensor", + "quantize_per_channel", + # Functions that return integers should not have output that require gradients + "argmax", + "argmin", + "argsort", + "searchsorted", + "bucketize", + # Functions that return booleans are not differentiable + "isnan", + "isposinf", + "isneginf", + "isinf", + "signbit", + "isin", + "allclose", + # Functions return none are not differentiable + "record_stream", + # These functions are not differentiable + "logical_and", + "logical_xor", + "logical_not", + "logical_or", + # This function returns nested_tensor shape as a tensor that is non-differentiable + "_nested_tensor_size", + "_nested_tensor_strides", + "_nested_tensor_storage_offsets", +} + +# The C -> R functions at the time of adding this are still being audited and tested +# but will not error out. +# C -> C, R -> C functions for which backward is correctly implemented and tested +GRADIENT_IMPLEMENTED_FOR_COMPLEX = { + "fill", + "t", + "t_copy", + "view", + "reshape", + "reshape_as", + "view_as", + "view_copy", + "roll", + "clone", + "block_diag", + "diag_embed", + "repeat", + "expand", + "expand_copy", + "flip", + "fliplr", + "flipud", + "rot90", + "nanmean", + "nansum", + "transpose", + "permute", + "squeeze", + "unsqueeze", + "unsqueeze_copy", + "resize", + "resize_as", + "tril", + "triu", + "chunk", + "zero_", + "eq_", + "ne_", + "add", + "__radd__", + "sum", + "_conj", + "sin", + "cos", + "mul", + "sinc", + "sinh", + "cosh", + "__rmul__", + "sgn", + "asin", + "acos", + "sub", + "div", + "cat", + "view_as_complex", + "index_put", + "neg", + "complex", + "select", + "where", + "as_strided", + "as_strided_copy", + "as_strided_scatter", + "slice", + "constant_pad_nd", + "unbind", + "split", + "split_with_sizes", + "unsafe_split", + "split_with_sizes_backward", + "dot", + "vdot", + "cholesky", + "triangular_solve", + "mm", + "_unsafe_view", + "mv", + "outer", + "bmm", + "diagonal", + "alias", + "atan", + "log", + "log10", + "log1p", + "log2", + "logaddexp", + "logsumexp", + "logcumsumexp", + "reciprocal", + "tan", + "pow", + "rsqrt", + "tanh", + "tanh_backward", + "asinh", + "acosh", + "atanh", + "take", + "fill_", + "exp", + "exp2", + "expm1", + "nonzero", + "mean", + "std_mean", + "var_mean", + "inverse", + "solve", + "linalg_cholesky", + "addcmul", + "addcdiv", + "matrix_exp", + "linalg_matrix_exp", + "_linalg_eigh", + "cholesky_solve", + "linalg_qr", + "_linalg_svd", + "_fft_c2c", + "_fft_r2c", + "linalg_solve", + "sqrt", + "stack", + "gather", + "index_select", + "index_add_", + "linalg_inv", + "linalg_inv_ex", + "baddbmm", + "addbmm", + "addmm", + "addmv", + "addr", + "linalg_householder_product", + "ormqr", + "reflection_pad1d", + "reflection_pad2d", + "reflection_pad3d", + "linalg_cholesky_ex", + "linalg_eig", + "diagonal_copy", + "diagonal_scatter", + "alias_copy", + "select_backward", + "diagonal_backward", + "slice_backward", + "reflection_pad1d_backward", + "reflection_pad2d_backward", + "reflection_pad3d_backward", + "_sparse_sparse_matmul", + "replication_pad1d", + "replication_pad2d", + "replication_pad3d", + "put", + "put_", + "_to_copy", + "replication_pad1d_backward", + "replication_pad2d_backward", + "replication_pad3d_backward", + "diag", + "masked_scatter", + "masked_select", + "index_add", + "index_fill", + "trace", + "polar", + "cumsum", + "rsub", + "eig", + "lerp", + "linalg_vector_norm", + "cumprod", + "prod", + "index_copy", + "lu", + "unfold", + "unfold_backward", + "index", + "masked_fill", + "masked_scatter_backward", + "linalg_cross", + "lu_unpack", + "renorm", + "_conj_physical", + "linalg_lu_factor_ex", + "scatter", + "scatter_add", + "sigmoid", + "sigmoid_backward", + "sparse_mask", + "trapezoid", + "cumulative_trapezoid", + "conj_physical_", + "_neg_view", + "_reshape_alias", + "_reshape_copy", + "_linalg_det", + "lu_solve", + "linalg_solve_triangular", + "linalg_pinv", + "linalg_lstsq", + "unfold_copy", + "col2im", + "im2col", + "cholesky_inverse", + "to_sparse", + "sparse_sampled_addmm", + "linalg_lu", + "pixel_shuffle", + "pixel_unshuffle", + "channel_shuffle", + "linalg_lu_solve", + "_linalg_slogdet", + "_linalg_solve_ex", + "_unsafe_index", + "_unsafe_index_put", + "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", +} + +GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { + "_to_dense", + "_coalesce", + "coalesce", + "values", + "_sparse_coo_tensor_with_dims_and_tensors", + "_sparse_addmm", +} + +GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX) + +# Some operators invalidate the grad_accumulator. Let's reset it. +RESET_GRAD_ACCUMULATOR = {"set_", "resize_"} + +# NOTE [ TensorImpl and Storage Pointer Sanity Checks ] +# +# We check the following properties: +# 1) A function should never change the input tensors' underlying c10::TensorImpl +# pointers or c10::Storage pointers, even if it modifies its input tensors (via +# inplace or out-variants) +# If the function does not modify its arguments, we also check the following properties +# pertaining to its output: +# 2) Its TensorImpl has use_count of 1 +# 3) If the function is a view function, it has the same StorageImpl as that of +# the input it is aliased with. Otherwise, its StorageImpl has use_count of 1 +# +# The following code templates implement the checks for this invariant: +SAVE_TENSOR_STORAGE = CodeTemplate( + """\ +auto ${tensor_name}_storage_saved = + ${tensor_name}.has_storage() ? ::std::optional(${tensor_name}.storage()) : ::std::nullopt; +""" +) + + +# If tensor_name == out_tensor_name, used to enforce (1), otherwise used for (2) +ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate( + """\ +if (${tensor_name}_storage_saved.has_value() && + !at::impl::dispatch_mode_enabled() && + !at::impl::tensor_has_dispatch(${tensor_name}) && + !at::impl::tensor_has_dispatch(${out_tensor_name})) + TORCH_INTERNAL_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage())); +""" +) + +SAVE_TENSORLIST_STORAGE = CodeTemplate( + """\ +std::vector<::std::optional> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); +for (const Tensor& tensor : ${tensorlist_name}) + ${tensorlist_name}_storage_saved.push_back( + tensor.has_storage() ? ::std::optional(tensor.storage()) : ::std::nullopt); +""" +) + +ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name})) + TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage())); +} +""" +) + +SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate( + """\ +std::vector<::std::optional> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); +for (const ::std::optional& tensor : ${tensorlist_name}) + ${tensorlist_name}_storage_saved.push_back( + tensor.has_value() && tensor->has_storage() ? ::std::optional(tensor->storage()) : ::std::nullopt); +""" +) + +ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name})) + TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of( + static_cast<::std::optional>(${tensorlist_name}[i])->storage())); +} +""" +) + +SAVE_TENSOR_IMPL = CodeTemplate( + """\ +c10::intrusive_ptr ${tensor_name}_impl_saved; +if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr(); +""" +) + +ENFORCE_SAME_TENSOR_IMPL = CodeTemplate( + """\ +if (${tensor_name}_impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) + TORCH_INTERNAL_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr()); +""" +) + +ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate( + """\ +if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) + TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}"); +""" +) + +ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE = CodeTemplate( + """\ +if (${tensor_name}.has_storage() && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) { + TORCH_INTERNAL_ASSERT(${tensor_name}.storage().use_count() == 1, "function: ${fn_name}"); +} +""" +) + +SAVE_TENSORLIST_IMPL = CodeTemplate( + """\ +std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); +for (size_t i=0; i<${tensorlist_name}.size(); i++) + if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr(); +""" +) + +ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_impl_saved[i] && !at::impl::tensorlist_has_dispatch(${tensorlist_name})) + TORCH_INTERNAL_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr()); +} +""" +) + +SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate( + """\ +std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + ::std::optional t = ${tensorlist_name}[i]; + if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr(); +} +""" +) + +ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_impl_saved[i]) + TORCH_INTERNAL_ASSERT( + ${tensorlist_name}_impl_saved[i] == static_cast<::std::optional>(${tensorlist_name}[i])->getIntrusivePtr()); +} +""" +) + +# The following list contains functions that we don't enforce the invariant on. +DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = { + # These functions are expected to change impl or storage of input tensors + "set_", + "_cudnn_rnn_flatten_weight", + "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", +} +DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = { + # These non-inplace, non-out functions return tensors with use_count > 1 + # Therefore, they MAY (but not necessarily) return one of its inputs as-is + # See https://github.com/pytorch/pytorch/issues/60426 for more information + "_embedding_bag", + "_embedding_bag_forward_only", + "q_per_channel_scales", + "q_per_channel_zero_points", + "lu_unpack", + "_cudnn_rnn_backward", + # The below failed StorageImpl use_count check but we skip tensor_impl check + # just in case + "_cudnn_rnn", + "dequantize_self", + # lift() should never actually be called with a requires_grad=True tensor, + "lift", + "lift_fresh", + "lift_fresh_copy", + # Nested Tensors related functions + # _nested_tensor_size() should never actually be called with requires_grad=True tensor + "_nested_tensor_size", + "_nested_tensor_strides", + "_nested_tensor_storage_offsets", +} + +DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = { + # These non-view functions return tensors with storage use_count != 1 + "_slow_conv2d_forward", + "slow_conv3d_forward", + "channel_shuffle", + # If an input is returned as-is in output, we cannot guarantee its storage_impl + # use count to be 1 either. + *DONT_ENFORCE_TENSOR_IMPL_USE_COUNT, +} +# END CHECKS FOR [ TensorImpl and Storage Pointer Sanity Checks ] + +DECLARE_GRAD_FN = CodeTemplate( + """\ +std::shared_ptr<${op}> grad_fn; +""" +) + +DECLARE_VECTOR_OF_GRAD_FN = CodeTemplate( + """\ +std::vector> grad_fns; +""" +) + +SETUP_ANY_REQUIRES_GRAD = CodeTemplate( + """\ +[[maybe_unused]] auto _any_requires_grad = compute_requires_grad( ${args_with_derivatives} ); +${extra_differentiability_conditions} +""" +) + +SETUP_DERIVATIVE = CodeTemplate( + """\ +if (_any_requires_grad) { + ${setup} +} +""" +) + +SETUP_NONE_REQUIRES_GRAD = CodeTemplate( + """\ +if (compute_requires_grad( ${args_to_check} )) { + throw_error_out_requires_grad("${base_name}"); +} +""" +) + +ASSIGN_GRAD_FN = CodeTemplate( + """\ +grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode); +grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); +""" +) + +# note(crcrpar): `compute_requires_grad` in the template below is supplied with arguments indexed with `i` +# while the `SETUP_ANY_REQUIRES_GRAD` above takes whole tensors and scalars. +ASSIGN_VECTOR_OF_GRAD_FN = CodeTemplate( + """\ +for (const auto& i : c10::irange( ${irange} )) { + const auto ith_requires_grad = compute_requires_grad(${args_with_derivatives}); + check_inplace(self[i], ith_requires_grad); + grad_fns.push_back([&]() -> std::shared_ptr<${op}> { + if (!ith_requires_grad) { + return nullptr; + } else { + auto grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode); + grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); + return grad_fn; + } + }()); +} +""" +) + +CALL_REDISPATCH = CodeTemplate( + """\ +at::redispatch::${api_name}(${unpacked_args})""" +) +# If the non-variable operation has return values, we use the `tmp` variable to hold the +# values temporarily and pass the values to the return variables outside of the +# `at::AutoDispatchBelowAutograd` guard block. +DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP = CodeTemplate( + """\ +auto ${tmp_var} = ([&]() { + if (${any_has_forward_grad}) { + static c10::OperatorName full_name("aten::${op_name}", "${op_overload}"); + static ::std::optional opt_op = c10::Dispatcher::singleton().findSchema(full_name); + return impl::run_jit_decomposition_with_args_for_jvp<${return_types}>("${op_name}", *opt_op, ks, ${arg_names}); + } else { + ${guard} + return ${base_type_call}; + } +})(); +""" +) + +DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate( + """\ +auto ${tmp_var} = ([&]() { + ${guard} + return ${base_type_call}; +})(); +""" +) + +DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate( + """\ +{ + ${guard} + ${base_type_call}; +} +""" +) + +SET_HISTORY = CodeTemplate( + """\ +if (grad_fn) { + ${fn}_history(${differentiable_outputs}, grad_fn); +} +""" +) + +LOOP_OVER_VECTOR_OF_GRAD_FNS = CodeTemplate( + """\ +if (!grad_fns.empty()) { + ${preamble} + for (const auto& i : c10::irange(grad_fns.size())) { + auto grad_fn = grad_fns[i]; + if (grad_fn != nullptr) { + ${statements} + } + } +} +""" +) + +CONDITIONAL = CodeTemplate( + """\ +if (${cond}) { + ${statements} +} +""" +) + +RUN_ONLY_IN_DEBUG_MODE = CodeTemplate( + """\ +#ifndef NDEBUG +${statements} +#endif +""" +) + +FW_DERIVATIVE_CHECK_TEMPLATE = CodeTemplate( + """\ +isFwGradDefined(${req_inp})\ +""" +) +FW_DERIVATIVE_SIZE_CHECK_TEMPLATE = CodeTemplate( + """\ +TORCH_CHECK( + self.size() == ${inp_name}.size(), + "Tensor lists must have the same number of tensors, got ", + self.size(), + " and ", + ${inp_name}.size()); +""" +) + +FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate( + """\ +isFwGradDefinedTensorList(${req_inp})\ +""" +) + +FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate( + """\ +auto ${inp_name}_t_raw = toNonOptFwGrad(${inp}); +auto ${inp_name}_tensor = toNonOptTensor(${inp}); +auto ${inp_name}_t = (${inp_name}_t_raw.defined() || !${inp_name}_tensor.defined()) + ? ${inp_name}_t_raw : at::${zeros_fn}(${inp_name}_tensor.sym_sizes(), ${inp_name}_tensor.options()); +""" +) + +FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate( + """\ +auto ${inp_name}_p = toNonOptPrimal(${inp}); +""" +) + +FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate( + """\ +if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}.defined()) { + // The hardcoded 0 here will need to be updated once we support multiple levels. + ${out_arg}._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace}); +} +""" +) + +FW_DERIVATIVE_SETTER_TENSOR_FOREACH = CodeTemplate( + """\ +for (const auto& i : c10::irange(${out_arg}_new_fw_grad_opts.size())) { + auto& ${out_arg}_new_fw_grad_opt = ${out_arg}_new_fw_grad_opts[i]; + if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}[i].defined()) { + // The hardcoded 0 here will need to be updated once we support multiple levels. + ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace}); + } +} +""" +) + +FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate( + """\ +if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined() + && ${out_arg}.defined()) { + ${out_arg}._set_fw_grad(std::get<${idx}>(${all_res}_new_fw_grad_opt.value()), /* level */ 0, /* is_inplace_op */ false); +} +""" +) + +FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate( + """\ +if (${out_arg}_new_fw_grad_opt.has_value()) { + auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value(); + TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size()); + for (const auto i : c10::irange(${out_arg}.size())) { + if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) { + // The hardcoded 0 here will need to be updated once we support multiple levels. + ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace}); + } + } +} +""" +) + +FW_DERIVATIVE_TEMPLATE = CodeTemplate( + """\ +${fw_grad_opt_definition} +if (${requires_fw_grad}) { + ${unpacked_arguments} + ${out_arg}_new_fw_grad_opt = ${formula}; +} +""" +) + +FW_DERIVATIVE_FOREACH_TEMPLATE = CodeTemplate( + """\ +${fw_grad_opt_definition} +for (const auto& i : c10::irange(${vector_of_optional_tensor}.size())) { + if (${any_has_forward_grad_for_current_index}) { + ${unpacked_arguments} + ${vector_of_optional_tensor}[i] = ${formula}; + } +} +""" +) + +FW_DERIVATIVE_FORBID_TEMPLATE = CodeTemplate( + """\ +TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}"); +""" +) + +FW_DERIVATIVE_FORBID_LIST_TEMPLATE = CodeTemplate( + """\ +for (const auto& _t: ${arg}) { + TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}"); +} +""" +) + + +def gen_variable_type( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo], + template_path: str, + used_keys: set[str], +) -> None: + """VariableType.h and VariableType.cpp body + + This is the at::Type subclass for differentiable tensors. The + implementation of each function dispatches to the base tensor type to + compute the output. The grad_fn is attached to differentiable functions. + """ + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write( + "VariableType.h", + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/VariableType.h" + }, + ) + + # helper that generates a TORCH_LIBRARY_IMPL macro for each + # dispatch key that appears in derivatives.yaml + def wrapper_registrations(used_keys: set[str]) -> str: + library_impl_macro_list: list[str] = [] + for key in sorted(used_keys): + dispatch_key = key + if key == "Default": + dispatch_key = "Autograd" + library_impl_macro = ( + f"TORCH_LIBRARY_IMPL(aten, {dispatch_key}, m) " + + "{\n" + + "${" + + f"wrapper_registrations_{key}" + + "}\n}" + ) + library_impl_macro_list += [library_impl_macro] + return "\n\n".join(library_impl_macro_list) + + # Generate a new template from VariableType.cpp which replaces ${wrapper_registrations} + # with per key TORCH_LIBRARY_IMPL macros for each key that appears in derivatives.yaml + fm1 = FileManager( + install_dir=out + "/templates", template_dir=template_path, dry_run=False + ) + fm1.write( + "VariableType.cpp", + lambda: { + "type_derived_method_definitions": "\n\n".join( + [ + "${" + f"type_derived_method_definitions_{key}" + "}" + for key in sorted(used_keys) + ] + ), + "wrapper_registrations": wrapper_registrations(used_keys), + }, + ) + + # Generate final VariableType_*.cpp files from the generated template + fm2 = FileManager(install_dir=out, template_dir=out + "/templates", dry_run=False) + + sharded_keys = set( + [f"type_derived_method_definitions_{key}" for key in sorted(used_keys)] + + [f"wrapper_registrations_{key}" for key in sorted(used_keys)] + ) + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + fm2.write_sharded( + "VariableType.cpp", + [fn for fn in fns_with_diff_infos if use_derived(fn)], + key_fn=lambda fn: cpp.name(fn.func.func), + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/VariableType.cpp", + }, + env_callable=gen_variable_type_func, + num_shards=5, + sharded_keys=sharded_keys, + ) + + +@with_native_function_and +def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str: + return WRAPPER_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name, + type_wrapper_name=type_wrapper_name(f, key), + class_type="VariableType", + ) + + +def gen_variable_type_func( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> dict[str, list[str]]: + f = fn.func + result = {} + with native_function_manager(f): + name = cpp.name(f.func) + formals = gen_formals(f) + + if ( + fn.info is None + and str(f.func.name.name) not in RESET_GRAD_ACCUMULATOR + and get_base_name(f) not in DONT_REQUIRE_DERIVATIVE + and len(gen_differentiable_outputs(fn)) > 0 + and cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE + and type_wrapper_name(f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT + and type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT + ): + # NOTE: [ Registering AutogradNotImplemented boxed kernel ] + # + # When there is no derivatives.yaml entry, we register a generic boxed + # NotImplemented kernel to set grad_fn to be NotImplemented, so that forward + # proceeds as usual but an error is properly produced on backward. + # TODO: it would be nice to not have these special cases + # + # There are several cases where still let codegen handle it: + # 1) ops that need to reset grad accumulator (we let codegen handle this case + # because) the list is (currently) only accessible in Python. + # 2) User explicitly specifies DONT_REQUIRE_DERIVATIVE. This basically makes + # autograd a fallthrough with NDEBUG checks. This can be useful for when all + # outputs are integral. + # 3) When there are no differentiable outputs. This is similar to (2). + # 4) There are certain ops where we skip certain NDEBUG checks. this is similar + # to (1). + type_definition = "" + wrapper_registration = AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name + ) + result["type_derived_method_definitions_Default"] = [type_definition] + result["wrapper_registrations_Default"] = [wrapper_registration] + else: + if not fn.info: + key = "Default" + type_definition = METHOD_DEFINITION.substitute( + return_type=cpp.returns_type( + f.func.returns, symint=True + ).cpp_type(), + type_wrapper_name=type_wrapper_name(f, key), + type_definition_body=emit_body(fn, key), + formals=formals, + ) + wrapper_registration = gen_wrapper_registration(f, key) + result[f"type_derived_method_definitions_{key}"] = [type_definition] + result[f"wrapper_registrations_{key}"] = [wrapper_registration] + else: + for key in fn.info.keys(): + type_definition = METHOD_DEFINITION.substitute( + return_type=cpp.returns_type( + f.func.returns, symint=True + ).cpp_type(), + type_wrapper_name=type_wrapper_name(f, key), + type_definition_body=emit_body(fn, key), + formals=formals, + ) + wrapper_registration = gen_wrapper_registration(f, key) + result[f"type_derived_method_definitions_{key}"] = [type_definition] + result[f"wrapper_registrations_{key}"] = [wrapper_registration] + # See Note [Manual Backend kernels] + assert (name in MANUAL_BACKEND) == f.manual_kernel_registration + # If you want to register a kernel to Autograd, you must make the op abstract. + # In other words, this op must have dispatch section in native_functions.yaml. + if name in MANUAL_AUTOGRAD_AND_TRACER or ( + fn.info and any(info.has_derivatives for info in fn.info.values()) + ): + msg = ( + f"There's a formula for {name}(or its functional variant) in derivatives.yaml. " + f"It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA " + f"or CompositeExplicitAutograd in native_functions.yaml. Please see " + f"https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword " + f"for instructions to choose the right dispatch keyword." + ) + assert f.is_abstract, msg + + return result + + +_foreach_ops_without_differentiability_info = { + # No reference backward available due to the lack of `{maximum, minimum}(tensor, scalar)`. + ("_foreach_maximum", "Scalar"), + ("_foreach_maximum", "ScalarList"), + ("_foreach_minimum", "Scalar"), + ("_foreach_minimum", "ScalarList"), + # No reference backward available as addcdiv/addcmul don't support Tensor as scaling factor. + ("_foreach_addcdiv", "Tensor"), + ("_foreach_addcmul", "Tensor"), + ("_foreach_copy", ""), +} + +_foreach_ops_with_different_arity = { + # These ops lack `alpha` of scaling factor to applied to the right hand side argument. + ("_foreach_add", "Scalar"), + ("_foreach_add", "ScalarList"), + ("_foreach_sub", "Scalar"), + ("_foreach_sub", "ScalarList"), +} + + +@with_native_function_with_differentiability_info_and_key +def emit_body( + fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" +) -> list[str]: + assert dispatch_strategy(fn) == "use_derived" + f = fn.func + info = fn.info[key] if fn.info else None + fw_derivatives = fn.fw_derivatives.get(key, []) if fn.fw_derivatives else [] + + name = cpp.name(f.func) + inplace = f.func.kind() == SchemaKind.inplace + is_out_fn = f.func.kind() == SchemaKind.out + returns_void = len(f.func.returns) == 0 + base_name = get_base_name(f) + view_info = get_view_info(f) + + is_foreach = name.startswith("_foreach") + is_inplace_foreach = is_foreach and inplace + if is_inplace_foreach: + inplace_foreacharg2refarg: dict[Argument, Argument] = {} + refargname2inplace_foreacharg: dict[str, Argument] = {} + base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name) + if info is None: + assert ( + base_name_and_overload_name + in _foreach_ops_without_differentiability_info + ), f"{'.'.join(base_name_and_overload_name)} should have a differentiability info" + else: + assert ( + len(f.func.arguments.flat_non_out) + == len(info.func.func.arguments.flat_non_out) + ) or (base_name_and_overload_name in _foreach_ops_with_different_arity), ( + f"{'.'.join(base_name_and_overload_name)} has {len(f.func.arguments.flat_non_out)} args " + f"but the reference has {len(info.func.func.arguments.flat_non_out)}" + ) + for foreach_arg, ref_arg in zip( + f.func.arguments.flat_non_out, info.func.func.arguments.flat_non_out + ): + foreach_arg_type = foreach_arg.type + if isinstance(foreach_arg_type, ListType): + foreach_arg_type = foreach_arg_type.elem + assert foreach_arg_type == ref_arg.type + inplace_foreacharg2refarg[foreach_arg] = ref_arg + refargname2inplace_foreacharg[ref_arg.name] = foreach_arg + + def gen_differentiable_input( + arg: Argument | SelfArgument | TensorOptionsArguments, + ) -> DifferentiableInput | None: + if isinstance(arg, TensorOptionsArguments): + return None + a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg + + # TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove. + # NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are + # not handled properly as they are irrelevant for this codegen. + cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type() + + if not is_differentiable(a.name, a.type, info): + return None + return DifferentiableInput( + name=a.name, + type=a.type, + cpp_type=cpp_type, + ) + + @with_native_function + def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]: + arguments = list(f.func.arguments.non_out) + if is_inplace_foreach and info is not None: + for i, arg in enumerate(f.func.arguments.flat_non_out): + if arg in inplace_foreacharg2refarg: + # note(crcrpar): From what I understand, what matters is only the name. + # Thus originally I only replace argument only when the names are different. + # TODO(crcrpar): Make it simpler. + mapped_arg = inplace_foreacharg2refarg[arg] + arguments[i] = Argument( + mapped_arg.name, + mapped_arg.type, + mapped_arg.default, + mapped_arg.annotation, + ) + return list(mapMaybe(gen_differentiable_input, arguments)) + + def find_args_with_derivatives( + differentiable_inputs: list[DifferentiableInput], + ) -> list[DifferentiableInput]: + """Find arguments that have derivative definitions""" + if info is None or not info.has_derivatives: + return differentiable_inputs + names = {name for d in info.derivatives for name in d.var_names} + differentiable = [arg for arg in differentiable_inputs if arg.name in names] + if len(differentiable) != len(names): + missing = names - {arg.name for arg in differentiable} + raise RuntimeError( + f"Missing arguments for derivatives: {missing} in {info.name}" + ) + return differentiable + + differentiable_inputs = gen_differentiable_inputs(f) + args_with_derivatives = find_args_with_derivatives(differentiable_inputs) + differentiable_outputs = gen_differentiable_outputs(fn, key) + + undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or ( + name in DONT_REQUIRE_DERIVATIVE + ) + + requires_derivative = ( + (not undifferentiable) + and (len(differentiable_inputs) > 0) + and ( + (len(differentiable_outputs) > 0) + # note(crcrpar): In-place foreach functions are a void function. + or is_inplace_foreach + ) + ) + + if ( + info is not None + and info.has_derivatives + and not requires_derivative + # out= ops are allowed to have zero returns which cause requires_derivative to be False + # we shouldn't error out though (out= ops for autograd just redispatch) + and len(f.func.returns) > 0 + ): + raise RuntimeError( + f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative" + ) + + # note(crcrpar): In-place foreach functions do not support forward AD + if requires_derivative and len(fw_derivatives) > 0 and not is_inplace_foreach: + assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len( + differentiable_outputs + ), ( + "Expected the number of forward derivatives implemented to match the " + "number of differentiable outputs. NB: This only applies when at least " + "one forward derivative is implemented. Not implementing any forward " + "derivatives is also okay, and we would require inputs to the op to " + "not have associated tangents in that case." + ) + + try_jit_decomposition = ( + requires_derivative + and len(fw_derivatives) == 0 + and (not modifies_arguments(f)) + and (not returns_void) + ) + + def emit_save_inputs() -> list[str]: + setup: list[str] = [] + if info is None or not info.has_derivatives: + return setup + + has_tensorlist_arg = any( + is_tensor_list_type(arg.type) for arg in args_with_derivatives + ) + + # We don't want to save tensors if we know that they will never be used + # when computing the derivative, so we add guards to those statements + def guard_for(arg: SavedAttribute) -> str | None: + assert info is not None + + # It's hard to determine the edge offset if we have TensorLists + # NOTE(crcrpar): in-place foreach functions' arguments include tensorlist + # but their derivatives don't use it, so let them bypass this check. + if has_tensorlist_arg and (not is_inplace_foreach): + return None + + # Empirical evaluation of the cases where we insert those guards in + # backward show that they are somewhat useless. E.g. there's no need + # to guard on some values captured from forward, because they had to + # require_grad if the backward function even gets executed. I don't + # have any good ideas for detecting those cases, so I simply disabled the + # checks. + if "backward" in info.name: + return None + + # If there's a single derivative we could compute, we already have + # a requires_grad check that is sufficient + if len(args_with_derivatives) <= 1: + return None + + # We really only care about trimming down the amount of tensors we save + if arg.nctype.type != BaseCType(tensorT): + return None + + # We want to emit simple guards, so we only allow that if checking one + # input is enough to determine whether we need that value + used_in = [d for d in info.derivatives if arg in d.saved_inputs] + assert len(used_in) > 0 + if len(used_in) != 1: + return None + derivative = used_in[0] + + # Case with multioutput formulas + # TODO: process all derivative formulas!!! + if len(derivative.var_names) != 1: + wrap_opt_if_start = derivative.formula.find( + f"wrap_opt_if({arg.nctype.name}" + ) + if wrap_opt_if_start == -1: + return None + + wrap_opt_if_match = re.match( + rf"wrap_opt_if\({arg.nctype.name},(.*?)\)", + derivative.formula[wrap_opt_if_start:], + ) + assert wrap_opt_if_match is not None + + # Condition is between 'wrap_opt_if(var_name,' and ')'. + condition_slice = slice(len(rf"wrap_opt_if\({arg.nctype.name},"), -1) + wrap_opt_if_condition = wrap_opt_if_match.group(0)[ + condition_slice + ].strip() + # replace 'grad_input_mask[num]' with 'grad_fn->should_compute_output(num)' + wrap_opt_if_condition = re.sub( + r"grad_input_mask\[(\d+)\]", + r"grad_fn->should_compute_output(\1)", + wrap_opt_if_condition, + ) + return f"{wrap_opt_if_condition}" + + # Figure out the offset of the edge that uses this variable + derivative_var_name = derivative.var_names[0] + for edge_off, a in enumerate(args_with_derivatives): + if a.name == derivative_var_name: + break + else: + raise AssertionError + return f"grad_fn->should_compute_output({edge_off})" + + if is_inplace_foreach: + save_input_stmts = save_variables(info.all_saved_inputs, False, guard_for) + if save_input_stmts: + setup.append( + LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute( + preamble="", statements=save_input_stmts + ) + ) + else: + setup.extend(save_variables(info.all_saved_inputs, False, guard_for)) + for arg in args_with_derivatives: + if is_tensor_list_type(arg.type): + setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();") + return setup + + def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]: + body: list[str] = [] + if is_out_fn: + # For out functions, ensure that no input or output requires grad + body.append(DECLARE_GRAD_FN.substitute(op="Node")) + body.append( + SETUP_NONE_REQUIRES_GRAD.substitute( + base_name=base_name, + args_to_check=[arg.name for arg in differentiable_inputs], + ) + ) + body.append( + SETUP_NONE_REQUIRES_GRAD.substitute( + base_name=base_name, + args_to_check=[arg.name for arg in differentiable_outputs], + ) + ) + return body + + op = info.op if info is not None and info.has_derivatives else "NotImplemented" + setup = [] + if not is_inplace_foreach: + setup.extend( + ASSIGN_GRAD_FN.substitute( + op=op, + op_ctor="" + if info is not None and info.has_derivatives + else f'"{cpp.name(f.func)}"', + args_with_derivatives=[arg.name for arg in args_with_derivatives], + ).split("\n") + ) + else: + # note(crcrpar): Assuming in-place foreach function's self_arg is always TensorList. + list_like_arg = "self" + args = [arg.name for arg in args_with_derivatives] + for i, arg in enumerate(args): + if is_inplace_foreach and info is not None: + if arg in refargname2inplace_foreacharg: + foreach_arg = refargname2inplace_foreacharg[arg] + args[i] = foreach_arg.name + ( + "[i]" if isinstance(foreach_arg.type, ListType) else "" + ) + else: + if arg == list_like_arg: + args[i] = arg + "[i]" + setup.extend( + ASSIGN_VECTOR_OF_GRAD_FN.substitute( + op=op, + op_ctor="" + if info is not None and info.has_derivatives + else f'"{cpp.name(f.func)}"', + args_with_derivatives=args, + irange=f"{list_like_arg}.size()", + ).split("\n") + ) + setup.extend(emit_save_inputs()) + + body.extend( + emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives) + ) + declare_grad_fn_template = ( + DECLARE_GRAD_FN if not is_inplace_foreach else DECLARE_VECTOR_OF_GRAD_FN + ) + body.append(declare_grad_fn_template.substitute(op=op)) + body.append(SETUP_DERIVATIVE.substitute(setup=setup)) + return body + + def emit_check_if_in_complex_autograd_allowlist() -> list[str]: + body: list[str] = [] + if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX: + return body + for arg in differentiable_outputs: + name = arg.name + # TODO: should be `arg.type.is_tensor_like()`? + if arg.cpp_type == "at::Tensor" or arg.cpp_type in TENSOR_LIST_LIKE_CTYPES: + body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");') + return body + + def emit_check_no_requires_grad( + tensor_args: list[DifferentiableInput], + args_with_derivatives: list[DifferentiableInput], + ) -> list[str]: + """Checks that arguments without derivatives don't require grad""" + body: list[str] = [] + for arg in tensor_args: + if arg in args_with_derivatives: + continue + arg_name = arg.name + if info and arg_name in info.non_differentiable_arg_names: + continue + if arg_name == "output": + # Double-backwards definitions sometimes take in 'input' and + # 'output', but only define the derivative for input. + continue + body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");') + return body + + def emit_original_self_definition() -> list[str]: + body: list[str] = [] + if inplace: + if is_inplace_foreach: + body.append( + "std::vector<::std::optional> original_selfs(self.size());" + ) + else: + body.append("::std::optional original_self;") + + all_forward_grad_cond = [] + for derivative in fw_derivatives: + if derivative.required_original_self_value: + all_forward_grad_cond.append( + get_any_has_forward_grad_name(derivative.var_names) + ) + + if all_forward_grad_cond: + if not is_inplace_foreach: + body.append(f'if ({" || ".join(all_forward_grad_cond)}) {{') + body.append(" original_self = self.clone();") + body.append("}") + else: + current_all_forward_grad_cond = [ + f"{cond}[i]" for cond in all_forward_grad_cond + ] + body.append("for (const auto& i : c10::irange(self.size())) {") + body.append( + f" if ({' || '.join(current_all_forward_grad_cond)}) {{" + ) + body.append(" original_selfs[i] = self[i].clone();") + body.append(" }") + body.append("}") + + return body + + def save_variables( + saved_variables: Sequence[SavedAttribute], + is_output: bool, + guard_for: Callable[[SavedAttribute], str | None] = lambda name: None, + ) -> Sequence[str]: + # assign the saved variables to the generated grad_fn + stmts: list[str] = [] + for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)): + name = ( + arg.nctype.name.name + if isinstance(arg.nctype.name, SpecialArgName) + else arg.nctype.name + ) + foreacharg: Argument | None = None + is_foreacharg_list_type: bool = False + type = arg.nctype.type + expr = arg.expr + stmts_prepend = None + if is_inplace_foreach and info is not None: + # todo(crcrpar): See if we can add some check e.g. `assert foreacharg is not None`. + # for now the example assert would fail. + name_to_query = name.split("_scalar_type")[0] + if name_to_query in refargname2inplace_foreacharg: + foreacharg = refargname2inplace_foreacharg[name_to_query] + is_foreacharg_list_type = isinstance(foreacharg.type, ListType) + if foreacharg is not None: + name_in_expr = ( + f"{foreacharg.name}{'[i]' if is_foreacharg_list_type else ''}" + ) + src_name = name + if "_scalar_type" in src_name: + split_src_name = src_name.split("_scalar_type") + assert len(split_src_name) == 2 + src_name = split_src_name[0] + expr = expr.replace(src_name, name_in_expr) + if ( + type == BaseCType(tensorT) + or type == OptionalCType(BaseCType(tensorT)) + or type == MutRefCType(OptionalCType(BaseCType(tensorT))) + or (is_output and type == BaseCType(scalarT)) + ): + # note(crcrpar): Here `expr` is generated from scratch, `arg.expr` is ignored. + var = name + name += "_" + if var == "self" and inplace: + original_self_var = ( + "original_self" + if not is_inplace_foreach + else "original_selfs[i]" + ) + self_var = var if not is_inplace_foreach else var + "[i]" + stmts_prepend = f"if (!{original_self_var}.has_value()) {original_self_var} = {self_var}.clone()" + var = f"{original_self_var}.value()" + assert not is_output + if inplace and is_output: + assert name == "result_" + var = ( + "self[i]" + if is_inplace_foreach or is_foreacharg_list_type + else "self" + ) + is_inplace_view = f"{var}.is_view()" + expr = f"SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})" + else: + expr = f"SavedVariable({var}, {str(is_output).lower()})" + if foreacharg is not None and "original_selfs" not in expr: + expr = expr.replace(src_name, name_in_expr) + elif ( + type == BaseCType(tensorListT) + or type == ListCType(OptionalCType(BaseCType(tensorT))) + or type == BaseCType(iTensorListRefT) + or type == VectorCType(BaseCType(tensorT)) + ): + # See Note [nuanced return type of out-of-place foreach functions] + if type == VectorCType(BaseCType(tensorT)): + assert is_foreach and is_output + expr = f"make_saved_variable_list({name}, {str(is_foreach and is_output).lower()})" + name += "_" + elif type == BaseCType(intArrayRefT): + expr = expr + ".vec()" + elif type == BaseCType(symIntArrayRefT): + expr = expr + ".vec()" + elif type == BaseCType(stringT): + expr = f"std::string({expr})" + elif type == OptionalCType(BaseCType(stringT)): + expr = f"{expr}.has_value() ? ::std::optional(std::string({expr}.value())) : ::std::nullopt" + elif type == ArrayRefCType( + elem=BaseCType(type=BaseCppType(ns="at", name="Scalar")) + ): + expr = expr + ".vec()" + + guard = guard_for(arg) + if guard is None: + if stmts_prepend: + stmts.append(f"{stmts_prepend};") + stmts.append(f"grad_fn->{name} = {expr};") + else: + stmts.append(f"if ({guard}) {{") + if stmts_prepend: + stmts.append(f" {stmts_prepend};") + stmts.append(f" grad_fn->{name} = {expr};") + stmts.append("}") + return stmts + + # Generates a Dispatcher::redispatch() call into the dispatcher. We do this mainly for performance reasons: + # - Pre-compute the full DispatchKeySet. This saves the dispatcher from having to read from TLS. + # - redispatch() avoids a redundant call to RecordFunction, which was already called right before + # we entered this autograd kernel. + def emit_dispatch_call( + f: NativeFunction, input_base: str, unpacked_args: Sequence[str] + ) -> str: + """Dispatch call via function in a namespace or method on Tensor.""" + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance. + # Ops also always have a function variant of the redispatch API. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::after_autograd_keyset" + call = CALL_REDISPATCH.substitute( + api_name=cpp.name( + f.func, + faithful_name_for_out_overloads=True, + symint_overload=f.func.has_symint(), + ), + unpacked_args=[dispatch_key_set] + list(unpacked_args), + ) + return call + + def wrap_output( + f: NativeFunction, unpacked_bindings: list[Binding], var: str + ) -> str: + call = "" + rhs_value: str | None = None + if not any(r.type.is_tensor_like() for r in f.func.returns): + rhs_value = var + else: + rhs_value = f"std::move({var})" + assert rhs_value is not None + call += ASSIGN_RETURN_VALUE.substitute( + return_values=tie_return_values(f), rhs_value=rhs_value + ) + return call + + def check_tensorimpl_and_storage( + call: str, unpacked_bindings: list[Binding] + ) -> str: + # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] + stmts_before_call: list[str] = [] + stmts_after_call: list[str] = [] + + if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: + return call + + # Check properties of inputs (enforce (1)) + for unpacked_binding in unpacked_bindings: + arg = unpacked_binding.name + noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref() + if noref_cpp_type == BaseCType(tensorListT) or noref_cpp_type == BaseCType( + iTensorListRefT + ): + stmts_before_call += [ + SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), + SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg), + ] + stmts_after_call += [ + ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), + ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg), + ] + elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))): + stmts_before_call += [ + SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), + SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg), + ] + stmts_after_call += [ + ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute( + tensorlist_name=arg + ), + ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute( + tensorlist_name=arg + ), + ] + elif noref_cpp_type == BaseCType(tensorT): + stmts_before_call += [ + SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), + SAVE_TENSOR_IMPL.substitute(tensor_name=arg), + ] + stmts_after_call += [ + ENFORCE_SAME_TENSOR_STORAGE.substitute( + tensor_name=arg, out_tensor_name=arg + ), + ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg), + ] + + assert (stmts_before_call and stmts_after_call) or ( + not stmts_before_call and not stmts_after_call + ) + + # Check properties of outputs (enforce (2), (3)) + if f.func.kind() not in (SchemaKind.inplace, SchemaKind.out): + base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)? + aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None) + if aliased_arg_name is not None: + aliased_arg_name = unpacked_name(aliased_arg_name) + for i, (ret, ret_name) in enumerate( + zip(f.func.returns, cpp.return_names(f)) + ): + noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref() + if noref_cpp_type == BaseCType(tensorT): + if aliased_arg_name is not None: + assert ( + i == 0 + ), "Expect non-CompositeImplicitAutograd view function {base} to return single output" + stmts_after_call += [ + ENFORCE_SAME_TENSOR_STORAGE.substitute( + tensor_name=aliased_arg_name, out_tensor_name=ret_name + ) + ] + else: + if ( + type_wrapper_name(f) + not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT + ): + stmts_after_call += [ + ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.substitute( + tensor_name=ret_name, fn_name=type_wrapper_name(f) + ) + ] + + if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: + stmts_after_call += [ + ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute( + tensor_name=ret_name, fn_name=type_wrapper_name(f) + ) + ] + + # Currently we don't have any functions that return the following types, but + # we should update the checks once we do + elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))): + raise AssertionError( + f"Please add use_count checks for {noref_cpp_type}" + ) + elif noref_cpp_type == BaseCType(tensorListT): + raise AssertionError( + f"Please add use_count checks for {noref_cpp_type}" + ) + + if stmts_before_call and stmts_after_call: + call = ( + RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + + call + + RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call) + ) + return call + + def emit_call( + f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool + ) -> str: + # We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch + # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure + # the baseType operations still dispatch to non-Variable type, even if the arguments passed + # in are now Variables. + # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. + unpacked_args = [b.name for b in unpacked_bindings] + base_type_call = emit_dispatch_call(f, "self_", unpacked_args) + + if get_view_info(f) is not None or modifies_arguments(f): + guard = "at::AutoDispatchBelowAutograd guard;" + else: + guard = "at::AutoDispatchBelowADInplaceOrView guard;" + + any_has_forward_grad = ( + get_any_has_fw_grad_cond(derivative=None) + if requires_derivative + else "false" + ) + return_types = ", ".join( + [cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns] + ) + if len(f.func.returns) > 1: + return_types = f"std::tuple<{return_types}>" + + arg_names = [ + a.name + for a in cpp.arguments( + f.func.arguments, + faithful=True, + symint=True, + method=False, + cpp_no_default_args=set(), + ) + ] + + if not modifies_arguments(f) and not returns_void: + if try_jit_decomposition: + call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP.substitute( + base_type_call=base_type_call, + tmp_var=TMP_VAR, + guard=guard, + any_has_forward_grad=any_has_forward_grad, + op_name=cpp.name(f.func), + op_overload=f.func.name.overload_name, + return_types=return_types, + arg_names=arg_names, + ) + else: + call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( + base_type_call=base_type_call, + tmp_var=TMP_VAR, + guard=guard, + ) + + call += wrap_output(f, unpacked_bindings, TMP_VAR) + else: + assert not try_jit_decomposition + call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( + base_type_call=base_type_call, guard=guard + ) + call = check_tensorimpl_and_storage(call, unpacked_bindings) + return call + + def emit_history() -> str: + fn = "rebase" if modifies_arguments(f) and view_info is None else "set" + output_names = [r.name for r in differentiable_outputs] + # TODO: flatten allocates a std::vector, which could be expensive + outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute( + outs=output_names if not is_inplace_foreach else "self" + ) + if not is_inplace_foreach: + return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs) + else: + return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute( + preamble=( + f"auto differentiable_outputs = {outs};\n" + f"TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());" + ), + statements=f"{fn}_history(differentiable_outputs[i], grad_fns[i]);", + ) + + def emit_save_outputs() -> str: + if is_out_fn: + # out functions don't currently support differentiation + return "" + if info is not None and info.has_derivatives: + stmts = save_variables(info.all_saved_outputs, True) + if len(stmts) == 0: + return "" + if not is_inplace_foreach: + return CONDITIONAL.substitute(cond="grad_fn", statements=stmts) + else: + return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute( + preamble="", statements=stmts + ) + return "" + + def emit_any_requires_grad() -> list[str]: + extra_condition = "" + if info and info.output_differentiability_conditions: + assert len(info.output_differentiability_conditions) == 1 + extra_condition = f"_any_requires_grad &= ({info.output_differentiability_conditions[0]});" + names_of_args_with_derivatives = [arg.name for arg in args_with_derivatives] + if is_inplace_foreach and info is not None: + for i, arg in enumerate(names_of_args_with_derivatives): + for f_arg, r_arg in inplace_foreacharg2refarg.items(): + if arg == r_arg.name: + names_of_args_with_derivatives[i] = f_arg.name + return [ + SETUP_ANY_REQUIRES_GRAD.substitute( + args_with_derivatives=names_of_args_with_derivatives, + extra_differentiability_conditions=extra_condition, + ) + ] + + def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str: + if len(var_names) == 1: + return f"_any_has_forward_grad_{var_names[0]}" + else: + return f'_any_has_forward_grad_{"_".join(var_names)}' + + def emit_any_has_forward_grad() -> list[str]: + content: list[str] = [] + if not is_foreach: + for derivative in fw_derivatives: + requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative) + if info and info.output_differentiability_conditions: + assert len(info.output_differentiability_conditions) == 1 + requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}" + content.append( + f"[[maybe_unused]] auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};" + ) + else: + for derivative in fw_derivatives: + bool_vector_name = get_any_has_forward_grad_name(derivative.var_names) + cur_derivative_conditions = [] + for inp in differentiable_inputs: + if derivative.required_inputs_fw_grad is None: + continue + if inp.name not in derivative.required_inputs_fw_grad: + continue + inp_name = ( + inp.name + if not inplace + else refargname2inplace_foreacharg[inp.name].name + ) + inp_type = ( + inp.type + if not inplace + else refargname2inplace_foreacharg[inp.name].type + ) + is_list_type = is_tensor_list_type(inp_type) + if is_list_type: + if inp_name != "self": + content.append( + FW_DERIVATIVE_SIZE_CHECK_TEMPLATE.substitute( + inp_name=inp_name + ) + ) + cur_derivative_conditions.append( + FW_DERIVATIVE_CHECK_TEMPLATE.substitute( + req_inp=inp_name + "[i]" + ) + ) + else: + cur_derivative_conditions.append( + FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name) + ) + + content.append(f"std::vector {bool_vector_name}(self.size());") + content.append("for (const auto& i : c10::irange(self.size())) {") + content.append( + f" {bool_vector_name}[i] = {' || '.join(cur_derivative_conditions)};" + ) + content.append("}") + return content + + def emit_check_inplace() -> list[str]: + if not inplace: + return [] + return [ + f"check_inplace({arg.name}, _any_requires_grad);" + for arg in differentiable_outputs + ] + + def emit_fw_derivatives() -> list[str]: + content: list[str] = [] + fw_grad_setters: list[str] = [] + for derivative in fw_derivatives: + res = derivative.var_names + if f.func.name.name.inplace: + assert ( + len(res) == 1 + ), "Expected number of outputs to be 1 if function is inplace" + # TODO update this when inplace namings are unified + res = ("self",) + + assert derivative.required_inputs_fw_grad is not None + + unpacked_arguments = "" + for inp in differentiable_inputs: + inp_name = inp.name + is_input_tensorlist = is_foreach and is_tensor_list_type( + inp.type + if not inplace + else refargname2inplace_foreacharg[inp.name].type + ) + input_suffix = "[i]" if is_input_tensorlist else "" + if is_inplace_foreach: + if inp.name in refargname2inplace_foreacharg: + inp_name = refargname2inplace_foreacharg[inp.name].name + zeros_fn = ( + "zeros_symint" + if inplace and inp.name == "self" + else "_efficientzerotensor_symint" + ) + if inp.name in derivative.required_inputs_fw_grad: + unpacked_arguments += ( + FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute( + inp_name=inp.name, + inp=inp_name + input_suffix, + zeros_fn=zeros_fn, + ) + ) + if inp.name in (derivative.required_inputs_primal or []): + unpacked_arguments += ( + FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute( + inp_name=inp.name, + inp=inp_name + input_suffix, + ) + ) + if derivative.required_original_self_value: + input_suffix = "s[i]" if is_inplace_foreach else "" + unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute( + inp_name="original_self", + inp="original_self" + input_suffix, + zeros_fn=zeros_fn, + ) + unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute( + inp_name="original_self", + inp="original_self" + input_suffix, + ) + elif inplace and derivative.is_reusing_outplace_formula: + # The gradient wasn't already cloned, do it if grad mode is enabled + unpacked_arguments += ( + "self_t = GradMode::is_enabled() ? self_t.clone() : self_t;" + ) + + if inplace: + is_inplace_str = "true" + else: + is_inplace_str = "false" + + requires_fw_grad = get_any_has_forward_grad_name(derivative.var_names) + + if all( + (isinstance(var_type, BaseType) and var_type.is_tensor_like()) + for var_type in derivative.var_types + ): + # Is there a way to get from BaseType to BaseCType + if len(derivative.var_types) == 1: + opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type() + if not is_foreach: + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + else: + assert res[0] == ("result" if not inplace else "self") + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + requires_fw_grad += f" && ({derivative.var_names[0]}.defined())" + else: + tuple_type = TupleCType( + [BaseCType(tensorT)] * len(derivative.var_types) + ) + opt_res_grad_type = OptionalCType(tuple_type).cpp_type() + for idx, single_res in enumerate(res): + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_MULTI_OUTPUT.substitute( + idx=idx, all_res="_".join(res), out_arg=single_res + ) + ) + elif ( + isinstance(derivative.var_types[0], ListType) + and derivative.var_types[0].is_tensor_like() + ): + assert ( + len(derivative.var_types) == 1 + ), "Expected number of outputs to be 1 if function returns ListType" + if not is_foreach: + opt_res_grad_type = OptionalCType( + VectorCType(BaseCType(tensorT)) + ).cpp_type() + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + else: + # TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow? + # Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml` + # can reach here. + opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type() + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + else: + raise RuntimeError("Unsupported output type for forward derivative") + + if not is_foreach: + fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = ::std::nullopt;" + # View ops create fw_grad that already is a view of the base's fw_grad so just use that + content.append( + FW_DERIVATIVE_TEMPLATE.substitute( + fw_grad_opt_definition=fw_grad_opt_definition, + requires_fw_grad=requires_fw_grad, + formula=derivative.formula, + out_arg="_".join(res), + unpacked_arguments=unpacked_arguments, + ) + ) + else: + # note(crcrpar): Assuming `self` is TensorList. + fw_grad_opt_definition = ( + f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts" + "(self.size(), ::std::nullopt);" + ) + foreach_forward_grad_formula = derivative.formula + _foreach_arg: Argument | DifferentiableInput + if inplace: + for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items(): + # note(crcrpar): Massage only Scalar and ArrayRef here. + if not ( + is_tensor_type(_foreach_arg.type) + or is_tensor_list_type(_foreach_arg.type) + ): + pattern = _foreach_arg.name + if isinstance(_foreach_arg.type, ListType): + pattern += "[i]" + foreach_forward_grad_formula = ( + foreach_forward_grad_formula.replace( + _ref_arg.name, pattern + ) + ) + else: + if ( + "result" in foreach_forward_grad_formula + and "result[i]" not in foreach_forward_grad_formula + ): + foreach_forward_grad_formula = ( + foreach_forward_grad_formula.replace("result", "result[i]") + ) + + content.append( + FW_DERIVATIVE_FOREACH_TEMPLATE.substitute( + fw_grad_opt_definition=fw_grad_opt_definition, + vector_of_optional_tensor=f"{'_'.join(res)}_new_fw_grad_opts", + any_has_forward_grad_for_current_index=" || ".join( + get_any_has_forward_grad_name(derivative.var_names) + "[i]" + for derivative in fw_derivatives + ), + formula=foreach_forward_grad_formula, + unpacked_arguments=unpacked_arguments, + ) + ) + + # Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367 + content.append("\n".join(fw_grad_setters)) + return content + + def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str: + # + # Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)") + # + if derivative is None: + # (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs + # - Used in the out_fn case when we want to forbid fw derivatives + # - Used in the case where the fw_derivative is not defined, but we want + # To check if there is a decomposition registered for jvp + to_check: list[str] = [] + for inp in list( + mapMaybe( + gen_differentiable_input, + f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator] + ) + ): + if is_tensor_type(inp.type): + to_check.append( + FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name) + ) + elif is_tensor_list_type(inp.type): + to_check.append( + FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute( + req_inp=inp.name + ) + ) + else: + raise RuntimeError( + f'Unsupported input type for "{name}" when forbidding forward AD usage.' + ) + return f'({" || ".join(to_check)})' + else: + # (2) If derivative is provided, use that information to determine which inputs + # to check fw_grad for + assert derivative.required_inputs_fw_grad is not None + + if len(derivative.required_inputs_fw_grad) == 0: + # Handle functions like stack + # For these, we don't unpack anything and always call the user function + if not ( + len(differentiable_inputs) == 1 + and is_tensor_list_type(differentiable_inputs[0].type) + ): + raise RuntimeError( + f'No differentiable input to "{name}" is a differentiable Tensor (as the provided ' + "forward AD formula does not use any input tangent) even though a forward gradient " + "formula has been defined for it. This case should only happen for function that " + "take a single TensorList as input. All other cases are not supported right now." + ) + any_has_fw_grad = "true" + else: + any_has_fw_grad = " || ".join( + [ + ( + FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE + if is_tensor_list_type(inp.type) + else FW_DERIVATIVE_CHECK_TEMPLATE + ).substitute(req_inp=inp.name) + for inp in differentiable_inputs + if inp.name in derivative.required_inputs_fw_grad + ] + ) + any_has_fw_grad = f"({any_has_fw_grad})" + + return any_has_fw_grad + + def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str: + if is_out_fn: + msg = "because it is an out= function" + else: + msg = ( + "because it has not been implemented yet.\\nPlease file an issue " + "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml " + "so that we can prioritize its implementation." + ) + cond = get_any_has_fw_grad_cond(derivative=None) + return ( + FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg) + if cond != "" + else "" + ) + + body: list[str] = [] + unpack_args_stats, unpacked_bindings = unpack_args(f) + + body.extend(unpack_args_stats) + if requires_derivative: + body.extend(emit_any_requires_grad()) + body.extend(emit_any_has_forward_grad()) + body.extend(emit_check_inplace()) + body.extend(emit_original_self_definition()) + body.extend(setup_derivative(differentiable_inputs)) + + body.append(emit_call(f, unpacked_bindings, try_jit_decomposition)) + if requires_derivative: + # set_flags has to appear after version_counter, because rebase_history + # requires that the counter is incremented before it is called + body.append(emit_history()) + body.extend(emit_check_if_in_complex_autograd_allowlist()) + + if is_out_fn: + body.append(emit_forbid_fw_derivatives(is_out_fn=True)) + else: + if requires_derivative and not try_jit_decomposition: + if len(fw_derivatives) > 0: + body.extend(emit_fw_derivatives()) + else: + body.append(emit_forbid_fw_derivatives()) + + if requires_derivative: + # Save only after the forward AD has been set up + body.append(emit_save_outputs()) + + if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR: + # `inplace` implies that there is exactly one output named `self`, + # so we can keep the generated code easy. If you need to + # `reset_grad_accumulator` in an operator that's not `inplace`, you can + # remove this assert but the code generation will get more elaborate + assert inplace + body.append("reset_grad_accumulator(self);") + if not returns_void: + body.append(f"return {get_return_value(f)};") + return body diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_view_funcs.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_view_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..245a77106dc65a2b9ab89c9006ff317eabf1ed1c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_view_funcs.py @@ -0,0 +1,340 @@ +# Generates ViewFuncs.h/cpp +# +# NOTE: If any changes are being made to the ViewFunc codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# The fallback is expected to mimic this codegen, so we should keep the two in sync. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torchgen.api.dispatcher as dispatcher +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + NamedCType, + SymIntT, + tensorT, + VectorCType, +) +from torchgen.code_template import CodeTemplate +from torchgen.model import Argument, NativeFunction, OptionalType +from torchgen.utils import FileManager + +from .gen_inplace_or_view_type import ( + CALL_DISPATCH, + extract_bindings, + get_view_info, + modifies_arguments, + use_derived, +) + + +if TYPE_CHECKING: + from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo + + +FUNCTION_DECLARATION = CodeTemplate( + """\ +#define ${uppercase_op}_AVAILABLE +struct ${op} : public ${superclass} { + ${op}(${constructor_args}) ${initializer_list} + {}; + virtual ~${op}() override {}; + virtual std::vector get_symints() const override; + virtual size_t num_symints() const override; + virtual std::vector get_tensors() const override; + virtual size_t num_tensors() const override; + virtual at::Tensor operator()(const at::Tensor&) const override; + virtual std::unique_ptr clone_and_set( + std::optional> = ::std::nullopt, + std::optional> = ::std::nullopt) const override; + +protected: + virtual void set_symints(std::vector) override; + virtual void set_tensors(std::vector) override; + +private: + ${state} +}; + +""" +) + +FUNCTION_DEFINITION = CodeTemplate( + """\ +std::vector ${op}::get_symints() const { + ${get_symints} +} + +size_t ${op}::num_symints() const { + return static_cast(${num_symints}); +} + +void ${op}::set_symints(std::vector ${symints_vec}) { + TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints()); + ${set_symints} +} + +std::vector ${op}::get_tensors() const { + ${get_tensors} +} + +size_t ${op}::num_tensors() const { + return static_cast(${num_tensors}); +} + +void ${op}::set_tensors(std::vector ${tensors_vec}) { + TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors()); + ${set_tensors} +} + +at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const { + return ${op_call}; +} + +std::unique_ptr ${op}::clone_and_set( + std::optional> ${symints_vec}, + std::optional> ${tensors_vec}) const { + auto output = std::make_unique<${op}>(${clone_args}); + if (${symints_vec}.has_value()) { + output->set_symints(std::move(*(${symints_vec}))); + } + if (${tensors_vec}.has_value()) { + output->set_tensors(std::move(*(${tensors_vec}))); + } + return output; +} + +""" +) + + +# e.g. as_strided -> AsStridedViewFunc for camel case or +# as_strided_view_func otherwise +def view_func_name( + f: NativeFunction, include_namespace: bool = False, camel_case: bool = True +) -> str: + name = f.func.name.unambiguous_name() + view_func_name = f"{name.replace('.', '_')}_view_func" + if camel_case: + is_private = view_func_name.startswith("_") + view_func_name = "".join( + [p.title() for p in view_func_name.replace(".", "_").split("_")] + ) + if is_private: + # put the leading underscore back in + view_func_name = f"_{view_func_name}" + namespace = "torch::autograd::generated::" if include_namespace else "" + return f"{namespace}{view_func_name}" + + +def is_symint_or_tensor(arg: Argument) -> bool: + return arg.type.is_tensor_like() or arg.type.is_symint_like() + + +def remove_const_ref(binding: Binding) -> Binding: + return Binding( + name=binding.name, + nctype=binding.nctype.remove_const_ref(), + argument=binding.argument, + default=binding.default, + ) + + +def returns_multi_tensor(fn: NativeFunction) -> bool: + returns = fn.func.returns + assert len(returns) == 1 + returns_list_like = returns[0].type.is_list_like() is not None + returns_tensor_like = returns[0].type.is_tensor_like() + return returns_list_like and returns_tensor_like + + +# Generates strings with logic for getting / setting state of a particular type. +# +# Args: +# bindings (list): List of state bindings of interest (may be empty) +# state_vec_type (NamedCType): Type of vector to either return or copy from +# +# Returns: +# tuple: (list of getter logic strings, list of setter logic strings, string +# with num items expression) +def generate_state_getter_setter( + bindings: list[Binding], + state_vec_type: NamedCType, +) -> tuple[list[str], list[str], str]: + getter_logic = [] + setter_logic = [] + + state_vec = state_vec_type.name + getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};") + if len(bindings) > 0: + setter_logic.append("auto i = 0;") + + num_exprs = [] + for i, b in enumerate(bindings): + assert isinstance(b.argument, Argument) + if b.argument.type.is_list_like(): + # Handle list-likes. + num_expr = f"{b.name}.size()" + num_exprs.append(num_expr) + getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());" + setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());" + elif isinstance(b.argument.type, OptionalType): + # Handle optionals. + num_expr = f"({b.name}.has_value() ? 1 : 0)" + num_exprs.append(num_expr) + conditional = f"if({b.name}.has_value())" + getter = ( + f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));" + ) + setter = f"{conditional} {b.name} = {state_vec}[i];" + else: + num_expr = "1" + num_exprs.append(num_expr) + getter = f"{state_vec}.push_back({b.name});" + setter = f"{b.name} = {state_vec}[i];" + + getter_logic.append(getter) + setter_logic.append(setter) + if i < len(bindings) - 1: + setter_logic.append(f"i += {num_expr};") + + # Reserve / assert based on the total number of items expression. + num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs) + if len(bindings) > 0: + getter_logic.insert(1, f"{state_vec}.reserve({num_items});") + + getter_logic.append(f"return {state_vec};") + + return getter_logic, setter_logic, num_items + + +def process_function(fn: NativeFunction, template: CodeTemplate) -> str: + bindings = extract_bindings(fn) + non_self_bindings = [b for b in bindings if b.name != "self"] + + non_self_args = fn.func.arguments.flat_all[1:] + non_self_value_bindings = [ + dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args + ] + + # Generate constructor / clone args for the generated struct. + constructor_args = [b.defn() for b in non_self_bindings] + clone_args = [b.name for b in non_self_bindings] + + # Generate state variable declarations for the generated struct. + state_variables = [ + f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings + ] + + # Generate initializer list expressions for the generated struct. + # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as + # vectors. + init_exprs = translate( + non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True + ) + initializers = [] + for b, init_expr in zip(non_self_bindings, init_exprs): + name = b.nctype.name + assert isinstance(name, str) + initializers.append(f"{name}({init_expr.expr})") + + # Generate call to underlying view op + call_input_name = "input_base" + op_call_args = [call_input_name, *(b.name for b in non_self_bindings)] + op_call = CALL_DISPATCH.substitute( + unambiguous_name=fn.func.name.unambiguous_name(), + unpacked_args=op_call_args, + ) + + # Multi-output views additionally require a view_idx for disambiguation. + if returns_multi_tensor(fn): + view_idx_name = "view_idx" + view_idx_typename = "int64_t" + view_idx_decl = f"{view_idx_typename} {view_idx_name}" + constructor_args.append(view_idx_decl) + clone_args.append(view_idx_name) + state_variables.append(f"{view_idx_decl};") + initializers.append(f"{view_idx_name}({view_idx_name})") + op_call += f"[{view_idx_name}]" + + # Generate initializer list for the generated struct. + initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else "" + + # Generate getter / setter logic for any symints. + symint_bindings = [ + b + for b in non_self_bindings + if isinstance(b.argument, Argument) and b.argument.type.is_symint_like() + ] + symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT))) + get_symints, set_symints, num_symints = generate_state_getter_setter( + symint_bindings, symints_vec_type + ) + + # Generate getter / setter logic for any tensors. + tensor_bindings = [ + b + for b in non_self_bindings + if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like() + ] + tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT))) + get_tensors, set_tensors, num_tensors = generate_state_getter_setter( + tensor_bindings, tensors_vec_type + ) + + return template.substitute( + op=view_func_name(fn), + uppercase_op=view_func_name(fn, camel_case=False).upper(), + superclass="torch::autograd::ViewFunc", + initializer_list=initializer_list, + state=state_variables, + constructor_args=constructor_args, + clone_args=clone_args, + symints_vec=symints_vec_type.name, + get_symints=get_symints, + set_symints=set_symints, + num_symints=num_symints, + tensors_vec=tensors_vec_type.name, + get_tensors=get_tensors, + set_tensors=set_tensors, + num_tensors=num_tensors, + call_input_name=call_input_name, + op_call=op_call, + ) + + +def gen_view_funcs( + out: str, + fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], + template_path: str, +) -> None: + # don't need the info parts, just the function + fns = [fn.func for fn in fns_with_infos if use_derived(fn)] + # only want out-of-place views + view_fns = [ + fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn) + ] + + declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns] + definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns] + ops_headers = [f"#include " for fn in view_fns] + + file_basename = "ViewFuncs" + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + for suffix in [".h", ".cpp"]: + fname = file_basename + suffix + fm.write_with_template( + fname, + fname, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/" + + fname, + "view_func_declarations": declarations, + "view_func_definitions": definitions, + "ops_headers": ops_headers, + }, + ) diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/load_derivatives.py b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/load_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..645a569c45e3dc9877f61b4329d2434fe987cf76 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/load_derivatives.py @@ -0,0 +1,1014 @@ +# Parses derivatives.yaml into autograd functions +# +# Each autograd function is represented by `DifferentiabilityInfo` containing +# a list of `Derivative`. See `torchgen.api.autograd` for the data models. + +from __future__ import annotations + +import re +from collections import defaultdict +from typing import Any, Counter, Dict, Sequence, Set, Tuple + +import yaml + +from torchgen.api import cpp +from torchgen.api.autograd import ( + Derivative, + DifferentiabilityInfo, + ForwardDerivative, + SavedAttribute, +) +from torchgen.api.types import ( + BaseCType, + Binding, + boolT, + CppSignatureGroup, + layoutT, + longT, + NamedCType, + OptionalCType, + scalarTypeT, + SpecialArgName, + stringT, + symIntArrayRefT, + SymIntT, + tensorGeometryT, + tensorOptionsT, + typeAndSizeT, + VectorCType, +) +from torchgen.context import with_native_function +from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml +from torchgen.model import ( + AUTOGRAD_KEYS, + FunctionSchema, + NativeFunction, + NativeFunctionsViewGroup, + OperatorName, + SchemaKind, + Type, + Variant, +) +from torchgen.utils import concatMap, IDENT_REGEX, split_name_params +from torchgen.yaml_utils import YamlLoader + + +DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]] + +_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {} + +_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS) + + +# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op. +# Since every {view} and {view}_copy op shares the same derivative formula, +# we generate them here instead of duplicating them in the yaml. +# See Note [Codegen'd {view}_copy Operators] +def add_view_copy_derivatives( + infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], + view_groups: list[NativeFunctionsViewGroup], +) -> None: + # Get the map from each view op's name to its corresponding view group + view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = { + g.view.func.name: g for g in view_groups + } + + view_infos = {} + + for info_dispatch_dict in infos.values(): + # maybe_view_group only needs to be calculated once per info_dispatch_dict + maybe_view_group = None + view_copy_differentiability_infos = {} + for dispatch_key, info in info_dispatch_dict.items(): + maybe_view_group = view_name_to_group.get(info.func.func.name, None) + if maybe_view_group is not None and maybe_view_group.view_copy is not None: + view_copy_info = info.create_view_copy_from_view_derivative( + maybe_view_group + ) + if view_copy_info is not None: + fn_schema = view_copy_info.func.func + view_copy_differentiability_infos[dispatch_key] = view_copy_info + else: + break + # prefer manually-defined derivatives if any + if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos: + assert fn_schema is not None + view_infos[fn_schema] = view_copy_differentiability_infos + + infos.update(view_infos) + + +def load_derivatives( + derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str +) -> DerivativeRet: + # Do some caching as this is a deterministic function + global _GLOBAL_LOAD_DERIVATIVE_CACHE + key = (derivatives_yaml_path, native_yaml_path) + if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE: + with open(derivatives_yaml_path) as f: + definitions = yaml.load(f, Loader=YamlLoader) + + funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions + # From the parsed native functions, separate out the (generated) view_copy functions, + # so we can generate derivatives for them separately. + native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs) + native_functions = concatMap( + lambda g: [g] + if isinstance(g, NativeFunction) + else list(g.functions(include_copy=True)), + native_functions_with_view_groups, + ) + view_groups = [ + g + for g in native_functions_with_view_groups + if isinstance(g, NativeFunctionsViewGroup) + ] + + # What's the difference between function schema v.s. signature? + # function schema is the complete declaration including mutability annotation / default value and etc. + # signature is the canonical schema for a group of functions (in-place/out/functional variants) + # that are semantically related. + functions_by_signature: dict[ + FunctionSchema, list[NativeFunction] + ] = defaultdict(list) + functions_by_schema: dict[str, NativeFunction] = {} + for function in native_functions: + functions_by_signature[function.func.signature()].append(function) + assert str(function.func) not in functions_by_schema + functions_by_schema[str(function.func)] = function + + # Keep track of how many of which ops we've seen so we can + # disambiguate them with a numeric suffix. + op_counter = Counter[str]() + + # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos + # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info + # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema + infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {} + used_dispatch_keys: set[str] = set() + for defn_dict in definitions: + # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded. + if "dispatch" not in defn_dict: + specification = defn_dict.pop("name") + output_differentiability = defn_dict.pop( + "output_differentiability", None + ) + defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}} + if output_differentiability: + defn_dict["output_differentiability"] = output_differentiability + name, per_dispatch_diffinfos = create_differentiability_info( + defn_dict, + functions_by_signature, + functions_by_schema, + op_counter, + used_dispatch_keys, + ) + infos[name] = per_dispatch_diffinfos + + add_view_copy_derivatives(infos, view_groups) + + # cache both loaded infos as well a a set of all the dispatch_keys/aliases + # that appear in derivatives.yaml. used_dispatch_keys is useful for generating + # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used + _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys + + return _GLOBAL_LOAD_DERIVATIVE_CACHE[key] + + +# TODO: Why is this going through CppSignatureGroup, that doesn't make sense... +@with_native_function +def cpp_arguments(f: NativeFunction) -> Sequence[Binding]: + sigs = CppSignatureGroup.from_native_function(f, method=False) + if sigs.symint_signature is not None: + return sigs.symint_signature.arguments() + else: + return sigs.signature.arguments() + + +def create_derivative( + f: NativeFunction, + formula: str, + var_names: tuple[str, ...], + available_named_gradients: Sequence[str], +) -> Derivative: + original_formula = formula + arguments: list[NamedCType] = [ + a.nctype.remove_const_ref() for a in cpp_arguments(f) + ] + + return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f)) + return_types = tuple( + cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns + ) + + named_returns = [ + NamedCType(name, type) for name, type in zip(return_names, return_types) + ] + + formula, saved_inputs = saved_variables(formula, arguments, var_names) + formula, saved_outputs = saved_variables(formula, named_returns, var_names) + + used_named_gradients = { + name + for name in available_named_gradients + if re.search(IDENT_REGEX.format(name), formula) + } + + # Check that the referenced derivatives in the formula are in bounds + for i in used_gradient_indices(formula): + if i >= len(f.func.returns): + raise RuntimeError( + f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} " + f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs." + ) + + return Derivative( + formula=formula, + original_formula=original_formula, + var_names=var_names, + saved_inputs=saved_inputs, + saved_outputs=saved_outputs, + named_gradients=used_named_gradients, + ) + + +def create_forward_derivative( + f: NativeFunction, formula: str, names: tuple[str, ...] +) -> ForwardDerivative: + var_names = names + var_types: tuple[Type, ...] | None = None + for r in f.func.returns: + if r.name in var_names: + if var_types is None: + var_types = () + var_types = var_types + (r.type,) + + # Handle default return names + if var_types is None: + if var_names == ("result",): + assert len(f.func.returns) == 1 + var_types = (f.func.returns[0].type,) + else: + for var_name in var_names: + res = re.findall(r"^result(\d+)$", var_name) + if len(res) == 1: + if var_types is None: + var_types = () + arg_idx = int(res[0]) + var_types = var_types + (f.func.returns[arg_idx].type,) + + assert var_types is not None, "No matching output for forward derivative definition" + return ForwardDerivative( + formula=formula, + var_names=var_names, + var_types=var_types, + required_inputs_fw_grad=None, + required_inputs_primal=None, + required_original_self_value=False, + is_reusing_outplace_formula=False, + ) + + +def postprocess_forward_derivatives( + f: NativeFunction, + defn_name: str, + all_arg_names: list[str], + derivatives: list[Derivative], + forward_derivatives: list[ForwardDerivative], + args_with_derivatives: Sequence[Binding], +) -> list[ForwardDerivative]: + def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]: + is_foreach = f.func.name.name.base.startswith("_foreach_") + required_inputs = set() + for arg in args_with_derivatives: + if ( + arg.type in ("at::TensorList", "const at::ITensorListRef &") + and not is_foreach + ): + # The functions taking TensorList handle everything internally + continue + arg_name = arg.name + + found = re.search(IDENT_REGEX.format(arg_name), formula) + if found: + raise RuntimeError( + f"The forward formula for {defn_name} is using the base name of the {arg_name} " + f"argument which is ambiguous. You should use {arg_name}_p to access the primal " + f"value and {arg_name}_t to access the tangent." + ) + + found = re.search(IDENT_REGEX.format(arg_name + postfix), formula) + if found: + required_inputs.add(arg_name) + + return tuple(required_inputs) + + updated_derivatives: list[ForwardDerivative] = [] + + for defn in forward_derivatives: + formula = defn.formula + required_inputs_tangent = find_required_inputs(formula, "_t") + if formula == "auto_element_wise": + assert ( + f.func.kind() != SchemaKind.inplace + ), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant" + if ( + (not len(args_with_derivatives) == 1) + or len(forward_derivatives) > 1 + or len(forward_derivatives[0].var_names) > 1 + ): + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml defines the " + "forward definition of gradient as element_wise but this only " + "works for functions with a single differentiable input and a " + "single differentiable output." + ) + if not len(derivatives) == 1: + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml defines the " + "forward definition of gradient as element_wise but it does not " + "defines the gradient formula for its argument which is required." + ) + # This transformation is based on the observation that for element-wise functions, the Jacobian + # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions) + # For the complex case, we use hermitian transpose and get (v.conj() J).conj() + # So here we are going to re-use the backward formula and replace two things: + # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input. + # 2) all usage of an original input "foo" with its primal value "foo_p". + # 3) conjugate the final result + # For example, for abs, the backward formula is: + # grad * self.sgn() + # And this function generates a forward formula that is: + # (self_t.conj() * self_p.sgn()).conj() + + backward_formula = derivatives[0].original_formula + input_name = args_with_derivatives[0].name + + # Do replacement 1) of the grad + def repl(m: Any) -> str: + return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}" + + fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula) + + # Do replacement 2) of the input variables + for arg in args_with_derivatives: + arg_name = arg.name + + def repl(m: Any) -> str: + return f"{m.group(1)}{arg_name}_p{m.group(2)}" + + fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula) + + # Do the final conjugate 3) + fw_formula = f"({fw_formula}).conj()" + + # Since there is a single differentiable inputs and we necessarily need its tangent we can + # simply require all differentiable input's tangent. + required_inputs_tangent = tuple(all_arg_names) + formula = fw_formula + elif formula == "auto_linear": + if ( + len(forward_derivatives) > 1 + or len(forward_derivatives[0].var_names) > 1 + ): + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml defines the " + "forward definition of gradient as linear but this only works " + "for functions with a single differentiable output." + ) + # This transformation is based on the observation that linear functions can be written as: + # y = f(x) = A * x + # For some matrix A and the Jacobian of the function f is also A. + # So doing J * v = A * v = f(v). + # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x. + # We do this by calling the forward again by replacing any occurrence of the differentiable + # input "foo" by it's tangent "foo_t". + # Note that multiple inputs are not a problem as long as the function is truly linear wrt to + # the vector where all the differentiable inputs are stacked. + + diff_arg_names = [arg.name for arg in args_with_derivatives] + assert len(diff_arg_names) > 0 + + # Do replacement of input variables + new_args = [] + for arg_name in all_arg_names: + if arg_name in diff_arg_names: + arg_name = arg_name + "_t" + new_args.append(arg_name) + + # TODO we are trolling + if f.func.has_symint(): + defn_name += "_symint" + + # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions. + if Variant.function in f.variants: + fw_formula = f"at::{defn_name}({', '.join(new_args)})" + else: + assert Variant.method in f.variants + fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})" + + # All of the input tangents are always used so all of them are required here. + required_inputs_tangent = tuple(diff_arg_names) + formula = fw_formula + + # At this point, the formula is final and is not modified anymore. + + # During forward formula, we use the primal instead of the input Tensors. + # This call inspects the formula to find for which input's primal are used. + required_inputs_primal = find_required_inputs(formula, "_p") + + updated_derivatives.append( + ForwardDerivative( + formula=formula, + var_names=defn.var_names, + var_types=defn.var_types, + required_inputs_fw_grad=required_inputs_tangent, + required_inputs_primal=required_inputs_primal, + required_original_self_value=False, + is_reusing_outplace_formula=False, + ) + ) + + return updated_derivatives + + +def is_forward_derivative_definition( + all_arg_names: list[str], names: tuple[str, ...] +) -> bool: + for name in names: + return name not in all_arg_names + raise RuntimeError("Expected `names` to be non-empty") + + +def create_differentiability_info( + defn_dict: dict[Any, Any], + functions_by_signature: dict[FunctionSchema, list[NativeFunction]], + functions_by_schema: dict[str, NativeFunction], + op_counter: Counter[str], + used_dispatch_keys: set[str], +) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]: + """Processes a single entry `defn` in derivatives.yaml""" + + def canonical_function( + functions: Sequence[NativeFunction], name: str + ) -> NativeFunction: + for f in functions: + if ( + not f.func.is_functional_fn() + and not f.func.is_out_fn() + and name == str(f.func.name.name) + ): + return f + # some functions only have in-place variants + assert name + "_" == cpp.name(functions[0].func) + return functions[0] + + def split_names(raw_names: str) -> tuple[str, ...]: + """Given "foo, bar", return ["foo", "bar"].""" + return tuple(x.strip() for x in raw_names.split(",")) + + def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: + """ + Check for some subtle mistakes one might make when writing derivatives. + These mistakes will compile, but will be latent until a function is + used with double backwards. + """ + + uses_grad = False # true if any derivative uses "grad" + num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" + uses_named_grads = False # true if any derivative uses "grad_{name}" + used_grads_indices: list[int] = [] # which indices of grads are used + for d in derivatives: + formula = d.formula + uses_grad = uses_grad or bool( + re.findall(IDENT_REGEX.format("grad"), formula) + ) + num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula)) + uses_named_grads = uses_named_grads or bool(d.named_gradients) + used_grads_indices.extend(used_gradient_indices(formula)) + # This is a basic sanity check: the number of places we see + # "grads" should be no fewer than the number of indices we see + # inside "grads". They may not be equal because we may use + # "grads" without an index. + assert num_grads_uses >= len(used_grads_indices) + # Thus if the number is equal, every use of grads is also + # indexed. + only_used_grads_indices = num_grads_uses == len(used_grads_indices) + + if uses_grad and num_grads_uses > 0: + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml illegally " + "mixes use of 'grad' and 'grads'. Consider replacing " + "occurrences of 'grad' with 'grads[0]'" + ) + + if only_used_grads_indices and set(used_grads_indices) == {0}: + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml solely " + "refers to 'grads[0]'. If the first output is indeed the " + "only differentiable output, replace 'grads[0]' with 'grad'; " + "otherwise, there is a likely error in your derivatives " + "declaration." + ) + + if uses_named_grads and (uses_grad or num_grads_uses > 0): + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml illegally " + 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use ' + "only one method for identifying gradients." + ) + + @with_native_function + def set_up_derivatives( + f: NativeFunction, + ) -> tuple[ + Sequence[Derivative], + Sequence[ForwardDerivative], + Sequence[Binding], + Sequence[str], + Sequence[str], + ]: + # Set up the derivative information + derivatives: list[Derivative] = [] + forward_derivatives: list[ForwardDerivative] = [] + non_differentiable_arg_names: list[str] = [] + args_with_derivatives_set: set[str] = set() + + all_arg_names = [a.name for a in cpp_arguments(f)] + all_ret_names = [ + r.name for r in f.func.returns + ] # only used for the assert below + # output_differentiability is captured from the enclosed + # scope. Don't modify it. + # + # If it is not present, then no output is explicitly + # undifferentiable. + # + # It may be present and shorter than the length of return + # values. If that's the case, any return value that does not + # have a corresponding entry is considered not differentiable. + differentiability = output_differentiability or [True] * len(f.func.returns) + # A return is available as a named gradient ... + available_named_gradients = [ + f"grad_{ret.name}" + for ret, differentiable in zip(f.func.returns, differentiability) + # if it has not been explicitly made undifferentiable + if differentiable + # and if it has a name + and ret.name is not None + # and if its type is differentiable + and ret.type.is_tensor_like() + ] + + for raw_names in sorted(defn.keys()): + formula = defn[raw_names] + names = split_names(raw_names) + + for name in names: + assert not (name in all_arg_names and name in all_ret_names), ( + f"While processing the derivative formula for '{f.func.name}' wrt '{name}', " + f"expected '{name}' to not be both an input arg and named return. " + ) + + if is_forward_derivative_definition(all_arg_names, names): + forward_derivatives.append(create_forward_derivative(f, formula, names)) + else: + if formula.lower().strip() == "non_differentiable": + non_differentiable_arg_names += names + else: + derivative = create_derivative( + f, formula, names, available_named_gradients + ) + derivatives.append(derivative) + args_with_derivatives_set |= set(names) + + overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names) + if overlap: + raise RuntimeError( + f"derivatives definition for {defn} have overlapped non_differentiable " + f"and differentiable variables: {overlap}" + ) + + # Next, let us determine the list of inputs in order. + # TODO: do we need eagerly calculate and save it here? Can it be derived + # from NativeFunction and `derivatives` on callsites instead? + args_with_derivatives = [ + a for a in cpp_arguments(f) if a.name in args_with_derivatives_set + ] + + # Postprocess forward derivatives definitions now that we know the differentiable arguments + forward_derivatives = postprocess_forward_derivatives( + f, + defn_name, + all_arg_names, + derivatives, + forward_derivatives, + args_with_derivatives, + ) + + # Test to see if the use of 'grads' makes sense. + check_grad_usage(defn_name, derivatives) + + return ( + derivatives, + forward_derivatives, + args_with_derivatives, + non_differentiable_arg_names, + available_named_gradients, + ) + + # NB: Removes 'name' from defn dictionary + specification = defn_dict.pop("name") + defn_name, _ = split_name_params(specification) + # NB: Removes 'output_differentiability' from defn dictionary + # `None` means all differentiable. + output_differentiability = defn_dict.pop("output_differentiability", None) + output_differentiability_conditions = None + if output_differentiability and any( + isinstance(diff, str) for diff in output_differentiability + ): + if len(output_differentiability) != 1: + raise RuntimeError( + f"Not supported: for {specification}," + f"output_differentiability must either be " + f"List[bool] or a List[str] where each str is a " + f"condition. In the case where it is a condition, " + f"we only support single-output functions. " + f"Please file us an issue. " + ) + output_differentiability_conditions = output_differentiability + output_differentiability = [True] + + schema_function = functions_by_schema.get(specification) + if not schema_function: + avail = "\n".join( + k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name + ) + raise RuntimeError( + f"could not find ATen function for schema: {specification} " + f". Available signatures:\n{avail}" + ) + + # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here + # to map in-place schemas to the out-of-place variants. + # TODO: maybe the logic to handle the legacy schema is no longer necessary? + signature = schema_function.func.signature() + functions = functions_by_signature[signature] + if len(functions) == 0: + avail = "\n".join( + str(k) + for k, v in functions_by_signature.items() + if cpp.name(k) == defn_name + ) + raise RuntimeError( + f"could not find ATen function for legacy signature: {signature} " + f"corresponding to schema {specification}. Please report a bug to PyTorch. " + f"Available signatures:\n{avail}" + ) + + canonical = canonical_function(functions, defn_name) + if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)): + raise RuntimeError( + f"Schema for {defn_name} has an argument named grad_input_mask, " + "but this name would be shadowed by our codegen. " + "Please use a different name in native_functions.yaml." + ) + + if "result" in (a.name for a in cpp_arguments(canonical)): + raise RuntimeError( + f"Schema for {defn_name} has an argument named result, " + "but this is only allowed for outputs." + "Please use a different name in native_functions.yaml." + ) + + diffinfo_dict = {} + for key, defn in defn_dict["dispatch"].items(): + if key != "Default" and key not in _VALID_AUTOGRAD_KEYS: + raise RuntimeError( + f"Invalid dispatch key {key} in derivatives.yaml for {specification}," + f" expected key to be one of {_VALID_AUTOGRAD_KEYS}" + ) + if key not in used_dispatch_keys: + used_dispatch_keys.add(key) + + ( + derivatives, + forward_derivatives, + args_with_derivatives, + non_differentiable_arg_names, + available_named_gradients, + ) = set_up_derivatives(canonical) + + used_named_gradients: set[str] = set() + for d in derivatives: + used_named_gradients |= d.named_gradients + + # only assign an op name if we are actually going to calculate a derivative + op = None + if args_with_derivatives: + op_prefix = _create_op_prefix(defn_name) + if key != "Default": + op_prefix = op_prefix + key + op = f"{op_prefix}{op_counter[op_prefix]}" + op_counter[op_prefix] += 1 + + diffinfo_dict[key] = DifferentiabilityInfo( + name=defn_name, + func=canonical, + op=op, + derivatives=derivatives, + forward_derivatives=forward_derivatives, + all_saved_inputs=dedup_vars( + [v for d in derivatives for v in d.saved_inputs] + ), + all_saved_outputs=dedup_vars( + [v for d in derivatives for v in d.saved_outputs] + ), + available_named_gradients=available_named_gradients, + used_named_gradients=used_named_gradients, + args_with_derivatives=args_with_derivatives, + non_differentiable_arg_names=non_differentiable_arg_names, + output_differentiability=output_differentiability, + output_differentiability_conditions=output_differentiability_conditions, + ) + + return canonical.func, diffinfo_dict + + +GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]" + + +def used_gradient_indices(formula: str) -> list[int]: + """Determine a list of gradient indices (the i in grads[i]) that + are used by the formula. + + >>> used_gradient_indices("foo(grads[0], grads[1])") + [0, 1] + """ + return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)] + + +def saved_variables( + formula: str, + nctypes: list[NamedCType], + var_names: tuple[str, ...], +) -> tuple[str, tuple[SavedAttribute, ...]]: + def stride_expr(name: str) -> str: + assert var_names == (name,), ( + 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' + 'that ".strides()" is being called on.' + ) + return f'strides_or_error({name}, "{name}")' + + REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [ + # replace self.sym_sizes() with self_sym_sizes + ( + r"{}.sym_sizes\(\)", + { + "suffix": "_sym_sizes", + "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), + }, + ), + # replace self->sym_sizes() with self_sym_sizes_opt + ( + r"{}->sym_sizes\(\)", + { + "suffix": "_sym_sizes_opt", + "nctype": lambda name: NamedCType( + name, OptionalCType(BaseCType(symIntArrayRefT)) + ), + "expr": lambda name: f"{name}.has_value() ? std::optional({name}->sym_sizes()) : std::nullopt", + }, + ), + # replace self.sym_blocksize() with self_sym_blocksize_opt + ( + r"{}.sym_blocksize\(\)", + { + "suffix": "_self_sym_blocksize_opt", + "nctype": lambda name: NamedCType( + name, OptionalCType(BaseCType(symIntArrayRefT)) + ), + "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})", + }, + ), + # replace self.options() with self_options + ( + r"{}.options\(\)", + { + "suffix": "_options", + "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)), + }, + ), + # replace zeros_like(self) with self_info + ( + r"zeros_like\({}\)", + { + "suffix": "_info", + "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)), + "expr": lambda name: name, # at save-time + "res": lambda name: name + "_info.zeros()", # at eval-time + }, + ), + # replace self.sym_size(2) with self_sym_size_2 + ( + r"{}.sym_size\((-?\w+)\)", + { + "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}", + "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), + }, + ), + # replace self.numel() with self_numel + ( + r"{}.numel\(\)", + { + "suffix": "_numel", + "nctype": lambda name: NamedCType(name, BaseCType(longT)), + }, + ), + # replace self.sym_numel() with self_sym_numel + ( + r"{}.sym_numel\(\)", + { + "suffix": "_sym_numel", + "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), + }, + ), + # replace to_args_sizes(self) with self_args_sizes + ( + r"to_args_sizes\({}\)", + { + "suffix": "_args_sizes", + "nctype": lambda name: NamedCType( + name, VectorCType(VectorCType(BaseCType(longT))) + ), + }, + ), + # replace to_args_sizes_symint(self) with self_args_sizes + ( + r"to_args_sizes_symint\({}\)", + { + "suffix": "_args_sizes_symint", + "nctype": lambda name: NamedCType( + name, VectorCType(VectorCType(BaseCType(SymIntT))) + ), + }, + ), + # replace to_args_scalartypes(self) with self_args_scalartypes + ( + r"to_args_scalartypes\({}\)", + { + "suffix": "_args_scalartypes", + "nctype": lambda name: NamedCType( + name, VectorCType(BaseCType(scalarTypeT)) + ), + }, + ), + # replace TensorGeometry(self) with self_geometry + ( + r"TensorGeometry\({}\)", + { + "suffix": "_geometry", + "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)), + }, + ), + ( + r"{}.scalar_type\(\)", + { + "suffix": "_scalar_type", + "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)), + }, + ), + # replace self.dim() with self_dim + ( + r"{}.dim\(\)", + { + "suffix": "_dim", + "nctype": lambda name: NamedCType(name, BaseCType(longT)), + }, + ), + # replace self.sym_strides() with self_sym_strides + ( + r"{}.sym_strides\(\)", + { + "suffix": "_sym_strides", + "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), + "expr": stride_expr, + }, + ), + # replace self.layout() with self_layout + ( + r"{}.layout\(\)", + { + "suffix": "_layout", + "nctype": lambda name: NamedCType(name, BaseCType(layoutT)), + }, + ), + # replace self.is_conj() with self_conjugate + ( + r"{}.is_conj\(\)", + { + "suffix": "_conjugate", + "nctype": lambda name: NamedCType(name, BaseCType(boolT)), + }, + ), + ] + + # find which arguments need to be saved + saved: list[SavedAttribute] = [] + + if ".sizes()" in formula or "->sizes()" in formula: + raise RuntimeError( + ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version," + + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}" + ) + if re.search(r"\.size\([-]?\d+\)", formula) or re.search( + r"->size\([-]?\d+\)", formula + ): + raise RuntimeError( + ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version," + + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}" + ) + if ".strides()" in formula or "->strides()" in formula: + raise RuntimeError( + ".strides() is not supported in derivative formulas. Instead, please use the SymInt version," + + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}" + ) + for nctype in nctypes: + name = ( + nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name + ) + # First search the formula for expressions which can be evaluated + # when the autograd Function is created to avoid saving variables + for regex, info in REPLACEMENTS: + + def repl(m: re.Match[str]) -> str: + suffix: str = ( + info["suffix"](m) if callable(info["suffix"]) else info["suffix"] + ) + expr: str = info["expr"](name) if "expr" in info else m.group(0) + saved.append( + SavedAttribute( + nctype=info["nctype"](name + suffix), + expr=expr, + ) + ) + if "res" in info: + replacement: str = info["res"](name) + return replacement + return name + suffix + + formula = re.sub(regex.format(name), repl, formula) + + # std::optional types stored in Backward nodes must be + # converted to std::optional before being passed into + # the backward function + if nctype.type == OptionalCType(BaseCType(stringT)): + formula = re.sub( + rf"\b{name}\b", + f"{name}.has_value() ? std::optional({name}.value()) : std::nullopt", + formula, + ) + + # Find any variables which remain in the formula and save them + if re.search(IDENT_REGEX.format(name), formula): + saved.append( + SavedAttribute( + nctype=nctype, + expr=name, + ) + ) + + return formula, tuple(saved) + + +def _create_op_prefix(name: str) -> str: + """Takes a native function name converts to a op prefix name. + + Note that the "name" parameter must be the native function name + without the optional variant suffix, so "add" instead of + "add.out". + + OP names correspond to classes, hence the change to title case. + + Example:: + >>> _create_op_prefix('add') + 'AddBackward' + """ + camel_case = "".join([p.title() for p in name.split("_")]) + return (camel_case + "Backward").replace("ForwardBackward", "Backward") + + +def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: + seen: set[str] = set() + saved: list[SavedAttribute] = [] + for var in vars: + name = ( + var.nctype.name.name + if isinstance(var.nctype.name, SpecialArgName) + else var.nctype.name + ) + if name in seen: + continue + seen.add(name) + saved.append(var) + return saved diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8276697eee065a36d1b16e583a5f011f92541c2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp @@ -0,0 +1,38 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include "torch/csrc/autograd/VariableTypeUtils.h" +#include "torch/csrc/autograd/generated/ViewFuncs.h" + +#include +#include +#include + +// ${generated_comment} + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using namespace at; +using torch::autograd::CreationMeta; +using torch::autograd::as_view; +using torch::autograd::increment_version; + +namespace torch { + +namespace ADInplaceOrView { + +namespace { +${inplace_or_view_method_definitions} +} // namespace +} // namespace ADInplaceOrView + +namespace { + +TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) { + ${inplace_or_view_wrapper_registrations}; +} + +} // namespace +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5bc089f67df74b300bc8de6568b702d48e0cb6c2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.cpp @@ -0,0 +1,20 @@ +#include "torch/csrc/autograd/FunctionsManual.h" +#include "torch/csrc/dynamo/compiled_autograd.h" + +// ${generated_comment} + +// The manual function definitions that used to be here are now in torch/csrc/autograd/FunctionsManual.cpp +// This speeds up re-compilation and allow to share these implementations so that they can be +// used for forward mode AD formulas as well. + +using namespace torch::autograd::generated::details; +using at::Tensor; +using at::Scalar; +using at::IntArrayRef; +using at::TensorList; + +namespace torch::autograd::generated { + +${autograd_function_definitions} + +} // namespace torch::autograd::generated diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.h b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.h new file mode 100644 index 0000000000000000000000000000000000000000..911d7d905c002b29941167ccff112a8079d48266 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.h @@ -0,0 +1,51 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include + +#include "torch/csrc/autograd/function.h" +#include "torch/csrc/autograd/variable.h" +#include "torch/csrc/autograd/saved_variable.h" +#include + +#include + +namespace torch { namespace autograd { namespace generated { + +using at::Scalar; +using at::Tensor; +using at::IntArrayRef; +using at::ArrayRef; +using at::Type; +using at::TensorGeometry; +using at::ScalarType; +using std::optional; +using c10::fmap; + +inline std::vector unpack_list(at::ArrayRef xs, std::shared_ptr saved_for = nullptr) { + // NB: we must explicitly do the conversion in the lambda, otherwise template + // deduction will give a Tensor of Variable which is not convertible + return fmap(xs, [&saved_for](const SavedVariable& x) { + // TODO(crcrpar): Use `std::move(saved_for)` to avoid incrementing refcount, which would need refactoring. + return static_cast(x.unpack(saved_for)); + }); +} + +inline c10::List> unpack_opt_list(at::ArrayRef xs, std::shared_ptr saved_for = nullptr) { + torch::List> result; + result.reserve(xs.size()); + for (const SavedVariable& v : xs) { + auto var = v.unpack(saved_for); + result.push_back(var.defined() ? std::optional(var) : ::std::nullopt); + } + return result; +} + +using torch::autograd::TypeAndSize; + +${autograd_function_declarations} + +}}} // namespace torch::autograd::generated diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb5e7ae44a5353a3cc2a90858fe33b7fc0ef8bfd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp @@ -0,0 +1,40 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include "torch/csrc/jit/frontend/tracer.h" + +#include + +#include "torch/csrc/autograd/function.h" + +#include "ATen/quantized/Quantizer.h" + +// ${generated_comment} + +// See the `Tracer` section in `torch/csrc/jit/OVERVIEW.md`. +// NOTE See [Sharded File] comment in VariableType + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using namespace at; + +namespace torch { + +namespace TraceType { + +namespace { +${trace_method_definitions} +} // namespace +} // namespace TraceType + +namespace { + +TORCH_LIBRARY_IMPL(aten, Tracer, m) { + ${trace_wrapper_registrations}; +} + +} // namespace + +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp new file mode 100644 index 0000000000000000000000000000000000000000..08f1f8b698e528ca382ead2fb64ee0a45a708b08 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp @@ -0,0 +1,65 @@ +#include "torch/csrc/autograd/VariableTypeUtils.h" +#include "torch/csrc/autograd/generated/VariableType.h" +#include "torch/csrc/autograd/FunctionsManual.h" + +#include +#include +#include +#include + +#include + + +// ${generated_comment} + +// NOTE [Sharded File]: on this file's split-into-shards state +// +// Back in the good old days, VariableType.cpp was generated as one +// file with every function in it, and everything was great and +// simple. +// +// However, this file was also very large (over 36,000 lines), and +// compiling it was very slow, and in fact was a significant +// bottleneck for incremental rebuilds. To address this, we now +// generate the file split across multiple shards, named +// VariableType_0.cpp and so on, which can be compiled in parallel. +// +// For ease of inspection and debugging, so that it's not necessary to +// go rooting around in multiple files, we also generate all the +// functions together in VariableTypeEverything.cpp. This generated +// file is only for convenience; it's not actually used in the +// build. If the file you're looking at now is one of the shards, you +// may want to switch over to the Everything variant to make you +// grepping smoother. + +using namespace at; +using namespace torch::autograd::generated; +using namespace torch::autograd::generated::details; + + +namespace torch::autograd { + +namespace VariableType { +namespace{ + C10_UNUSED void reset_grad_accumulator(Variable & self) { + AutogradMeta* meta = torch::autograd::impl::get_autograd_meta(self); + if (meta != nullptr) { + meta->grad_accumulator_.reset(); + } + } +} + +namespace { + + +${type_derived_method_definitions} +} +} + +namespace { + +${wrapper_registrations} + +} + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/VariableType.h b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/VariableType.h new file mode 100644 index 0000000000000000000000000000000000000000..08da173f94bf868517ed6a52fd449e6f144904ce --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/VariableType.h @@ -0,0 +1,59 @@ +#pragma once + +// ${generated_comment} + +#include +#include + +#include + +#include +#include + +#include // for size_t +#include // for function +#include // for unique_ptr +#include +#include + +namespace at { + struct Quantizer; +}; + +namespace torch { namespace autograd { + +using Variable = at::Tensor; +using at::Context; +using at::Device; +using at::Dimname; +using at::DimnameList; +using at::Generator; +using at::IntArrayRef; +using at::MemoryFormat; +using at::QScheme; +using at::Scalar; +using at::ScalarType; +using at::Storage; +using at::Tensor; +using at::TensorList; +using at::TensorOptions; +using at::Quantizer; +// This is temporary typedef to enable Quantizer in aten native function API +// we'll remove them when we are actually exposing Quantizer class +// to frontend +using ConstQuantizerPtr = const c10::intrusive_ptr&; +using std::optional; + +namespace VariableType { + TORCH_API std::vector allCUDATypes(); + TORCH_API std::vector allXPUTypes(); + TORCH_API std::vector allCPUTypes(); + TORCH_API std::vector allPrivateUser1Types(); + + at::Tensor & unpack(Tensor & t, const char * name, int pos); + const at::Tensor & unpack(const Tensor & t, const char * name, int pos); + at::Tensor unpack_opt(const Tensor & t, const char * name, int pos); + std::vector unpack(const at::ITensorListRef& tl, const char *name, int pos); +}; + +}} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11b9b194fb46f924e863c4c1dab5cbb8dbb0601b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.cpp @@ -0,0 +1,14 @@ +#include + +// ${generated_comment} + +using at::Tensor; +using at::Scalar; +using at::IntArrayRef; +using at::TensorList; + +namespace torch::autograd::generated { + +${view_func_definitions} + +} // namespace torch::autograd::generated diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.h b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.h new file mode 100644 index 0000000000000000000000000000000000000000..1f69c062d344e4cd5f98cf5f34fd4278019fdf8a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.h @@ -0,0 +1,28 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +namespace torch::autograd::generated { + +using at::Scalar; +using at::Tensor; +using at::IntArrayRef; +using at::ArrayRef; +using at::Type; +using at::ScalarType; +using std::optional; +using c10::fmap; + +${view_func_declarations} + +} // namespace torch::autograd::generated diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/annotated_fn_args.py.in b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/annotated_fn_args.py.in new file mode 100644 index 0000000000000000000000000000000000000000..1012c008451745b8f1ed1454a864f666caf2618a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/annotated_fn_args.py.in @@ -0,0 +1,11 @@ +""" +This file is needed for generating procedural tests required for +testing __torch_function__. See tests/test_overrides.py. +""" + +# flake8: noqa +import torch + +annotated_args = { +${annotated_args} +} diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_enum_tag.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_enum_tag.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83cfad1d7ba4d6fc3529caf78e036c5883e7bc23 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_enum_tag.cpp @@ -0,0 +1,15 @@ +#include +#include +#include +#include + +namespace py = pybind11; +namespace torch { + namespace autograd { + void initEnumTag(PyObject* module) { + auto m = py::handle(module).cast(); + py::enum_(m, "Tag") + ${enum_of_valid_tags}; + m.doc() = "An Enum that contains tags that can be assigned to an operator registered in C++."; + } +}} diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_fft_functions.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_fft_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..71ac4e2226d2db418eba5690995424d3f007e620 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_fft_functions.cpp @@ -0,0 +1,81 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_fft_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/out_types.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/device_lazy_init.h" + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Device; +using at::Layout; +using at::Scalar; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; +using at::IntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; + +using torch::utils::check_out_type_matches; +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef fft_functions[] = { + ${py_method_defs} + {NULL} +}; + +static PyObject* THPFFTVariableFunctionsModule = NULL; + +void initFFTFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._fft", + NULL, + -1, + fft_functions + }; + PyObject* fft = PyModule_Create(&def); + THPFFTVariableFunctionsModule = fft; + if (!fft) { + throw python_error(); + } + // steals a reference to fft + if (PyModule_AddObject(module, "_fft", fft) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_functions.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1522d6cd0f5a2a1fc0188bf9d6d0d59fe1b27d85 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_functions.cpp @@ -0,0 +1,37 @@ +#include + +// ${generated_comment} + +#include +#include + +#include +#include "torch/csrc/autograd/generated/Functions.h" +#include "torch/csrc/autograd/python_cpp_function.h" +#include +#include +#include +#include +#include + +// NOTE: See [Sharded File] comment in VariableType + +namespace torch::autograd::generated { + +template +static void addClass(PyObject* module, PyTypeObject& type, const char* name, + PyGetSetDef* function_properties=NULL, PyMethodDef* function_methods=NULL) +{ + _initFunctionPyTypeObject(type, name, function_properties, function_methods); + Py_INCREF(&type); + PyModule_AddObject(module, name, (PyObject*)&type); + registerCppFunction(typeid(C), &type); +} + +${py_function_props_and_getters} + +void initialize_autogenerated_functions${shard_id}(PyObject* module) { + ${py_function_initializers} +} + +} // namespace torch::autograd::generated diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_functions.h b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..22e37207e219431100fefaf21b02e3ed0f63d956 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_functions.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +// ${generated_comment} + +// Python bindings for automatically generated autograd functions + +namespace torch { namespace autograd { namespace generated { + +${shard_forward_declare} + +inline void initialize_autogenerated_functions(PyObject* module) { + ${shard_call} +} + +}}} // namespace torch::autograd::generated diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_linalg_functions.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_linalg_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c93752a3ddbfcf111426f98c3ea68fc625e94def --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_linalg_functions.cpp @@ -0,0 +1,68 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_linalg_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Scalar; +using at::ScalarType; +using at::MemoryFormat; +using at::Generator; +using at::IntArrayRef; +using at::TensorList; + +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef linalg_functions[] = { + ${py_method_defs} + {NULL} +}; + +static PyObject* THPLinalgVariableFunctionsModule = NULL; + +void initLinalgFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._linalg", + NULL, + -1, + linalg_functions + }; + PyObject* linalg = PyModule_Create(&def); + THPLinalgVariableFunctionsModule = linalg; + if (!linalg) { + throw python_error(); + } + // steals a reference to linalg + if (PyModule_AddObject(module, "_linalg", linalg) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_nested_functions.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_nested_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3acb5128cee1e180de887080106e7cf5559f15ee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_nested_functions.cpp @@ -0,0 +1,81 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_nested_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/out_types.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/device_lazy_init.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Device; +using at::Layout; +using at::Scalar; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; +using at::IntArrayRef; +using at::OptionalIntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; + +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef nested_functions[] = { + {NULL, NULL, 0, NULL}, + ${py_method_defs} + {NULL} +}; + +static PyObject* THPNestedVariableFunctionsModule = NULL; + +void initNestedFunctions(PyObject* module) { + nested_functions[0] = get_nested_functions_manual()[0]; + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._nested", + NULL, + -1, + nested_functions + }; + PyObject* nested = PyModule_Create(&def); + THPNestedVariableFunctionsModule = nested; + if (!nested) { + throw python_error(); + } + // steals a reference to nested + if (PyModule_AddObject(module, "_nested", nested) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_nn_functions.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_nn_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4877df6584bd6702f259f0797e2ff45d3c719bd3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_nn_functions.cpp @@ -0,0 +1,113 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_nn_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/tensor_memoryformats.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Scalar; +using at::MemoryFormat; +using at::Generator; +using at::IntArrayRef; +using at::ArrayRef; + +using namespace torch::autograd::utils; + +namespace torch::autograd { + +static PyObject* THPNNVariableFunctionsModule = NULL; + +static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + }); + ParsedArgs<5> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + if (r.has_torch_function()) { + return handle_torch_function(r, args, kwargs, THPNNVariableFunctionsModule, "torch.nn", "_parse_to"); + } + auto parsed = parse_to_conversion(r, /*allow_copy*/ false); // we don't want copy for nn.Module.to + auto& device = std::get<0>(parsed); + auto& scalarType = std::get<1>(parsed); + auto non_blocking = std::get<2>(parsed); + auto opt_memory_format = std::get<4>(parsed); + auto tuple = THPObjectPtr{PyTuple_New(4)}; + if (!tuple) throw python_error(); + if (device) { + PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device)); + } else { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(tuple.get(), 0, Py_None); + } + if (scalarType) { + PyTuple_SET_ITEM(tuple.get(), 1, Py_NewRef(torch::getTHPDtype(*scalarType))); + } else { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(tuple.get(), 1, Py_None); + } + PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking)); + if (opt_memory_format.has_value()) { + PyTuple_SET_ITEM(tuple.get(), 3, Py_NewRef(torch::utils::getTHPMemoryFormat(opt_memory_format.value()))); + } else { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(tuple.get(), 3, Py_None); + } + return tuple.release(); + END_HANDLE_TH_ERRORS +} + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef nn_functions[] = { + {"_parse_to", castPyCFunctionWithKeywords(THPVariable__parse_to), + METH_VARARGS | METH_KEYWORDS, nullptr}, + ${py_method_defs} + {NULL} +}; + +void initNNFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._nn", + NULL, + -1, + nn_functions + }; + PyObject* nn = PyModule_Create(&def); + THPNNVariableFunctionsModule = nn; + if (!nn) { + throw python_error(); + } + // steals a reference to nn + if (PyModule_AddObject(module, "_nn", nn) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_return_types.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_return_types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..139e6b8958336cfcc8328fa33581e9f1ab6d5532 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_return_types.cpp @@ -0,0 +1,52 @@ +#include + +#include +#include +#include + +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/Exceptions.h" + +namespace torch { namespace autograd { namespace generated { + +${py_return_types} + +}}} + +namespace torch::autograd { + +static void addReturnType( + PyObject* module, + const char* name, + PyTypeObject* type) { + // hold onto the TypeObject for the unlikely case of user + // deleting or overriding it. + Py_INCREF(type); + if (PyModule_AddObject( + module, + name, + (PyObject*)type) != 0) { + Py_DECREF(type); + throw python_error(); + } +} + +void initReturnTypes(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, "torch._C._return_types", nullptr, -1, {}}; + PyObject* return_types_module = PyModule_Create(&def); + if (!return_types_module) { + throw python_error(); + } + + ${py_return_types_registrations} + + // steals a reference to return_types on success + if (PyModule_AddObject(module, "_return_types", return_types_module) != 0) { + Py_DECREF(return_types_module); + throw python_error(); + } +} + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_return_types.h b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_return_types.h new file mode 100644 index 0000000000000000000000000000000000000000..ce6c355ea146a272709255b898603764112168b9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_return_types.h @@ -0,0 +1,14 @@ +#pragma once + +namespace torch { +namespace autograd { +namespace generated { + +${py_return_types_declarations} + +} + +void initReturnTypes(PyObject* module); + +} // namespace autograd +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_sparse_functions.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_sparse_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..648d91442102e9b950cb2ddb8db545c4b4e1100e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_sparse_functions.cpp @@ -0,0 +1,67 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_sparse_functions.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Scalar; +using at::ScalarType; +using at::MemoryFormat; +using at::Generator; +using at::IntArrayRef; +using at::TensorList; + +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef sparse_functions[] = { + ${py_method_defs} + {NULL} +}; + +static PyObject* THPSparseVariableFunctionsModule = NULL; + +void initSparseFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._sparse", + NULL, + -1, + sparse_functions + }; + PyObject* sparse = PyModule_Create(&def); + THPSparseVariableFunctionsModule = sparse; + if (!sparse) { + throw python_error(); + } + // steals a reference to sparse + if (PyModule_AddObject(module, "_sparse", sparse) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_special_functions.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_special_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bf9e109b4a77352cd85ba828b97d67d329543867 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_special_functions.cpp @@ -0,0 +1,79 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include "torch/csrc/Device.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/autograd/python_special_functions.h" +#include "torch/csrc/autograd/generated/python_return_types.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/out_types.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/device_lazy_init.h" + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +using at::Tensor; +using at::Device; +using at::Layout; +using at::Scalar; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; +using at::IntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; + +using torch::utils::check_out_type_matches; +using namespace torch::autograd::utils; + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef special_functions[] = { + ${py_method_defs} + {NULL} +}; + +static PyObject* THPSpecialVariableFunctionsModule = NULL; + +void initSpecialFunctions(PyObject* module) { + static struct PyModuleDef def = { + PyModuleDef_HEAD_INIT, + "torch._C._special", + NULL, + -1, + special_functions + }; + PyObject* special = PyModule_Create(&def); + THPSpecialVariableFunctionsModule = special; + if (!special) { + throw python_error(); + } + // steals a reference to special + if (PyModule_AddObject(module, "_special", special) != 0) { + throw python_error(); + } +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_torch_functions.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_torch_functions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c17d1040e1892b6a215a8c4264fe5a5345265bc7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_torch_functions.cpp @@ -0,0 +1,93 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +// Python bindings for torch.* functions implemented through ATen. +// +// The functions are bound as static methods on a class +// torch._C._VariableFunctions which is also aliased as Variable._torch +// and also copied into 'torch' module. + +#include + +// Undefine the copysign macro so that at::copysign works as intended with MSVC +// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196 +#ifdef _MSC_VER +#undef copysign +#endif // _MSC_VER + +#include "torch/csrc/autograd/python_torch_functions.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/Dtype.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/utils/out_types.h" +#include "torch/csrc/utils/pybind.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/tensor_layouts.h" +#include "torch/csrc/utils/tensor_new.h" +#include "torch/csrc/utils/tensor_numpy.h" +#include "torch/csrc/jit/frontend/tracer.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/device_lazy_init.h" +#include "torch/csrc/autograd/generated/python_return_types.h" + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#endif + +#include +#include +#include +#include + +using at::Tensor; +using at::Device; +using at::Layout; +using at::Scalar; +using at::ScalarType; +using at::Backend; +using at::OptionalDeviceGuard; +using at::DeviceGuard; +using at::TensorOptions; +using at::IntArrayRef; +using at::Generator; +using at::TensorList; +using at::Dimname; +using at::DimnameList; +using at::ArrayRef; + +using torch::utils::check_out_type_matches; +using namespace torch::autograd::utils; + +// NOTE: See [Sharded File] comment in VariableType + +namespace torch::autograd { + +// generated forward declarations start here + +${py_forwards} + +static PyMethodDef torch_functions_shard[] = { + ${py_method_defs} +}; + +void gatherTorchFunctions${shard_id}(std::vector &torch_functions) { + constexpr size_t num_functions = sizeof(torch_functions_shard) / sizeof(torch_functions_shard[0]); + torch_functions.insert( + torch_functions.end(), + torch_functions_shard, + torch_functions_shard + num_functions); +} + +// generated methods start here + +${py_methods} + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_variable_methods.cpp b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_variable_methods.cpp new file mode 100644 index 0000000000000000000000000000000000000000..16c3b9e5efd6a6eab58f4d29557386be6f893a2c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/python_variable_methods.cpp @@ -0,0 +1,1333 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include + +// Undefine the copysign macro so that at::copysign works as intended with MSVC +// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196 +#ifdef _MSC_VER +#undef copysign +#endif // _MSC_VER + +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/Size.h" +#include "torch/csrc/autograd/generated/VariableType.h" +#include "torch/csrc/autograd/python_variable.h" +#include "torch/csrc/autograd/utils/python_arg_parsing.h" +#include "torch/csrc/autograd/utils/error_messages.h" +#include "torch/csrc/autograd/utils/wrap_outputs.h" +#include "torch/csrc/jit/frontend/tracer.h" +#ifdef USE_CUDA +#include "torch/csrc/cuda/Event.h" +#endif +#include "torch/csrc/utils/device_lazy_init.h" +#include +#include "torch/csrc/utils/object_ptr.h" +#include "torch/csrc/utils/pycfunction_helpers.h" +#include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/python_numbers.h" +#include "torch/csrc/utils/python_strings.h" +#include "torch/csrc/utils/python_tuples.h" +#include "torch/csrc/utils/tensor_apply.h" +#include "torch/csrc/utils/tensor_list.h" +#include "torch/csrc/utils/tensor_new.h" +#include "torch/csrc/utils/tensor_numpy.h" +#include "torch/csrc/utils/tensor_types.h" +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/autograd/generated/python_return_types.h" + +#include +#include +#include "c10/util/Optional.h" +#include "c10/core/Stream.h" + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +$ops_headers +#include +#endif + +using at::DeviceGuard; +using at::device_of; +using at::OptionalDeviceGuard; +using at::Backend; +using at::Scalar; +using at::ScalarType; +using at::Tensor; +using c10::Stream; +using namespace torch::autograd::utils; + +namespace torch::autograd { + +static PyObject * THPVariable__is_view(PyObject *self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "_is_view", args); + } + auto& self_ = THPVariable_Unpack(self); + if (self_.is_view()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +// implemented on the python object bc no support for first-class functions in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + auto args = py::make_tuple(py::handle(arg)); + return handle_torch_function(self, "apply_", args.ptr()); + } + auto& self_ = THPVariable_Unpack(self); + if (self_.requires_grad()) { + throw std::runtime_error( + "Can't call apply_() on Variable that requires grad. Use " + "var.detach().apply_() instead."); + } + return THPVariable_Wrap(torch::utils::apply_(self_, arg)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "size(int64_t? dim=None)", + "size(Dimname dim)", + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + if (r.idx == 0) { + if (!r.toInt64Optional(0).has_value()) { + return THPSize_NewFromSymSizes(self_); + } + if (jit::tracer::isTracing()) { + // will error out if a tensor has symints + return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0))); + } else { + return torch::toPyObject(self_.sym_size(r.toInt64(0))); + } + } else if (r.idx == 1) { + if (jit::tracer::isTracing()) { + TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT"); + } + return wrap(self_.size(r.dimname(0))); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "stride(int64_t? dim=None)", + "stride(Dimname dim)", + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + if (r.idx == 0) { + if (r.toInt64Optional(0).has_value()) { + return torch::toPyObject(self_.sym_stride(r.toInt64(0))); + } + // yes, this is called strides in ATen. + at::SymIntArrayRef strides = self_.sym_strides(); + // we can't do the normal wrapping here because IntArrayRef maps to both + // torch.Size and tuple in python + // TODO: consider factoring this out + THPObjectPtr tuple(PyTuple_New(strides.size())); + if (!tuple) throw python_error(); + for (size_t i = 0; i != strides.size(); i++) { + PyObject* s = torch::toPyObject(strides[i]); + if (!s) throw python_error(); + PyTuple_SET_ITEM(tuple.get(), i, s); + } + return tuple.release(); + } else if (r.idx == 1) { + return wrap(self_.stride(r.dimname(0))); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "get_device", args, nullptr); + } + auto& self = THPVariable_Unpack(self_); + return wrap(self.get_device()); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "has_names", args); + } + auto& self = THPVariable_Unpack(self_); + return wrap(self.has_names()); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "data_ptr", args); + } + auto& self = THPVariable_Unpack(self_); + return wrap(self.data_ptr()); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "storage_offset"); + } + auto& self = THPVariable_Unpack(self_); + return py::cast(self.sym_storage_offset()).release().ptr(); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_dim(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "dim", args); + } + auto& self_ = THPVariable_Unpack(self); + return THPUtils_packInt64(self_.dim()); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_numel(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "numel", args); + } + auto& self_ = THPVariable_Unpack(self); + if (jit::tracer::isTracing()) { + return wrap(jit::tracer::getNumelOf(self_)); + } else { + return py::cast(self_.sym_numel()).release().ptr(); + } + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_contiguous(const Tensor & self, at::MemoryFormat memory_format) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.contiguous(memory_format); +} + +static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "contiguous(*, MemoryFormat memory_format=contiguous_format)", + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto& self_ = THPVariable_Unpack(self); + auto memory_format = r.memoryformat(0); + // avoids touching the GIL or current device if self is already contiguous + if (self_.is_contiguous(memory_format)) { + // NOTE: this logic is duplicated from VariableType.cpp. Since we need to + // record this call to contiguous() in the trace regardless of whether + // we actually call contiguous here, we need to record this information + // manually. + if (jit::tracer::isTracing()) { + auto tracer_state = jit::tracer::getTracingState(); + auto op_name = c10::Symbol::fromQualString("aten::contiguous"); + auto node = tracer_state->createNode(op_name, /*num_outputs=*/0); + jit::tracer::recordSourceLocation(node); + jit::tracer::addInputs(node, "self", self_); + jit::tracer::addInputs(node, "memory_format", memory_format); + tracer_state->insertNode(node); + jit::tracer::addOutput(node, self_); + } + Py_INCREF(self); + return self; + } + return THPVariable_Wrap(dispatch_contiguous(self_, memory_format)); + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_copy_(const Tensor & self, const Tensor & other, bool non_blocking) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.copy_(other, non_blocking); +} + + static PyObject * THPVariable_copy_(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "copy_(Tensor other, bool non_blocking=False)", + "copy_(Tensor other, bool async=False)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<2> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + return THPVariable_Wrap(dispatch_copy_(self_, r.tensor(0), r.toBool(1))); + END_HANDLE_TH_ERRORS +} + +template +static T dispatch_to(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + TORCH_CHECK_VALUE(self.sym_numel() == 1, "only one element tensors can be converted to Python scalars"); + return self.template item(); +} + +static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__float__", args); + } + jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW); + auto& self_ = THPVariable_Unpack(self); + return wrap(dispatch_to(self_)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__complex__", args); + } + jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW); + auto& self_ = THPVariable_Unpack(self); + return wrap(dispatch_to>(self_)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__int__", args); + } + jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW); + auto& self_ = THPVariable_Unpack(self); + if (isFloatingType(self_.scalar_type())) { + // we can't dispatch to item here because we want to avoid ATen overflow checks; + // the python integral type (long in python2) can't overflow. + return THPUtils_packDoubleAsInt(dispatch_to(self_)); + } else { + return wrap(dispatch_to(self_)); + } + END_HANDLE_TH_ERRORS +} + +// This is the __index__ function in Python which is similar to __int__, but +// called when used as a slice. +static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__index__", args); + } + auto& self_ = THPVariable_Unpack(self); + // TODO: change the condition to `self_.dim() != 0` once we expose scalars + // in PyTorch. + if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true) || self_.sym_numel() != 1) { + throw TypeError("only integer tensors of a single element can be converted to an index"); + } + return wrap(dispatch_to(self_)); + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_invert(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.bitwise_not(); +} + +static PyObject * THPVariable_invert(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__invert__", args); + } + auto& self_ = THPVariable_Unpack(self); + if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true)) { + throw TypeError("~ (operator.invert) is only implemented on integer and Boolean-type tensors"); + } + return THPVariable_Wrap(dispatch_invert(self_)); + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_to(const Tensor & self, Device device, bool non_blocking, bool copy, std::optional optional_memory_format) { + pybind11::gil_scoped_release no_gil; + // NOTE: this is where we record aten::to in the graph during tracing. However, the behavior of aten::to + // is different with respect to TensorOptions fields that are not present: aten::to inherits fields that + // are missing from the self argument while the tracer assumes that they should be populated with the + // default values (eg. float for scalar type). By explicitly copying over the tensor options here we fully + // specify all tensor options and thus record the proper trace + return self.to(self.options().device(device).memory_format(optional_memory_format), non_blocking, copy); +} + +static Tensor dispatch_to(const Tensor & self, bool non_blocking, bool copy, std::optional optional_memory_format) { + pybind11::gil_scoped_release no_gil; + return self.to(self.options().memory_format(optional_memory_format), non_blocking, copy); +} + +static Tensor dispatch_to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy, std::optional optional_memory_format) { + pybind11::gil_scoped_release no_gil; + // TODO: Make this call the TensorOptions version, maybe? + return self.to(dtype, non_blocking, copy, optional_memory_format); +} + +static Tensor dispatch_to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy, std::optional optional_memory_format) { + pybind11::gil_scoped_release no_gil; + // TODO: Make this call the TensorOptions version, maybe? + return self.to(device, dtype, non_blocking, copy, optional_memory_format); +} + +static PyObject * THPVariable_cpu(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "cpu(*, MemoryFormat? memory_format=None)" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_Wrap(dispatch_to(self_, at::Device(at::DeviceType::CPU), false, false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static Tensor dispatch_nonzero(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.nonzero(); +} + +static std::vector dispatch_nonzero_numpy(const Tensor & self) { + pybind11::gil_scoped_release no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + return self.nonzero_numpy(); +} + +static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "nonzero()", + "nonzero(*, bool as_tuple)", + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<2> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + if (r.idx == 0 || (r.idx == 1 && !r.toBool(0))) { + return wrap(dispatch_nonzero(self_)); + } else { + return wrap(dispatch_nonzero_numpy(self_)); + } + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_cuda(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "cuda(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "cuda(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto device = r.isNone(0) ? at::Device(at::DeviceType::CUDA) : r.device(0); + auto opt_memory_format = r.memoryformatOptional(2); + TORCH_CHECK(device.is_cuda(), "Invalid device, must be cuda device"); + torch::utils::device_lazy_init(at::kCUDA); + return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_mtia(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "mtia(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "mtia(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto device = r.isNone(0) ? at::Device(at::DeviceType::MTIA) : r.device(0); + auto opt_memory_format = r.memoryformatOptional(2); + TORCH_CHECK(device.is_mtia(), "Invalid device, must be MTIA device"); + torch::utils::device_lazy_init(at::kMTIA); + return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_xpu(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "xpu(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "xpu(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto device = r.isNone(0) ? at::Device(at::DeviceType::XPU) : r.device(0); + auto opt_memory_format = r.memoryformatOptional(2); + TORCH_CHECK(device.is_xpu(), "Invalid device, must be xpu device"); + torch::utils::device_lazy_init(at::kXPU); + return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_ipu(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "ipu(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "ipu(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto device = r.isNone(0) ? at::Device(at::DeviceType::IPU) : r.device(0); + auto opt_memory_format = r.memoryformatOptional(2); + TORCH_CHECK(device.is_ipu(), "Invalid device, must be ipu device"); + return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_to_type(PyObject* self, ScalarType scalarType, std::optional optional_memory_format) { + HANDLE_TH_ERRORS + auto& self_ = THPVariable_Unpack(self); + return THPVariable_Wrap(dispatch_to(self_, scalarType, false, false, optional_memory_format)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_byte(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "byte(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Byte, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_char(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "char(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Char, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_double(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "double(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Double, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_float(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "float(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Float, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_cdouble(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "cdouble(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::ComplexDouble, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_cfloat(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "cfloat(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::ComplexFloat, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_half(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "half(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Half, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_int(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "int(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Int, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_long(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "long(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Long, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_short(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "short(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Short, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_bool(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "bool(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::Bool, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_bfloat16(PyObject* self, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "bfloat16(*, MemoryFormat? memory_format=None)" + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + auto opt_memory_format = r.memoryformatOptional(0); + return THPVariable_to_type(self, ScalarType::BFloat16, opt_memory_format); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_element_size(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "element_size", args); + } + auto& self_ = THPVariable_Unpack(self); + return THPUtils_packInt64(self_.element_size()); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object bc PyObjects not declarable in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_numpy(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "numpy(*, bool force=False)" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + jit::tracer::warn("Converting a tensor to a NumPy array", jit::tracer::WARN_PYTHON_DATAFLOW); + return torch::utils::tensor_to_numpy(self_, r.toBool(0)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_requires_grad_(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "requires_grad_(bool requires_grad=True)", + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + // temporary hack to improve functorch UX. + const auto& functorch_tls = at::functorch::functorchTLSAccessor(); + if (functorch_tls) { + functorch_tls->checkSupportsInplaceRequiresGrad(); + } + + auto requires_grad = r.toBool(0); + // should we throw if requires_grad is true? var.requires_grad = True throws here + // but it's nice to let this be a no-op. + if (!self_.is_leaf() && !requires_grad) { + throw std::runtime_error(autograd::utils::requires_grad_leaf_error(requires_grad)); + } + if (requires_grad && ! isDifferentiableType(at::typeMetaToScalarType(self_.dtype()))) { + throw std::runtime_error("only Tensors of floating point dtype can require gradients"); + } + self_.set_requires_grad(requires_grad); + return THPVariable_Wrap(self_); + END_HANDLE_TH_ERRORS +} + +inline bool dispatch_is_contiguous(const Tensor & self, MemoryFormat memory_format) { + return self.is_contiguous(memory_format); +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "is_contiguous(*, MemoryFormat memory_format=contiguous_format)", + }); + ParsedArgs<1> parsed_args; + auto r = parser.parse(self_, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self_, args, kwargs, PyObject_Type(self_), "torch.Tensor"); + } + + auto memory_format = r.memoryformat(0); + auto& self = THPVariable_Unpack(self_); + return wrap(dispatch_is_contiguous(self, memory_format)); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_item(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "item", args); + } + jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW); + auto& self_ = THPVariable_Unpack(self); + auto dispatch_item_ = [](const Tensor& self) -> at::Scalar { + pybind11::gil_scoped_release no_gil; + return self.item(); + }; + return py::cast(dispatch_item_(self_)).release().ptr(); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object bc no support for first class functions in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_map_(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ "map_(Tensor other, PyObject* callable)" }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<2> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + Variable other = r.tensor(0); + if (self_.requires_grad() || other.requires_grad()) { + throw std::runtime_error( + "Can't call map_() on Variable that requires grad. Use " + "var.detach().map_() instead."); + } + TORCH_CHECK( + !self_.unsafeGetTensorImpl()->is_python_dispatch() && !other.unsafeGetTensorImpl()->is_python_dispatch(), + ".map_ is not supported for tensor subclasses."); + + return THPVariable_Wrap(torch::utils::map_(self_, other, r.pyobject(1))); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object bc no support for first class functions in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_map2_(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ "map2_(Tensor x, Tensor y, PyObject* callable)" }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + Variable x = r.tensor(0); + Variable y = r.tensor(1); + if (self_.requires_grad() || x.requires_grad() || y.requires_grad()) { + throw std::runtime_error( + "Can't call map2_() on Variable that requires grad. Use " + "var.detach().map2_() instead."); + } + TORCH_CHECK( + !x.unsafeGetTensorImpl()->is_python_dispatch() && !y.unsafeGetTensorImpl()->is_python_dispatch(), + ".map2_ is not supported for tensor subclasses."); + return THPVariable_Wrap(torch::utils::map2_(self_, x, y, r.pyobject(2))); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "new", args, kwargs); + } + auto& self_ = THPVariable_Unpack(self); + OptionalDeviceGuard device_guard(device_of(self_)); + return THPVariable_Wrap(torch::utils::legacy_tensor_new(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "new_tensor", args, kwargs); + } + auto& self_ = THPVariable_Unpack(self); + OptionalDeviceGuard device_guard(device_of(self_)); + return THPVariable_Wrap(torch::utils::new_tensor(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_storage(PyObject* self, PyObject* arg) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "untyped_storage"); + } + auto& self_ = THPVariable_Unpack(self); + return createPyObject(self_.storage()); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + }); + ParsedArgs<5> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto parsed = parse_to_conversion(r, /*allow_copy*/ true); + auto& device = std::get<0>(parsed); + auto& scalarType = std::get<1>(parsed); + auto non_blocking = std::get<2>(parsed); + auto copy = std::get<3>(parsed); + auto opt_memory_format = std::get<4>(parsed); + auto& self_ = THPVariable_Unpack(self); + torch::utils::maybe_initialize_device(device); + if (!device && !scalarType && !copy && !opt_memory_format.has_value()) { + Py_INCREF(self); + return self; + } else if (!device && !scalarType) { + return THPVariable_Wrap( + dispatch_to(self_, non_blocking, copy, opt_memory_format)); + } else if (!device) { + return THPVariable_Wrap(dispatch_to(self_, *scalarType, non_blocking, copy, opt_memory_format)); + } else if (!scalarType) { + return THPVariable_Wrap(dispatch_to(self_, *device, non_blocking, copy, opt_memory_format)); + } else { + return THPVariable_Wrap(dispatch_to(self_, *device, *scalarType, non_blocking, copy, opt_memory_format)); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// implemented on the python object b/c arbitrarily nested list not declarable in native_functions.yaml +// See: ATen/native/README.md for more context +static PyObject * THPVariable_tolist(PyObject* self, PyObject* args) +{ + HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "tolist", args); + } + jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW); + auto self_ = THPVariable_Unpack(self); + return torch::utils::tensor_to_list(self_); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "type(PyObject* dtype=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)", + "type(PyObject* dtype=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated" + }); + auto& self_ = THPVariable_Unpack(self); + ParsedArgs<3> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + + if (r.isNone(0)) { + return THPUtils_packString(torch::utils::options_to_string(self_.options())); + } + auto obj = r.pyobject(0); + auto opt_memory_format = r.memoryformatOptional(2); + std::string type_name; + bool is_dtype = false; + if (PyType_Check(obj)) { + if (obj == THPVariableClass) { + type_name = "torch.Tensor"; + } else { + type_name = ((PyTypeObject*)obj)->tp_name; + } + } else if (THPUtils_checkString(obj)) { + type_name = THPUtils_unpackString(obj); + } else if (THPDtype_Check(obj)) { + is_dtype = true; + } else { + throw TypeError("dtype must be a type, str, or dtype object"); + } + ScalarType scalar_type; + Device device = self_.device(); + if (is_dtype) { + scalar_type = r.scalartype(0); + return THPVariable_Wrap(dispatch_to(self_, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format)); + } + at::TensorOptions options = torch::utils::options_from_string(type_name); + scalar_type = at::typeMetaToScalarType(options.dtype()); + auto device_type = options.device().type(); + if (device_type != device.type()) { + device = at::Device(device_type); + } + torch::utils::maybe_initialize_device(device); + return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format)); + END_HANDLE_TH_ERRORS +} + +// generated methods start here + +${py_methods} + +static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) { + if (check_has_torch_function(self)) { + HANDLE_TH_ERRORS + return handle_torch_function(self, "__bool__", args); + END_HANDLE_TH_ERRORS + } + jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW); + return THPVariable_is_nonzero(self, args); +} + +static PyObject * THPVariable___eq__(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + HANDLE_TH_ERRORS +#ifdef USE_NUMPY + if (torch::utils::is_numpy_available()) { + static PythonArgParser parser({ + "__eq__(PyObject* other)", + }, /*traceable=*/true); + + ParsedArgs<1> parsed_args; + auto _r = parser.parse(self_, args, kwargs, parsed_args); + if(_r.has_torch_function()) { + return handle_torch_function(_r, self_, args, kwargs, THPVariableClass, "torch.Tensor"); + } + switch (_r.idx) { + case 0: { + auto other = _r.pyobject(0); + if (PyArray_Check(other)) { + auto other_tensor = torch::utils::tensor_from_numpy(other); + auto dispatch_eq = [](const at::Tensor & self, const at::Tensor & other) -> at::Tensor { + pybind11::gil_scoped_release no_gil; + return self.eq(other); + }; + const Tensor& self = THPVariable_Unpack(self_); + return wrap(dispatch_eq(self, other_tensor)); + } + } + } + } +#endif + return THPVariable_eq(self_, args, kwargs); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// Wrapper converts a raised TypeError into returning NotImplemented +// Used to implement binary arithmetic operators +template +static PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, PyObject* kwargs) { + + PyObject* ret = Func(self, args, kwargs); + if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) { + PyErr_Clear(); + Py_INCREF(Py_NotImplemented); + ret = Py_NotImplemented; + } + return ret; +} + +// set_ has to be defined in the template because the c10::Storage object +// does not have a type, and we need to make sure the Python storage object's +// type matches the tensor's type +static PyObject* THPVariable_set_( + PyObject* self_, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + const Tensor& self = THPVariable_Unpack(self_); + static PythonArgParser parser( + { + "set_()", + "set_(Storage source)", + "set_(Storage source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)", + "set_(Tensor source)", + "set_(Tensor source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)", + }, + /*traceable=*/false); + + ParsedArgs<4> parsed_args; + auto _r = parser.parse(args, kwargs, parsed_args); + + switch (_r.idx) { + case 0: { + // aten::set_(Tensor(a!) self) -> Tensor(a!) + auto dispatch_set_ = [](const Tensor& self) -> Tensor { + pybind11::gil_scoped_release no_gil; + return self.set_(); + }; + return wrap(dispatch_set_(self)); + } + case 1: { + // aten::set_.source_Storage(Tensor(a!) self, Storage source) -> + // Tensor(a!) + at::ScalarType storage_scalar_type; + bool is_typed_storage = true; + at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage); + TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage, + "Expected a Storage of type ", self.dtype(), + " or an UntypedStorage, but got type ", storage_scalar_type, + " for argument 1 'storage'"); + auto dispatch_set_ = [](const Tensor& self, Storage source) -> Tensor { + pybind11::gil_scoped_release no_gil; + return self.set_(source); + }; + return wrap(dispatch_set_(self, storage)); + } + case 2: { + // aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage + // source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!) + at::ScalarType storage_scalar_type; + bool is_typed_storage = true; + at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage); + TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage, + "Expected a Storage of type ", self.dtype(), + " or an UntypedStorage, but got type ", storage_scalar_type, + " for argument 1 'storage'"); + auto dispatch_set_ = [](const Tensor& self, + Storage source, + c10::SymInt storage_offset, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) -> Tensor { + pybind11::gil_scoped_release no_gil; + return self.set__symint(source, storage_offset, size, stride); + }; + return wrap(dispatch_set_( + self, storage, _r.toSymInt(1), _r.symintlist(2), _r.symintlist(3))); + } + case 3: { + // aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) + auto dispatch_set_ = [](const Tensor& self, const Tensor& source) -> Tensor { + TORCH_CHECK(source.dtype() == self.dtype(), "Could not set tensor of type ", source.dtype(), " to a tensor of type ", self.dtype()); + pybind11::gil_scoped_release no_gil; + return self.set_(source); + }; + return wrap(dispatch_set_(self, _r.tensor(0))); + } + case 4: { + // aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor + // source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!) + at::Tensor storage = _r.tensor(0); + auto dispatch_set_ = [](const Tensor& self, + const Tensor& source, + c10::SymInt storage_offset, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) -> Tensor { + pybind11::gil_scoped_release no_gil; + return self.set__symint(source, storage_offset, size, stride); + }; + return wrap(dispatch_set_( + self, storage, _r.toSymInt(1), _r.symintlist(2), _r.symintlist(3))); + } + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +// XXX: ops that are bound here are not exposed to the C++ api nor the JIT. +// Any new ops added here should be accompanied with a comment why they are not +// being registered through native_functions.yaml, and be tagged cpp / JIT +PyMethodDef variable_methods[] = { + // These magic methods are all implemented on python object to wrap NotImplementedError + {"__add__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__radd__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__iadd__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__rmul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__mul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__imul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__sub__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__isub__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__div__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__truediv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__floordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__idiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ifloordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__mod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__imod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ne__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__lt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__le__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__gt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ge__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__rand__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__ror__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__rxor__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"__bool__", THPVariable_bool_scalar, METH_NOARGS, NULL}, + {"__float__", THPVariable_float_scalar, METH_NOARGS, NULL}, + {"__complex__", THPVariable_complex_scalar, METH_NOARGS, NULL}, + {"__int__", THPVariable_integral_scalar, METH_NOARGS, NULL}, + {"__long__", THPVariable_integral_scalar, METH_NOARGS, NULL}, + {"__index__", THPVariable_index_scalar, METH_NOARGS, NULL}, + {"__nonzero__", THPVariable_bool_scalar, METH_NOARGS, NULL}, + {"__invert__", THPVariable_invert, METH_NOARGS, NULL}, + {"__matmul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"_is_view", THPVariable__is_view, METH_NOARGS, NULL}, + {"apply_", THPVariable_apply_, METH_O, NULL}, + {"bfloat16", castPyCFunctionWithKeywords(THPVariable_bfloat16), METH_VARARGS | METH_KEYWORDS, NULL}, + {"byte", castPyCFunctionWithKeywords(THPVariable_byte), METH_VARARGS | METH_KEYWORDS, NULL}, + {"char", castPyCFunctionWithKeywords(THPVariable_char), METH_VARARGS | METH_KEYWORDS, NULL}, + {"contiguous", castPyCFunctionWithKeywords(THPVariable_contiguous), METH_VARARGS | METH_KEYWORDS, NULL}, + {"copy_", castPyCFunctionWithKeywords(THPVariable_copy_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cpu", castPyCFunctionWithKeywords(THPVariable_cpu), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cuda", castPyCFunctionWithKeywords(THPVariable_cuda), METH_VARARGS | METH_KEYWORDS, NULL}, + {"mtia", castPyCFunctionWithKeywords(THPVariable_mtia), METH_VARARGS | METH_KEYWORDS, NULL}, + {"xpu", castPyCFunctionWithKeywords(THPVariable_xpu), METH_VARARGS | METH_KEYWORDS, NULL}, + {"ipu", castPyCFunctionWithKeywords(THPVariable_ipu), METH_VARARGS | METH_KEYWORDS, NULL}, + {"data_ptr", THPVariable_data_ptr, METH_NOARGS, NULL}, + {"dim", THPVariable_dim, METH_NOARGS, NULL}, + {"has_names", THPVariable_has_names, METH_NOARGS, NULL}, + {"double", castPyCFunctionWithKeywords(THPVariable_double), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cdouble", castPyCFunctionWithKeywords(THPVariable_cdouble), METH_VARARGS | METH_KEYWORDS, NULL}, + {"element_size", THPVariable_element_size, METH_NOARGS, NULL}, + {"float", castPyCFunctionWithKeywords(THPVariable_float), METH_VARARGS | METH_KEYWORDS, NULL}, + {"cfloat", castPyCFunctionWithKeywords(THPVariable_cfloat), METH_VARARGS | METH_KEYWORDS, NULL}, + {"get_device", THPVariable_get_device, METH_NOARGS, NULL}, + {"bool", castPyCFunctionWithKeywords(THPVariable_bool), METH_VARARGS | METH_KEYWORDS, NULL}, + {"half", castPyCFunctionWithKeywords(THPVariable_half), METH_VARARGS | METH_KEYWORDS, NULL}, + {"int", castPyCFunctionWithKeywords(THPVariable_int), METH_VARARGS | METH_KEYWORDS, NULL}, + {"is_contiguous", castPyCFunctionWithKeywords(THPVariable_is_contiguous), METH_VARARGS | METH_KEYWORDS, NULL}, + {"item", THPVariable_item, METH_NOARGS, NULL}, + {"long", castPyCFunctionWithKeywords(THPVariable_long), METH_VARARGS | METH_KEYWORDS, NULL}, + {"map_", castPyCFunctionWithKeywords(THPVariable_map_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"map2_", castPyCFunctionWithKeywords(THPVariable_map2_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"ndimension", THPVariable_dim, METH_NOARGS, NULL}, + {"nelement", THPVariable_numel, METH_NOARGS, NULL}, + {"new", castPyCFunctionWithKeywords(THPVariable_new), METH_VARARGS | METH_KEYWORDS, NULL}, + {"new_tensor", castPyCFunctionWithKeywords(THPVariable_new_tensor), METH_VARARGS | METH_KEYWORDS, NULL}, + {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS, NULL}, + {"numel", THPVariable_numel, METH_NOARGS, NULL}, + {"numpy", castPyCFunctionWithKeywords(THPVariable_numpy), METH_VARARGS | METH_KEYWORDS, NULL}, + {"requires_grad_", castPyCFunctionWithKeywords(THPVariable_requires_grad_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, NULL}, + {"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, NULL}, + {"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, NULL}, + {"untyped_storage", THPVariable_storage, METH_NOARGS, NULL}, + {"storage_offset", THPVariable_storage_offset, METH_NOARGS, NULL}, + {"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, NULL}, + {"to", castPyCFunctionWithKeywords(THPVariable_to), METH_VARARGS | METH_KEYWORDS, NULL}, + {"tolist", THPVariable_tolist, METH_NOARGS, NULL}, + {"type", castPyCFunctionWithKeywords(THPVariable_type), METH_VARARGS | METH_KEYWORDS, NULL}, + ${py_method_defs} + {NULL} +}; + +} // namespace torch::autograd diff --git a/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/variable_factories.h b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/variable_factories.h new file mode 100644 index 0000000000000000000000000000000000000000..2b55f441ab6249cb7963c5e4a15070f626f775b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/variable_factories.h @@ -0,0 +1,135 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +$ops_headers +#endif + +#include +#include +#include + +namespace torch { + +/// NOTE: Currently `torch::tensor(...)` doesn't support mixed data types +/// (i.e. `torch::tensor({{bool, 2.0}})` doesn't work). We might be able to +/// support it in the future by iterating over all sub-lists to find +/// the largest data type that can represent all of the elements, or by using +/// variadic templates. +/// +/// NOTE: C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / `std::vector` / +/// (nested) braced-init-list of floating-point types always produces a tensor of dtype +/// `torch::get_default_dtype()`, matching Python `torch.tensor` behavior. +/// +/// NOTE: C++ `torch::tensor` with an integer type or an `at::ArrayRef` / `std::vector` / +/// (nested) braced-init-list of integer types always produces a tensor of dtype `at::kLong` +/// (aka. int64_t), matching Python `torch.tensor` behavior. +/// +/// NOTE: The following dtypes are not supported by `torch::tensor` currently: +/// - `unsigned int` +/// - `unsigned long int` +/// - `unsigned long long int` +/// - `long long int` +inline at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const at::TensorOptions& options = {}) { + return autograd::make_variable( + // note: we remove the requires_grad setting from the TensorOptions because + // it is ignored anyways (and we actually have an assertion that it isn't set + // which would fail otherwise). We handle requires_grad explicitly here + // instead of passing it through to the kernel. + tensor_data_container.convert_to_tensor(options.requires_grad(::std::nullopt)), + options.requires_grad()); +} + +/// A generic deleter function. +using Deleter = std::function; +using at::MemoryFormat; + +/// Exposes the given `data` as a `Tensor` without taking ownership of the +/// original data. `sizes` should specify the shape of the tensor, `strides` the +/// stride in each dimension. The `deleter` function (a +/// `std::function`) will be called on the `data` when the Tensor +/// data would normally be deallocated. The `TensorOptions` specify additional +/// configuration options for the returned tensor, such as what type to +/// interpret the `data` as. +inline at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + at::IntArrayRef strides, + const Deleter& deleter, + const at::TensorOptions& options = at::TensorOptions()) { + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + return at::from_blob(data, sizes, strides, deleter, options.requires_grad(::std::nullopt)); + })(); + return autograd::make_variable(tensor, options.requires_grad()); +} + +/// Exposes the given `data` as a `Tensor` without taking ownership of the +/// original data. `sizes` should specify the shape of the tensor, `strides` the +/// stride in each dimension. The `TensorOptions` +/// specify additional configuration options for the returned tensor, such as +/// what type to interpret the `data` as. +inline at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + at::IntArrayRef strides, + const at::TensorOptions& options = at::TensorOptions()) { + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + return at::from_blob(data, sizes, strides, options.requires_grad(::std::nullopt)); + })(); + return autograd::make_variable(tensor, options.requires_grad()); +} + +/// Exposes the given `data` as a `Tensor` without taking ownership of the +/// original data. `sizes` should specify the shape of the tensor. The `deleter` +/// (a `std::function`) function will be called on the `data` when +/// the Tensor data would normally be deallocated. The `TensorOptions` specify +/// additional configuration options for the returned tensor, such as what type +/// to interpret the `data` as. +inline at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + const Deleter& deleter, + const at::TensorOptions& options = at::TensorOptions()) { + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + return at::from_blob(data, sizes, deleter, options.requires_grad(::std::nullopt)); + })(); + return autograd::make_variable(tensor, options.requires_grad()); +} + +/// Exposes the given `data` as a `Tensor` without taking ownership of the +/// original data. `sizes` should specify the shape of the tensor. The +/// `TensorOptions` specify additional configuration options for the returned +/// tensor, such as what type to interpret the `data` as. +inline at::Tensor from_blob( + void* data, + at::IntArrayRef sizes, + const at::TensorOptions& options = at::TensorOptions()) { + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + return at::from_blob(data, sizes, options.requires_grad(::std::nullopt)); + })(); + return autograd::make_variable(tensor, options.requires_grad()); +} + +${function_definitions} + +} // namespace torch diff --git a/.venv/lib/python3.11/site-packages/urllib3/contrib/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/urllib3/contrib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89ba836e74d4a6713a9b31523dbbf106b14a4522 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/urllib3/contrib/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/urllib3/contrib/__pycache__/pyopenssl.cpython-311.pyc b/.venv/lib/python3.11/site-packages/urllib3/contrib/__pycache__/pyopenssl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b7309956f732354a7fb8550ffb4cf18b621856 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/urllib3/contrib/__pycache__/pyopenssl.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/urllib3/contrib/__pycache__/socks.cpython-311.pyc b/.venv/lib/python3.11/site-packages/urllib3/contrib/__pycache__/socks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8132dd40958e3ce51d5c2269b7e0c39bb457c31a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/urllib3/contrib/__pycache__/socks.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/urllib3/contrib/emscripten/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/urllib3/contrib/emscripten/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2a2a536a53c81ee744b668d1d6999e20594047c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/urllib3/contrib/emscripten/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/urllib3/contrib/emscripten/__pycache__/fetch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/urllib3/contrib/emscripten/__pycache__/fetch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa9cad247299f21b9b43a8bc525de8be7b8137ce Binary files /dev/null and b/.venv/lib/python3.11/site-packages/urllib3/contrib/emscripten/__pycache__/fetch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/urllib3/contrib/emscripten/__pycache__/response.cpython-311.pyc b/.venv/lib/python3.11/site-packages/urllib3/contrib/emscripten/__pycache__/response.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eba886b85f10c0e6ca96b458b6b86de981ea308 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/urllib3/contrib/emscripten/__pycache__/response.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/urllib3/util/connection.py b/.venv/lib/python3.11/site-packages/urllib3/util/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..f92519ee9124e91e5da7d60ccc3f274312ed3514 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/urllib3/util/connection.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import socket +import typing + +from ..exceptions import LocationParseError +from .timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT + +_TYPE_SOCKET_OPTIONS = list[tuple[int, int, typing.Union[int, bytes]]] + +if typing.TYPE_CHECKING: + from .._base_connection import BaseHTTPConnection + + +def is_connection_dropped(conn: BaseHTTPConnection) -> bool: # Platform-specific + """ + Returns True if the connection is dropped and should be closed. + :param conn: :class:`urllib3.connection.HTTPConnection` object. + """ + return not conn.is_connected + + +# This function is copied from socket.py in the Python 2.7 standard +# library test suite. Added to its signature is only `socket_options`. +# One additional modification is that we avoid binding to IPv6 servers +# discovered in DNS if the system doesn't have IPv6 functionality. +def create_connection( + address: tuple[str, int], + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + socket_options: _TYPE_SOCKET_OPTIONS | None = None, +) -> socket.socket: + """Connect to *address* and return the socket object. + + Convenience function. Connect to *address* (a 2-tuple ``(host, + port)``) and return the socket object. Passing the optional + *timeout* parameter will set the timeout on the socket instance + before attempting to connect. If no *timeout* is supplied, the + global default timeout setting returned by :func:`socket.getdefaulttimeout` + is used. If *source_address* is set it must be a tuple of (host, port) + for the socket to bind as a source address before making the connection. + An host of '' or port 0 tells the OS to use the default. + """ + + host, port = address + if host.startswith("["): + host = host.strip("[]") + err = None + + # Using the value from allowed_gai_family() in the context of getaddrinfo lets + # us select whether to work with IPv4 DNS records, IPv6 records, or both. + # The original create_connection function always returns all records. + family = allowed_gai_family() + + try: + host.encode("idna") + except UnicodeError: + raise LocationParseError(f"'{host}', label empty or too long") from None + + for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = None + try: + sock = socket.socket(af, socktype, proto) + + # If provided, set socket level options before connecting. + _set_socket_options(sock, socket_options) + + if timeout is not _DEFAULT_TIMEOUT: + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(sa) + # Break explicitly a reference cycle + err = None + return sock + + except OSError as _: + err = _ + if sock is not None: + sock.close() + + if err is not None: + try: + raise err + finally: + # Break explicitly a reference cycle + err = None + else: + raise OSError("getaddrinfo returns an empty list") + + +def _set_socket_options( + sock: socket.socket, options: _TYPE_SOCKET_OPTIONS | None +) -> None: + if options is None: + return + + for opt in options: + sock.setsockopt(*opt) + + +def allowed_gai_family() -> socket.AddressFamily: + """This function is designed to work in the context of + getaddrinfo, where family=socket.AF_UNSPEC is the default and + will perform a DNS search for both IPv6 and IPv4 records.""" + + family = socket.AF_INET + if HAS_IPV6: + family = socket.AF_UNSPEC + return family + + +def _has_ipv6(host: str) -> bool: + """Returns True if the system can bind an IPv6 address.""" + sock = None + has_ipv6 = False + + if socket.has_ipv6: + # has_ipv6 returns true if cPython was compiled with IPv6 support. + # It does not tell us if the system has IPv6 support enabled. To + # determine that we must bind to an IPv6 address. + # https://github.com/urllib3/urllib3/pull/611 + # https://bugs.python.org/issue658327 + try: + sock = socket.socket(socket.AF_INET6) + sock.bind((host, 0)) + has_ipv6 = True + except Exception: + pass + + if sock: + sock.close() + return has_ipv6 + + +HAS_IPV6 = _has_ipv6("::1") diff --git a/.venv/lib/python3.11/site-packages/urllib3/util/response.py b/.venv/lib/python3.11/site-packages/urllib3/util/response.py new file mode 100644 index 0000000000000000000000000000000000000000..0f4578696fa2e17a900c6890ec26d65e860b0b72 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/urllib3/util/response.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import http.client as httplib +from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect + +from ..exceptions import HeaderParsingError + + +def is_fp_closed(obj: object) -> bool: + """ + Checks whether a given file-like object is closed. + + :param obj: + The file-like object to check. + """ + + try: + # Check `isclosed()` first, in case Python3 doesn't set `closed`. + # GH Issue #928 + return obj.isclosed() # type: ignore[no-any-return, attr-defined] + except AttributeError: + pass + + try: + # Check via the official file-like-object way. + return obj.closed # type: ignore[no-any-return, attr-defined] + except AttributeError: + pass + + try: + # Check if the object is a container for another file-like object that + # gets released on exhaustion (e.g. HTTPResponse). + return obj.fp is None # type: ignore[attr-defined] + except AttributeError: + pass + + raise ValueError("Unable to determine whether fp is closed.") + + +def assert_header_parsing(headers: httplib.HTTPMessage) -> None: + """ + Asserts whether all headers have been successfully parsed. + Extracts encountered errors from the result of parsing headers. + + Only works on Python 3. + + :param http.client.HTTPMessage headers: Headers to verify. + + :raises urllib3.exceptions.HeaderParsingError: + If parsing errors are found. + """ + + # This will fail silently if we pass in the wrong kind of parameter. + # To make debugging easier add an explicit check. + if not isinstance(headers, httplib.HTTPMessage): + raise TypeError(f"expected httplib.Message, got {type(headers)}.") + + unparsed_data = None + + # get_payload is actually email.message.Message.get_payload; + # we're only interested in the result if it's not a multipart message + if not headers.is_multipart(): + payload = headers.get_payload() + + if isinstance(payload, (bytes, str)): + unparsed_data = payload + + # httplib is assuming a response body is available + # when parsing headers even when httplib only sends + # header data to parse_headers() This results in + # defects on multipart responses in particular. + # See: https://github.com/urllib3/urllib3/issues/800 + + # So we ignore the following defects: + # - StartBoundaryNotFoundDefect: + # The claimed start boundary was never found. + # - MultipartInvariantViolationDefect: + # A message claimed to be a multipart but no subparts were found. + defects = [ + defect + for defect in headers.defects + if not isinstance( + defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect) + ) + ] + + if defects or unparsed_data: + raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data) + + +def is_response_to_head(response: httplib.HTTPResponse) -> bool: + """ + Checks whether the request of a response has been a HEAD-request. + + :param http.client.HTTPResponse response: + Response to check if the originating request + used 'HEAD' as a method. + """ + # FIXME: Can we do this somehow without accessing private httplib _method? + method_str = response._method # type: str # type: ignore[attr-defined] + return method_str.upper() == "HEAD" diff --git a/.venv/lib/python3.11/site-packages/urllib3/util/ssl_match_hostname.py b/.venv/lib/python3.11/site-packages/urllib3/util/ssl_match_hostname.py new file mode 100644 index 0000000000000000000000000000000000000000..453cfd420d835be58b5af581c3065e7b37079ecf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/urllib3/util/ssl_match_hostname.py @@ -0,0 +1,159 @@ +"""The match_hostname() function from Python 3.5, essential when using SSL.""" + +# Note: This file is under the PSF license as the code comes from the python +# stdlib. http://docs.python.org/3/license.html +# It is modified to remove commonName support. + +from __future__ import annotations + +import ipaddress +import re +import typing +from ipaddress import IPv4Address, IPv6Address + +if typing.TYPE_CHECKING: + from .ssl_ import _TYPE_PEER_CERT_RET_DICT + +__version__ = "3.5.0.1" + + +class CertificateError(ValueError): + pass + + +def _dnsname_match( + dn: typing.Any, hostname: str, max_wildcards: int = 1 +) -> typing.Match[str] | None | bool: + """Matching according to RFC 6125, section 6.4.3 + + http://tools.ietf.org/html/rfc6125#section-6.4.3 + """ + pats = [] + if not dn: + return False + + # Ported from python3-syntax: + # leftmost, *remainder = dn.split(r'.') + parts = dn.split(r".") + leftmost = parts[0] + remainder = parts[1:] + + wildcards = leftmost.count("*") + if wildcards > max_wildcards: + # Issue #17980: avoid denials of service by refusing more + # than one wildcard per fragment. A survey of established + # policy among SSL implementations showed it to be a + # reasonable choice. + raise CertificateError( + "too many wildcards in certificate DNS name: " + repr(dn) + ) + + # speed up common case w/o wildcards + if not wildcards: + return bool(dn.lower() == hostname.lower()) + + # RFC 6125, section 6.4.3, subitem 1. + # The client SHOULD NOT attempt to match a presented identifier in which + # the wildcard character comprises a label other than the left-most label. + if leftmost == "*": + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append("[^.]+") + elif leftmost.startswith("xn--") or hostname.startswith("xn--"): + # RFC 6125, section 6.4.3, subitem 3. + # The client SHOULD NOT attempt to match a presented identifier + # where the wildcard character is embedded within an A-label or + # U-label of an internationalized domain name. + pats.append(re.escape(leftmost)) + else: + # Otherwise, '*' matches any dotless string, e.g. www* + pats.append(re.escape(leftmost).replace(r"\*", "[^.]*")) + + # add the remaining fragments, ignore any wildcards + for frag in remainder: + pats.append(re.escape(frag)) + + pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE) + return pat.match(hostname) + + +def _ipaddress_match(ipname: str, host_ip: IPv4Address | IPv6Address) -> bool: + """Exact matching of IP addresses. + + RFC 9110 section 4.3.5: "A reference identity of IP-ID contains the decoded + bytes of the IP address. An IP version 4 address is 4 octets, and an IP + version 6 address is 16 octets. [...] A reference identity of type IP-ID + matches if the address is identical to an iPAddress value of the + subjectAltName extension of the certificate." + """ + # OpenSSL may add a trailing newline to a subjectAltName's IP address + # Divergence from upstream: ipaddress can't handle byte str + ip = ipaddress.ip_address(ipname.rstrip()) + return bool(ip.packed == host_ip.packed) + + +def match_hostname( + cert: _TYPE_PEER_CERT_RET_DICT | None, + hostname: str, + hostname_checks_common_name: bool = False, +) -> None: + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 + rules are followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError( + "empty or no certificate, match_hostname needs a " + "SSL socket or SSL context with either " + "CERT_OPTIONAL or CERT_REQUIRED" + ) + try: + # Divergence from upstream: ipaddress can't handle byte str + # + # The ipaddress module shipped with Python < 3.9 does not support + # scoped IPv6 addresses so we unconditionally strip the Zone IDs for + # now. Once we drop support for Python 3.9 we can remove this branch. + if "%" in hostname: + host_ip = ipaddress.ip_address(hostname[: hostname.rfind("%")]) + else: + host_ip = ipaddress.ip_address(hostname) + + except ValueError: + # Not an IP address (common case) + host_ip = None + dnsnames = [] + san: tuple[tuple[str, str], ...] = cert.get("subjectAltName", ()) + key: str + value: str + for key, value in san: + if key == "DNS": + if host_ip is None and _dnsname_match(value, hostname): + return + dnsnames.append(value) + elif key == "IP Address": + if host_ip is not None and _ipaddress_match(value, host_ip): + return + dnsnames.append(value) + + # We only check 'commonName' if it's enabled and we're not verifying + # an IP address. IP addresses aren't valid within 'commonName'. + if hostname_checks_common_name and host_ip is None and not dnsnames: + for sub in cert.get("subject", ()): + for key, value in sub: + if key == "commonName": + if _dnsname_match(value, hostname): + return + dnsnames.append(value) + + if len(dnsnames) > 1: + raise CertificateError( + "hostname %r " + "doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames))) + ) + elif len(dnsnames) == 1: + raise CertificateError(f"hostname {hostname!r} doesn't match {dnsnames[0]!r}") + else: + raise CertificateError("no appropriate subjectAltName fields were found") diff --git a/.venv/lib/python3.11/site-packages/urllib3/util/ssltransport.py b/.venv/lib/python3.11/site-packages/urllib3/util/ssltransport.py new file mode 100644 index 0000000000000000000000000000000000000000..6d59bc3bce2489c3a0aa5bcb83b737dcf33c033b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/urllib3/util/ssltransport.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import io +import socket +import ssl +import typing + +from ..exceptions import ProxySchemeUnsupported + +if typing.TYPE_CHECKING: + from typing_extensions import Self + + from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT + + +_WriteBuffer = typing.Union[bytearray, memoryview] +_ReturnValue = typing.TypeVar("_ReturnValue") + +SSL_BLOCKSIZE = 16384 + + +class SSLTransport: + """ + The SSLTransport wraps an existing socket and establishes an SSL connection. + + Contrary to Python's implementation of SSLSocket, it allows you to chain + multiple TLS connections together. It's particularly useful if you need to + implement TLS within TLS. + + The class supports most of the socket API operations. + """ + + @staticmethod + def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: + """ + Raises a ProxySchemeUnsupported if the provided ssl_context can't be used + for TLS in TLS. + + The only requirement is that the ssl_context provides the 'wrap_bio' + methods. + """ + + if not hasattr(ssl_context, "wrap_bio"): + raise ProxySchemeUnsupported( + "TLS in TLS requires SSLContext.wrap_bio() which isn't " + "available on non-native SSLContext" + ) + + def __init__( + self, + socket: socket.socket, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + suppress_ragged_eofs: bool = True, + ) -> None: + """ + Create an SSLTransport around socket using the provided ssl_context. + """ + self.incoming = ssl.MemoryBIO() + self.outgoing = ssl.MemoryBIO() + + self.suppress_ragged_eofs = suppress_ragged_eofs + self.socket = socket + + self.sslobj = ssl_context.wrap_bio( + self.incoming, self.outgoing, server_hostname=server_hostname + ) + + # Perform initial handshake. + self._ssl_io_loop(self.sslobj.do_handshake) + + def __enter__(self) -> Self: + return self + + def __exit__(self, *_: typing.Any) -> None: + self.close() + + def fileno(self) -> int: + return self.socket.fileno() + + def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: + return self._wrap_ssl_read(len, buffer) + + def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to recv") + return self._wrap_ssl_read(buflen) + + def recv_into( + self, + buffer: _WriteBuffer, + nbytes: int | None = None, + flags: int = 0, + ) -> None | int | bytes: + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to recv_into") + if nbytes is None: + nbytes = len(buffer) + return self.read(nbytes, buffer) + + def sendall(self, data: bytes, flags: int = 0) -> None: + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to sendall") + count = 0 + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + v = self.send(byte_view[count:]) + count += v + + def send(self, data: bytes, flags: int = 0) -> int: + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to send") + return self._ssl_io_loop(self.sslobj.write, data) + + def makefile( + self, + mode: str, + buffering: int | None = None, + *, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: + """ + Python's httpclient uses makefile and buffered io when reading HTTP + messages and we need to support it. + + This is unfortunately a copy and paste of socket.py makefile with small + changes to point to the socket directly. + """ + if not set(mode) <= {"r", "w", "b"}: + raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") + + writing = "w" in mode + reading = "r" in mode or not writing + assert reading or writing + binary = "b" in mode + rawmode = "" + if reading: + rawmode += "r" + if writing: + rawmode += "w" + raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] + self.socket._io_refs += 1 # type: ignore[attr-defined] + if buffering is None: + buffering = -1 + if buffering < 0: + buffering = io.DEFAULT_BUFFER_SIZE + if buffering == 0: + if not binary: + raise ValueError("unbuffered streams must be binary") + return raw + buffer: typing.BinaryIO + if reading and writing: + buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] + elif reading: + buffer = io.BufferedReader(raw, buffering) + else: + assert writing + buffer = io.BufferedWriter(raw, buffering) + if binary: + return buffer + text = io.TextIOWrapper(buffer, encoding, errors, newline) + text.mode = mode # type: ignore[misc] + return text + + def unwrap(self) -> None: + self._ssl_io_loop(self.sslobj.unwrap) + + def close(self) -> None: + self.socket.close() + + @typing.overload + def getpeercert( + self, binary_form: typing.Literal[False] = ... + ) -> _TYPE_PEER_CERT_RET_DICT | None: ... + + @typing.overload + def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ... + + def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: + return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] + + def version(self) -> str | None: + return self.sslobj.version() + + def cipher(self) -> tuple[str, str, int] | None: + return self.sslobj.cipher() + + def selected_alpn_protocol(self) -> str | None: + return self.sslobj.selected_alpn_protocol() + + def shared_ciphers(self) -> list[tuple[str, str, int]] | None: + return self.sslobj.shared_ciphers() + + def compression(self) -> str | None: + return self.sslobj.compression() + + def settimeout(self, value: float | None) -> None: + self.socket.settimeout(value) + + def gettimeout(self) -> float | None: + return self.socket.gettimeout() + + def _decref_socketios(self) -> None: + self.socket._decref_socketios() # type: ignore[attr-defined] + + def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: + try: + return self._ssl_io_loop(self.sslobj.read, len, buffer) + except ssl.SSLError as e: + if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: + return 0 # eof, return 0. + else: + raise + + # func is sslobj.do_handshake or sslobj.unwrap + @typing.overload + def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: ... + + # func is sslobj.write, arg1 is data + @typing.overload + def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: ... + + # func is sslobj.read, arg1 is len, arg2 is buffer + @typing.overload + def _ssl_io_loop( + self, + func: typing.Callable[[int, bytearray | None], bytes], + arg1: int, + arg2: bytearray | None, + ) -> bytes: ... + + def _ssl_io_loop( + self, + func: typing.Callable[..., _ReturnValue], + arg1: None | bytes | int = None, + arg2: bytearray | None = None, + ) -> _ReturnValue: + """Performs an I/O loop between incoming/outgoing and the socket.""" + should_loop = True + ret = None + + while should_loop: + errno = None + try: + if arg1 is None and arg2 is None: + ret = func() + elif arg2 is None: + ret = func(arg1) + else: + ret = func(arg1, arg2) + except ssl.SSLError as e: + if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + # WANT_READ, and WANT_WRITE are expected, others are not. + raise e + errno = e.errno + + buf = self.outgoing.read() + self.socket.sendall(buf) + + if errno is None: + should_loop = False + elif errno == ssl.SSL_ERROR_WANT_READ: + buf = self.socket.recv(SSL_BLOCKSIZE) + if buf: + self.incoming.write(buf) + else: + self.incoming.write_eof() + return typing.cast(_ReturnValue, ret)