koichi12 commited on
Commit
36383c5
·
verified ·
1 Parent(s): 8509ad7

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .venv/lib/python3.11/site-packages/mistral_common/data/tekken_240718.json +3 -0
  3. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/util.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/all_to_all_operator.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/from_operators.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/input_data_operator.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/map_operator.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/n_ary_operator.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/one_to_one_operator.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/read_operator.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/write_operator.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/ray/data/_internal/logical/rules/randomize_blocks.py +77 -0
  14. .venv/lib/python3.11/site-packages/torchgen/__pycache__/gen.cpython-311.pyc +3 -0
  15. .venv/lib/python3.11/site-packages/torchgen/api/autograd.py +870 -0
  16. .venv/lib/python3.11/site-packages/torchgen/api/functionalization.py +199 -0
  17. .venv/lib/python3.11/site-packages/torchgen/api/lazy.py +467 -0
  18. .venv/lib/python3.11/site-packages/torchgen/api/meta.py +13 -0
  19. .venv/lib/python3.11/site-packages/torchgen/api/python.py +1519 -0
  20. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/README.md +3 -0
  21. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__init__.py +0 -0
  22. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/build.bzl +14 -0
  35. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/context.py +31 -0
  36. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/derivatives.yaml +0 -0
  37. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py +132 -0
  38. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_autograd.py +147 -0
  39. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py +675 -0
  40. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_python_functions.py +1402 -0
  41. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_trace_type.py +536 -0
  42. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_factories.py +116 -0
  43. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_type.py +2180 -0
  44. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_view_funcs.py +340 -0
  45. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/load_derivatives.py +1014 -0
  46. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp +38 -0
  47. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.cpp +20 -0
  48. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.h +51 -0
  49. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp +40 -0
  50. .venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp +65 -0
.gitattributes CHANGED
@@ -398,3 +398,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
398
  .venv/lib/python3.11/site-packages/mistral_common/data/mistral_instruct_tokenizer_240216.model.v2 filter=lfs diff=lfs merge=lfs -text
399
  .venv/lib/python3.11/site-packages/numpy/lib/tests/__pycache__/test_io.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
400
  .venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
398
  .venv/lib/python3.11/site-packages/mistral_common/data/mistral_instruct_tokenizer_240216.model.v2 filter=lfs diff=lfs merge=lfs -text
399
  .venv/lib/python3.11/site-packages/numpy/lib/tests/__pycache__/test_io.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
400
  .venv/lib/python3.11/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
401
+ .venv/lib/python3.11/site-packages/mistral_common/data/tekken_240718.json filter=lfs diff=lfs merge=lfs -text
402
+ .venv/lib/python3.11/site-packages/torchgen/__pycache__/gen.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/mistral_common/data/tekken_240718.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eccd1665d2e477697c33cb7f0daa6f6dfefc57a0a6bceb66d4be52952f827516
3
+ size 14801223
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/__pycache__/util.cpython-311.pyc ADDED
Binary file (3.76 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (209 Bytes). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/all_to_all_operator.cpython-311.pyc ADDED
Binary file (8.57 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/from_operators.cpython-311.pyc ADDED
Binary file (7.21 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/input_data_operator.cpython-311.pyc ADDED
Binary file (5.4 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/map_operator.cpython-311.pyc ADDED
Binary file (15.7 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/n_ary_operator.cpython-311.pyc ADDED
Binary file (3.34 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/one_to_one_operator.cpython-311.pyc ADDED
Binary file (4.95 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/read_operator.cpython-311.pyc ADDED
Binary file (5.51 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/operators/__pycache__/write_operator.cpython-311.pyc ADDED
Binary file (1.92 kB). View file
 
.venv/lib/python3.11/site-packages/ray/data/_internal/logical/rules/randomize_blocks.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from collections import deque
3
+
4
+ from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule
5
+ from ray.data._internal.logical.operators.all_to_all_operator import (
6
+ AbstractAllToAll,
7
+ RandomizeBlocks,
8
+ )
9
+
10
+
11
+ class ReorderRandomizeBlocksRule(Rule):
12
+ """Rule for reordering RandomizeBlocks logical operator.
13
+
14
+ Reordering RandomizeBlocks operators is to help fuse multiple
15
+ AbstractUDFMap operators together for better performance.
16
+
17
+ 1. Dedupes multiple RandomizeBlocks operators if they are not seeded.
18
+ 2. Moves RandomizeBlocks operator to the end of a sequence of AbstractUDFMap
19
+ operators. RandomizeBlocks operators are not moved across AbstractAllToAll operator
20
+ boundaries.
21
+ """
22
+
23
+ def apply(self, plan: LogicalPlan) -> LogicalPlan:
24
+ optimized_dag: LogicalOperator = self._apply(plan.dag)
25
+ new_plan = LogicalPlan(dag=optimized_dag, context=plan.context)
26
+ return new_plan
27
+
28
+ def _apply(self, op: LogicalOperator) -> LogicalOperator:
29
+ operators = []
30
+
31
+ # Post-order traversal.
32
+ nodes = deque()
33
+ for node in op.post_order_iter():
34
+ nodes.appendleft(node)
35
+
36
+ while len(nodes) > 0:
37
+ current_op = nodes.pop()
38
+ upstream_ops = current_op.input_dependencies
39
+
40
+ # Iterate through all upstream ops, and remove all RandomizeBlocks
41
+ # operators.
42
+ for i in range(len(upstream_ops)):
43
+ if isinstance(upstream_ops[i], RandomizeBlocks):
44
+ # If no seeds are provided, then collapse into a single
45
+ # RandomizeBlocks operator.
46
+ current_seed = upstream_ops[i]._seed
47
+ if not operators or current_seed or operators[-1]._seed:
48
+ # We need to make a copy of the operator.
49
+ # Because the operator instance may be shared by multiple
50
+ # Datasets. We shouldn't modify it in place.
51
+ operators.append(copy.copy(upstream_ops[i]))
52
+
53
+ # Remove RandomizeBlocks operator from the dag and wire in new input
54
+ # dependencies.
55
+ assert len(upstream_ops[i].input_dependencies) == 1
56
+ upstream_ops[i] = upstream_ops[i].input_dependencies[0]
57
+ if isinstance(current_op, AbstractAllToAll) and not isinstance(
58
+ current_op, RandomizeBlocks
59
+ ):
60
+ # If this operator is a an AllToAll Operator, then insert
61
+ # RandomizeBlocks right before this operator rather than the end of the
62
+ # DAG.
63
+ # All-to-all operators can have only 1 input operator.
64
+ assert len(upstream_ops) == 1
65
+ input_op = upstream_ops[0]
66
+ for random_op in operators:
67
+ random_op._input_dependencies = [input_op]
68
+ input_op = random_op
69
+ upstream_ops[0] = input_op
70
+ operators = []
71
+
72
+ # Add RandomizeBlocks operator as the last operator in the DAG if necessary.
73
+ for random_op in operators:
74
+ random_op._input_dependencies = [op]
75
+ op = random_op
76
+
77
+ return op
.venv/lib/python3.11/site-packages/torchgen/__pycache__/gen.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:651bd8f392a2068689c3b3a80e08fda2bab7e27693fdef2c9f01c2c6303ab472
3
+ size 123663
.venv/lib/python3.11/site-packages/torchgen/api/autograd.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import cast, Sequence
6
+
7
+ from torchgen import local
8
+ from torchgen.api import cpp
9
+ from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
10
+ from torchgen.model import (
11
+ BaseTy,
12
+ BaseType,
13
+ FunctionSchema,
14
+ ListType,
15
+ NativeFunction,
16
+ NativeFunctionsViewGroup,
17
+ SchemaKind,
18
+ Type,
19
+ )
20
+ from torchgen.utils import IDENT_REGEX
21
+
22
+
23
+ # Represents a saved attribute involved in backward calculation.
24
+ # Note that it can be a derived property of an input argument, e.g.:
25
+ # we could save `other.scalar_type()` instead of the entire `other` tensor.
26
+ @dataclass(frozen=True)
27
+ class SavedAttribute:
28
+ # The NamedCType holds the updated name and cpp type of the attribute
29
+ # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
30
+ nctype: NamedCType
31
+
32
+ # The expression to read the derived property at save time, e.g.:
33
+ # `other.scalar_type()`.
34
+ expr: str
35
+
36
+
37
+ # Represents a backward formula that calculates derivatives for one
38
+ # or more tensors.
39
+ @dataclass(frozen=True)
40
+ class Derivative:
41
+ # The formula string (legit C++ expression).
42
+ # Note that expressions against input arguments have been replaced with the
43
+ # corresponding saved attributes.
44
+ # E.g.:
45
+ # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
46
+ # here: `mul_tensor_backward(grad, self, other_scalar_type)`
47
+ formula: str
48
+
49
+ # The formula string before input argument replacement
50
+ original_formula: str
51
+
52
+ # Names of the arguments for which this formula calculates derivatives.
53
+ var_names: tuple[str, ...]
54
+
55
+ # Saved inputs that are referenced by the formula.
56
+ saved_inputs: tuple[SavedAttribute, ...]
57
+
58
+ # Saved outputs that are referenced by the formula.
59
+ saved_outputs: tuple[SavedAttribute, ...]
60
+
61
+ # Gradients that are referenced by name in the formula.
62
+ named_gradients: set[str]
63
+
64
+
65
+ # Represents a forward formula that calculates forward derivatives
66
+ # for one tensor.
67
+ @dataclass(frozen=True)
68
+ class ForwardDerivative:
69
+ # The formula string (legit C++ expression).
70
+ # Note that special keywords such as "linear" or "element_wise" have been
71
+ # replaced by the automatically generated formula.
72
+ formula: str
73
+
74
+ # Name of the output arguments for which this formula calculates forward
75
+ # derivatives
76
+ var_names: tuple[str, ...]
77
+
78
+ # Type of the output arguments for which this formula calculates forward
79
+ # derivatives
80
+ var_types: tuple[Type, ...]
81
+
82
+ # Inputs for which the forward derivatives are required for this formula
83
+ required_inputs_fw_grad: tuple[str, ...] | None
84
+
85
+ # Inputs for which the primal is required for this formula
86
+ required_inputs_primal: tuple[str, ...] | None
87
+
88
+ # Flag to specify if this formula requires the original value of self
89
+ # This is only used by inplace operations
90
+ required_original_self_value: bool
91
+
92
+ # If this formula is specified in derivatives.yaml or if we are re-using the
93
+ # out of place formula for inplace
94
+ is_reusing_outplace_formula: bool
95
+
96
+
97
+ # Represents differentiability info for a NativeFunction.
98
+ @dataclass(frozen=True)
99
+ class DifferentiabilityInfo:
100
+ # The base name read from derivatives.yaml.
101
+ name: str
102
+
103
+ # The matching native function.
104
+ #
105
+ # There can be multiple NativeFunction having the same base name:
106
+ # - different overloads with different types of input arguments;
107
+ # - in-place/out/functional variants of the same function;
108
+ #
109
+ # We first use the schema string (under the 'name' key) in derivatives.yaml
110
+ # to find the NativeFunction having the same schema string.
111
+ # Then we find the in-place/out/functional variants of the matching function.
112
+ # Among these variants, we choose the one having the same name as the
113
+ # derivatives.yaml entry. If there is no exact match, then we choose the
114
+ # in-place variant.
115
+ # TODO: maybe the logic to search for all variants is no longer necessary?
116
+ func: NativeFunction
117
+
118
+ # The name of the generated autograd function.
119
+ # It's set only if we will calculate a derivative, i.e.
120
+ # 'args_with_derivatives' is not empty.
121
+ op: str | None
122
+
123
+ # The derivatives formulae for this function.
124
+ # Note that the length of this sequence is the number of differentiable inputs
125
+ derivatives: Sequence[Derivative]
126
+
127
+ # The forward derivatives formulae for this function.
128
+ # Note that the length of this sequence is the number of differentiable outputs
129
+ forward_derivatives: Sequence[ForwardDerivative]
130
+
131
+ # The union of 'saved_inputs' of all 'derivatives'.
132
+ all_saved_inputs: Sequence[SavedAttribute]
133
+
134
+ # The union of 'saved_outputs' of all 'derivatives'.
135
+ all_saved_outputs: Sequence[SavedAttribute]
136
+
137
+ # All named gradients that are available for use, in the same
138
+ # order as in the grads vector.
139
+ available_named_gradients: Sequence[str]
140
+
141
+ # The named gradients that are used in any of the derivatives.
142
+ # Invariant: all(name in available_named_gradients for name in used_named_gradients)
143
+ used_named_gradients: set[str]
144
+
145
+ # The function's input arguments for which it calculates derivatives.
146
+ # It's the union of 'var_names' of all 'derivatives', sorted by the
147
+ # argument order in the function schema.
148
+ args_with_derivatives: Sequence[Binding]
149
+
150
+ # Names of arguments whose derivative formula is 'non_differentiable'.
151
+ non_differentiable_arg_names: Sequence[str]
152
+
153
+ # Raw data read from derivatives.yaml.
154
+ output_differentiability: list[bool] | None
155
+
156
+ # output_differentiability in derivatives.yaml can be a list of
157
+ # conditions that express if the output is differentiable. In this case,
158
+ # the number of conditions must match the number of outputs
159
+ # (NB: we only support one condition right now).
160
+ # output_differentiability gets populated with True for each condition,
161
+ # while output_differentiability_conditions gets populated with the conditions
162
+ output_differentiability_conditions: list[str] | None
163
+
164
+ @property
165
+ def has_derivatives(self) -> bool:
166
+ return len(self.args_with_derivatives) > 0
167
+
168
+ # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
169
+ # but with a new operator name.
170
+ # This is used when generating "copy" variants of view ops,
171
+ # which are able to use the exact same derivative formula as the original view op
172
+ # See Note [Codegen'd {view}_copy Operators]
173
+ def create_view_copy_from_view_derivative(
174
+ self, g: NativeFunctionsViewGroup
175
+ ) -> DifferentiabilityInfo | None:
176
+ if g.view_copy is None:
177
+ return None
178
+ f = g.view_copy
179
+
180
+ name_split_by_period = self.name.split(".", maxsplit=2)
181
+ # Append a "_copy" to the base name of the operator (but keep the overload name the same)
182
+ view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
183
+ name_split_by_period[1:]
184
+ )
185
+ view_copy_op_name = None if self.op is None else f"{self.op}_copy"
186
+
187
+ return DifferentiabilityInfo(
188
+ # Use the "_copy" version of name/func/op
189
+ name=view_copy_name,
190
+ func=f,
191
+ op=view_copy_op_name,
192
+ # But keep all derivative info the same
193
+ derivatives=self.derivatives,
194
+ forward_derivatives=self.forward_derivatives,
195
+ all_saved_inputs=self.all_saved_inputs,
196
+ all_saved_outputs=self.all_saved_outputs,
197
+ available_named_gradients=self.available_named_gradients,
198
+ used_named_gradients=self.used_named_gradients,
199
+ args_with_derivatives=self.args_with_derivatives,
200
+ non_differentiable_arg_names=self.non_differentiable_arg_names,
201
+ output_differentiability=self.output_differentiability,
202
+ output_differentiability_conditions=self.output_differentiability_conditions,
203
+ )
204
+
205
+
206
+ def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
207
+ if info is None:
208
+ return False
209
+ for derivative in info.derivatives:
210
+ formula = derivative.formula
211
+ if re.search(IDENT_REGEX.format(ident), formula):
212
+ return True
213
+ return False
214
+
215
+
216
+ def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
217
+ return uses_ident(info, "retain_variables")
218
+
219
+
220
+ def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
221
+ return uses_ident(info, "grad")
222
+
223
+
224
+ # Represents a differentiable `Argument`.
225
+ # How is it different from the `Argument` type?
226
+ # - It's processed Arguments which are differentiable and only used in the
227
+ # context of the autograd codegen;
228
+ # - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
229
+ @dataclass(frozen=True)
230
+ class DifferentiableInput:
231
+ name: str
232
+ type: Type
233
+
234
+ # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
235
+ cpp_type: str
236
+
237
+
238
+ # Represents a differentiable `Return`.
239
+ # How it it different from the `Return` type?
240
+ # - The name in `Return` is optional. Here it is always populated using the same
241
+ # `cpp.return_names()` method.
242
+ # TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
243
+ # - It's processed Returns which are differentiable, in compliance with the
244
+ # `output_differentiability` field defined in derivatives.yaml (if specified),
245
+ # and are only used in the context of the autograd codegen;
246
+ @dataclass(frozen=True)
247
+ class DifferentiableOutput:
248
+ name: str
249
+ type: Type
250
+
251
+ # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
252
+ cpp_type: str
253
+
254
+
255
+ @dataclass(frozen=True)
256
+ class NativeFunctionWithDifferentiabilityInfo:
257
+ func: NativeFunction
258
+ info: dict[str, DifferentiabilityInfo] | None
259
+ fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
260
+
261
+
262
+ # TODO: Update comment below since it is out of date.
263
+ def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
264
+ """How are we going to call the underlying implementation of a
265
+ declaration? There are two strategies:
266
+ - use_derived: we want to call the implementation on CPUDoubleType
267
+ (or a similar, derived Type instance). Because these derived
268
+ instances deal in Tensors, not Variables (it's a completely different
269
+ object, so it doesn't dispatch back to VariableType), code on
270
+ this dispatch path needs to wrap/unwrap tensors. If the
271
+ derived implementation takes and returns tensors, the
272
+ implementation is usually differentiable (although we also use
273
+ the derived dispatch path for non-differentiable functions
274
+ that we still want to dispatch on the derived Type instance;
275
+ e.g., size())
276
+ - use_type: we want to call the implementation on Type, because
277
+ it is implemented concretely, and the functions it invokes will
278
+ get dispatched back to VariableType (which will ensure that they
279
+ are differentiable.)
280
+ """
281
+ # fn is derived as long as any of its per-key differentiability infos
282
+ # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
283
+ # and ADInplaceOrViewType. We want to generate these functions as long as a
284
+ # derivative is defined for ANY dispatch key.
285
+ if fn.func.is_abstract or (
286
+ fn.info is not None and any(info.has_derivatives for info in fn.info.values())
287
+ ):
288
+ # If the function is abstract (not implemented on at::Type), we must
289
+ # call the implementation on the derived type with unpacked tensors.
290
+
291
+ # If the function has a derivative specified and is concrete, we could
292
+ # call either implementation. We prefer the calling the derived
293
+ # type's implementation with unpacked tensors because it is more
294
+ # performant in some cases: any internal calls to other ATen functions
295
+ # won't have the history tracked.
296
+
297
+ # If the function has a type dispatched argument (i.e. is a factory),
298
+ # we prefer calling the derived type's implementation both because it is
299
+ # more performant and to ensure factory functions return tensors with _version
300
+ # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
301
+ # to understand.
302
+
303
+ return "use_derived"
304
+ else:
305
+ # If the function is concrete (we don't have to override it) and we
306
+ # didn't declare it in derivatives.yaml, we'll assume that it is
307
+ # actually implemented out of differentiable functions. (This
308
+ # assumption might not hold, but then you'll see gradcheck fail.)
309
+ return "use_type"
310
+
311
+
312
+ def is_foreach_func(f: NativeFunction) -> bool:
313
+ return f.func.name.name.base.startswith("_foreach_")
314
+
315
+
316
+ # note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
317
+ # is functional for their backward derivatives (and forward derivatives in the future), i.e.,
318
+ # they would find such one in `functional_info_by_signature`. There however are some exceptions:
319
+ _foreach_with_inplace_ref = {"_foreach_zero_"}
320
+ _foreach_with_tensor_overload = {
321
+ "_foreach_add.Tensor",
322
+ "_foreach_mul.Tensor",
323
+ "_foreach_div.Tensor",
324
+ }
325
+ # The following do not support the alpha kwarg, which the nonforeach versions support.
326
+ _skip_argument_len_check = {
327
+ "_foreach_add.Scalar",
328
+ "_foreach_add_.Scalar",
329
+ "_foreach_add.ScalarList",
330
+ "_foreach_add_.ScalarList",
331
+ "_foreach_sub.Scalar",
332
+ "_foreach_sub_.Scalar",
333
+ "_foreach_sub.ScalarList",
334
+ "_foreach_sub_.ScalarList",
335
+ }
336
+
337
+
338
+ # Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
339
+ # reference to generate derivatives.
340
+ def is_reference_for_foreach(
341
+ f: NativeFunction,
342
+ function_schema: FunctionSchema,
343
+ ) -> bool:
344
+ return (
345
+ f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
346
+ and (
347
+ not function_schema.name.name.inplace
348
+ or str(f.func.name) in _foreach_with_inplace_ref
349
+ )
350
+ and (
351
+ str(f.func.name) in _skip_argument_len_check
352
+ or len(f.func.arguments.flat_non_out)
353
+ == len(function_schema.arguments.flat_non_out)
354
+ )
355
+ and all(
356
+ ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
357
+ for arg, ref_arg in zip(
358
+ f.func.arguments.flat_non_out,
359
+ function_schema.arguments.flat_non_out,
360
+ )
361
+ )
362
+ )
363
+
364
+
365
+ # TODO(crcrpar): Avoid hard coding "Default" ideally.
366
+ def gen_foreach_derivativeinfo(
367
+ foreach_function: NativeFunction,
368
+ functional_info_by_signature: dict[
369
+ FunctionSchema, dict[str, DifferentiabilityInfo]
370
+ ],
371
+ non_functional_info_by_signature: dict[
372
+ FunctionSchema, dict[str, DifferentiabilityInfo]
373
+ ],
374
+ dispatch_key: str = "Default",
375
+ ) -> tuple[DifferentiabilityInfo | None, bool]:
376
+ """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
377
+
378
+ The second return value indicates whether the info is generated in this function.
379
+ """
380
+ ref_diff_info: DifferentiabilityInfo | None = None
381
+
382
+ for function_schema, diff_info in functional_info_by_signature.items():
383
+ if not is_reference_for_foreach(foreach_function, function_schema):
384
+ continue
385
+ ref_diff_info = diff_info[dispatch_key]
386
+ if ref_diff_info is not None:
387
+ break
388
+ # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
389
+ # while the info of `zero_` is in non_functional_info_by_signature
390
+ if (
391
+ ref_diff_info is None
392
+ and foreach_function.func.kind() == SchemaKind.inplace
393
+ and str(foreach_function.func.name) in _foreach_with_inplace_ref
394
+ ):
395
+ for function_schema, diff_info in non_functional_info_by_signature.items():
396
+ if not is_reference_for_foreach(foreach_function, function_schema):
397
+ continue
398
+ ref_diff_info = diff_info[dispatch_key]
399
+ if ref_diff_info is not None:
400
+ break
401
+ if ref_diff_info is None:
402
+ return None, False
403
+
404
+ # non out-place uses the existing Derivative.
405
+ if foreach_function.func.kind() == SchemaKind.inplace:
406
+ return ref_diff_info, False
407
+
408
+ map_refarg2foreacharg, map_name2arg = {}, {}
409
+ for i, (arg, ref_arg) in enumerate(
410
+ zip(
411
+ foreach_function.func.arguments.flat_non_out,
412
+ function_schema.arguments.flat_non_out,
413
+ )
414
+ ):
415
+ map_refarg2foreacharg[ref_arg.name] = arg.name
416
+ map_name2arg[arg.name] = arg
417
+
418
+ all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
419
+ modified_derivative_formulas = []
420
+ for i, derivative in enumerate(ref_diff_info.derivatives):
421
+ modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
422
+ "result", "result[i]"
423
+ )
424
+ saved_inputs, saved_outputs = [], []
425
+ # note(crcrpar): This context seems necessary to call `cpp.argument_type`
426
+ with local.parametrize(
427
+ use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
428
+ use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
429
+ ):
430
+ for ref_input in derivative.saved_inputs:
431
+ ref_input_jit_name = ref_input.expr.split(".")[0]
432
+ mapped_name = map_refarg2foreacharg[ref_input_jit_name]
433
+ if isinstance(map_name2arg[mapped_name].type, ListType):
434
+ mapped_expr = mapped_name + "[i]"
435
+ else:
436
+ mapped_expr = mapped_name
437
+ new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
438
+ modified_formula = modified_formula.replace(
439
+ cast(str, ref_input.nctype.name), new_expr
440
+ )
441
+
442
+ nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
443
+ canonical_nctype = NamedCType(
444
+ nctype.name, nctype.type.remove_const_ref()
445
+ )
446
+ saved_inputs.append(
447
+ SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
448
+ )
449
+ for ref_output in derivative.saved_outputs:
450
+ if ref_output.nctype.name == "result":
451
+ saved_outputs.append(
452
+ SavedAttribute(
453
+ nctype=NamedCType(
454
+ name="result", type=BaseCType(tensorListT)
455
+ ),
456
+ expr="result",
457
+ )
458
+ )
459
+ else:
460
+ raise RuntimeError("")
461
+ var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
462
+ all_var_names.extend(var_names)
463
+ all_saved_inputs.extend(saved_inputs)
464
+ all_saved_outputs.extend(saved_outputs)
465
+ modified_derivative = Derivative(
466
+ formula=modified_formula,
467
+ original_formula=derivative.formula,
468
+ var_names=tuple(var_names),
469
+ saved_inputs=tuple(saved_inputs),
470
+ saved_outputs=tuple(saved_outputs),
471
+ named_gradients=set(),
472
+ )
473
+ modified_derivative_formulas.append(modified_derivative)
474
+
475
+ with local.parametrize(
476
+ use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
477
+ use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
478
+ ):
479
+ args_with_derivatives = [
480
+ Binding(
481
+ name=arg.name,
482
+ nctype=cpp.argument_type(arg, binds=arg.name),
483
+ argument=arg,
484
+ default=None,
485
+ )
486
+ for arg in foreach_function.func.arguments.flat_non_out
487
+ if arg.name in all_var_names
488
+ ]
489
+
490
+ forward_derivatives: list[ForwardDerivative] = []
491
+ fw_derivative: ForwardDerivative
492
+ for fw_derivative in ref_diff_info.forward_derivatives:
493
+ var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
494
+ var_types: list[Type] = list(fw_derivative.var_types)
495
+ required_inputs_fw_grad: list[str] = []
496
+ required_inputs_primal: list[str] = []
497
+ if fw_derivative.required_inputs_fw_grad is not None:
498
+ required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
499
+ if fw_derivative.required_inputs_primal:
500
+ required_inputs_primal = list(fw_derivative.required_inputs_primal)
501
+ modified_formula = fw_derivative.formula
502
+
503
+ # Foreach's result is TensorList
504
+ if "result" in modified_formula:
505
+ modified_formula = fw_derivative.formula.replace("result", "result[i]")
506
+
507
+ for foreach_arg, ref_arg in zip(
508
+ foreach_function.func.arguments.flat_non_out,
509
+ ref_diff_info.func.func.arguments.flat_non_out,
510
+ ):
511
+ # Modify reference forward formula
512
+ if (
513
+ isinstance(foreach_arg.type, ListType)
514
+ and not foreach_arg.type.is_tensor_like()
515
+ ):
516
+ # Assuming ScalarList
517
+ modified_formula = modified_formula.replace(
518
+ ref_arg.name, foreach_arg.name + "[i]"
519
+ )
520
+ elif foreach_arg.type.is_tensor_like():
521
+ # Assuming TensorList / Tensor
522
+ # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
523
+ assert isinstance(foreach_arg.type, ListType) or (
524
+ foreach_arg.type == BaseType(BaseTy.Tensor)
525
+ and str(foreach_function.func.name) in _foreach_with_tensor_overload
526
+ ), f"{foreach_function.func.name}, {foreach_arg.type}"
527
+ for suffix in ("_p", "_t"):
528
+ curr_expr = ref_arg.name + suffix
529
+ if curr_expr in modified_formula:
530
+ new_expr = foreach_arg.name + suffix
531
+ modified_formula = modified_formula.replace(curr_expr, new_expr)
532
+ else:
533
+ # Assuming Scalar
534
+ if foreach_arg.name != ref_arg.name:
535
+ modified_formula = modified_formula.replace(
536
+ ref_arg.name, foreach_arg.name
537
+ )
538
+
539
+ # note(crcrpar): there should exist a cooler way...
540
+ for i, name in enumerate(var_names):
541
+ if name == ref_arg.name:
542
+ var_names[i] = foreach_arg.name
543
+ var_types[i] = foreach_arg.type
544
+ for i, name in enumerate(required_inputs_fw_grad):
545
+ if name == ref_arg.name:
546
+ required_inputs_fw_grad[i] = foreach_arg.name
547
+ for i, name in enumerate(required_inputs_primal):
548
+ if name == ref_arg.name:
549
+ required_inputs_primal[i] = foreach_arg.name
550
+ forward_derivatives.append(
551
+ ForwardDerivative(
552
+ formula=modified_formula,
553
+ var_names=tuple(var_names),
554
+ var_types=tuple(var_types),
555
+ required_inputs_fw_grad=tuple(required_inputs_fw_grad),
556
+ required_inputs_primal=tuple(required_inputs_primal),
557
+ required_original_self_value=fw_derivative.required_original_self_value,
558
+ is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
559
+ )
560
+ )
561
+
562
+ return (
563
+ DifferentiabilityInfo(
564
+ name=foreach_function.func.name.name.base,
565
+ func=foreach_function,
566
+ op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
567
+ derivatives=modified_derivative_formulas,
568
+ forward_derivatives=forward_derivatives,
569
+ all_saved_inputs=tuple(set(all_saved_inputs)),
570
+ all_saved_outputs=tuple(set(all_saved_outputs)),
571
+ available_named_gradients=(),
572
+ used_named_gradients=set(),
573
+ args_with_derivatives=args_with_derivatives,
574
+ non_differentiable_arg_names=[],
575
+ output_differentiability=None,
576
+ output_differentiability_conditions=None,
577
+ ),
578
+ True,
579
+ )
580
+
581
+
582
+ def match_differentiability_info(
583
+ native_functions: list[NativeFunction],
584
+ differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
585
+ ) -> list[NativeFunctionWithDifferentiabilityInfo]:
586
+ """Sets the "derivative" key on declarations to matching autograd function
587
+ In-place functions will use the out-of-place derivative definition if there
588
+ is no in-place specific derivative.
589
+ """
590
+
591
+ functional_info_by_signature = {
592
+ schema.signature(strip_default=True): info_dict
593
+ for schema, info_dict in differentiability_infos.items()
594
+ if schema.kind() == SchemaKind.functional
595
+ }
596
+ non_functional_info_by_signature = {
597
+ schema.signature(strip_default=True): info_dict
598
+ for schema, info_dict in differentiability_infos.items()
599
+ if schema.kind() != SchemaKind.functional
600
+ }
601
+
602
+ def find_info(
603
+ f: NativeFunction,
604
+ ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
605
+ # Don't bother matching info to generated out= variants
606
+ if "generated" in f.tags and f.func.kind() == SchemaKind.out:
607
+ return None, False
608
+
609
+ # (1) Check for an exact match
610
+ if f.func in differentiability_infos:
611
+ return differentiability_infos[f.func], True
612
+
613
+ # (2) If no exact match, check if the out-of-place variant
614
+ # of this operator has a match.
615
+ # i.e mul() for mul_() or mul_out()
616
+ # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
617
+ # native functions instead of the out-place counterparts.
618
+ f_sig = f.func.signature(strip_default=True)
619
+ if f_sig in functional_info_by_signature and not is_foreach_func(f):
620
+ return functional_info_by_signature[f_sig], False
621
+
622
+ # (3) Some operators have a derivative explicitly defined for the mutable
623
+ # variant, but get a code-generated out-of-place variant which does *not*
624
+ # come with a derivative formula.
625
+ # For the generated out-of-place variant, use the mutable variant's formula
626
+ # if it exists.
627
+ if "generated" in f.tags and f_sig in non_functional_info_by_signature:
628
+ info_dict = non_functional_info_by_signature[f_sig]
629
+ # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
630
+ assert not any(
631
+ any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
632
+ for info in info_dict.values()
633
+ ), f"""\
634
+ Attempted to convert a derivative formula for a mutable operator
635
+ to be used by automatically by its functional variant ("{str(f.func)}").
636
+ this is not currently supported (we'd need to fix up the formula in the codegen)."""
637
+ return info_dict, False
638
+
639
+ # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
640
+ if is_foreach_func(f):
641
+ assert f.func not in differentiability_infos
642
+ diff_info, is_generated = gen_foreach_derivativeinfo(
643
+ f,
644
+ functional_info_by_signature,
645
+ non_functional_info_by_signature,
646
+ )
647
+ if diff_info is None:
648
+ return None, False
649
+ # TODO(crcrpar): Avoid hard coding "Default" ideally.
650
+ diff_info_dict = {"Default": diff_info}
651
+ if is_generated:
652
+ differentiability_infos[f.func] = diff_info_dict
653
+ functional_info_by_signature[f.func] = diff_info_dict
654
+ return diff_info_dict, is_generated
655
+
656
+ return None, False
657
+
658
+ result: list[NativeFunctionWithDifferentiabilityInfo] = []
659
+ for f in native_functions:
660
+ info_dict, is_exact_match = find_info(f)
661
+
662
+ # Currently, the '.strides()' to 'strides_or_error' replacement does not support
663
+ # 'self' derivatives of an inplace function, so we must check for this case.
664
+ if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
665
+ for info in info_dict.values():
666
+ for derivative in info.derivatives:
667
+ if "self" in derivative.var_names:
668
+ for saved_input in derivative.saved_inputs:
669
+ assert "strides_or_error" not in saved_input.expr, (
670
+ "Calling '.strides()' in the 'self' derivative formula of an "
671
+ f"in-place function is not supported: {f.func}"
672
+ )
673
+
674
+ if not info_dict:
675
+ result.append(
676
+ NativeFunctionWithDifferentiabilityInfo(
677
+ func=f, info=None, fw_derivatives=None
678
+ )
679
+ )
680
+ continue
681
+
682
+ fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
683
+ for key, info in info_dict.items():
684
+ if not info.forward_derivatives:
685
+ fw_derivative_dict[key] = []
686
+ continue
687
+
688
+ forward_derivatives = info.forward_derivatives
689
+
690
+ # For functions that have a single def for out-of-place and inplace (like abs())
691
+ if f.func.kind() == SchemaKind.inplace:
692
+ # For inplace functions there is a little bit of work to do:
693
+ # 1) Validate the formula and make sure the input that is modified in not used:
694
+ # - If there is a formula for the inplace variant of the function (is_exact_match == True) then
695
+ # we make sure that the original value of the input that is being modified inplace (self_p) is
696
+ # not used in the formula. Note that the formula can use "original_self_p" here and that would
697
+ # trigger a clone of the original input.
698
+ # - If we are re-using the out of place formula (is_exact_match == False) then we replace every
699
+ # occurrence of self_p and self_t by original_self_p and original_self_t. These will be
700
+ # populated by cloned version of the original input (either the clone done by the backward AD
701
+ # logic if self is also used in a backward formula or a special clone that we add).
702
+ # 2) At this point, there cannot be a self_p in the formula.
703
+ # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
704
+ # simply called self (as it is modified inplace).
705
+ # 4) Update the required primals data in case it used to contain "result" but should now contain
706
+ # "self"
707
+ # 5) If it is not an exact match, the user formula is not modifying the existing forward grad
708
+ # inplace as it should. So add some code that makes sure that we do so if the forward grad
709
+ # already exists.
710
+
711
+ assert (
712
+ len(info.forward_derivatives) == 1
713
+ ) # Only single output inplace should exist
714
+ fw_info = info.forward_derivatives[0]
715
+ formula = fw_info.formula
716
+
717
+ def replace_self_with_original_self(formula: str, postfix: str) -> str:
718
+ def repl(m: re.Match[str]) -> str:
719
+ return f"{m.group(1)}original_self{postfix}{m.group(2)}"
720
+
721
+ return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
722
+
723
+ if re.search(IDENT_REGEX.format("self_p"), formula):
724
+ if is_exact_match:
725
+ # For manually defined formulas, don't allow the original value to be used
726
+ raise RuntimeError(
727
+ f'The formula for "{f.func.name}" is using the original value of self '
728
+ "that is being modified inplace. This would lead to wrong forward gradients. "
729
+ 'Please use "result" in the formula only.'
730
+ )
731
+ else:
732
+ # When the original formula is out of place, we save a clone of the primal
733
+ # value to be able to access this value if needed
734
+ # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
735
+ formula = replace_self_with_original_self(formula, "_p")
736
+ formula = replace_self_with_original_self(formula, "_t")
737
+
738
+ # replace "result" from the formula by "self_p"
739
+ def repl(m: re.Match[str]) -> str:
740
+ return f"{m.group(1)}self_p{m.group(2)}"
741
+
742
+ formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
743
+
744
+ required_primals = fw_info.required_inputs_primal
745
+ if re.search(IDENT_REGEX.format("self_p"), formula):
746
+ required_primals = (
747
+ required_primals + ("self",) if required_primals else ("self",)
748
+ )
749
+
750
+ if not is_exact_match:
751
+ # NOTE [In-place forward AD formula Optimization]
752
+ #
753
+ # This optimization transforms the formula to directly do inplace, i.e.
754
+ # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
755
+ #
756
+ # 1) the formula satisfies the pattern: "self_t.op(*args)"
757
+ # 2) "op" in (1) needs to be the same as the op the derivative is for
758
+ #
759
+ # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
760
+ # If there is a need, we can relax (2) to allow any op that has an in-place variant
761
+ is_single_method_on_self_t = False
762
+ directly_do_inplace = False
763
+ op_name: str | None = None
764
+ between_parens: str | None = None
765
+ match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
766
+ if match:
767
+ op_name, between_parens = match.group(1), match.group(2)
768
+
769
+ # We want to...
770
+ # Match: self_t.op1(other_p.op2(arg))
771
+ # Avoid: self_t.op1(args) + self_t.op2(args)
772
+ # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
773
+ def check_parens_nest_level_gt_zero(s: str) -> bool:
774
+ level = 1
775
+ for ch in s:
776
+ if ch == ")":
777
+ level -= 1
778
+ if level == 0:
779
+ return False
780
+ if ch == "(":
781
+ level += 1
782
+ return True
783
+
784
+ is_single_method_on_self_t = check_parens_nest_level_gt_zero(
785
+ between_parens
786
+ )
787
+ directly_do_inplace = (
788
+ is_single_method_on_self_t and op_name == info.name
789
+ )
790
+
791
+ if directly_do_inplace:
792
+ assert op_name is not None
793
+ assert between_parens is not None
794
+ formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
795
+ else:
796
+ # Make sure that the forward grad is modified inplace when the original formula
797
+ # is out of place
798
+ formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
799
+
800
+ required_original_self_value = bool(
801
+ re.search(IDENT_REGEX.format("original_self_p"), formula)
802
+ ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
803
+
804
+ forward_derivatives = [
805
+ ForwardDerivative(
806
+ formula=formula,
807
+ var_names=("self",),
808
+ var_types=fw_info.var_types,
809
+ required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
810
+ required_inputs_primal=required_primals,
811
+ required_original_self_value=required_original_self_value,
812
+ is_reusing_outplace_formula=not is_exact_match,
813
+ ),
814
+ ]
815
+
816
+ fw_derivative_dict[key] = forward_derivatives
817
+
818
+ result.append(
819
+ NativeFunctionWithDifferentiabilityInfo(
820
+ func=f, info=info_dict, fw_derivatives=fw_derivative_dict
821
+ )
822
+ )
823
+
824
+ return result
825
+
826
+
827
+ def is_differentiable(
828
+ name: str, type: Type, info: DifferentiabilityInfo | None
829
+ ) -> bool:
830
+ return type.is_tensor_like() and (
831
+ info is None or name not in info.non_differentiable_arg_names
832
+ )
833
+
834
+
835
+ def gen_differentiable_outputs(
836
+ fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
837
+ ) -> list[DifferentiableOutput]:
838
+ f = fn.func
839
+ info = fn.info[key] if fn.info else None
840
+ outputs: list[DifferentiableOutput] = [
841
+ DifferentiableOutput(
842
+ name=name,
843
+ type=ret.type,
844
+ cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
845
+ )
846
+ for name, ret in zip(cpp.return_names(f), f.func.returns)
847
+ ]
848
+ output_differentiability = info.output_differentiability if info else None
849
+ if output_differentiability is not None:
850
+ if len(output_differentiability) != len(outputs):
851
+ raise RuntimeError(
852
+ f"The length of output_differentiability ({len(output_differentiability)}), "
853
+ f"does not match the number of outputs ({len(outputs)})."
854
+ )
855
+ differentiable_outputs: list[DifferentiableOutput] = []
856
+ if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
857
+ raise RuntimeError(
858
+ "output_differentiability=False for inplace operation (version_counter won't get updated)"
859
+ )
860
+ for differentiable, output in zip(output_differentiability, outputs):
861
+ if differentiable:
862
+ differentiable_outputs.append(output)
863
+ return differentiable_outputs
864
+ candidate_differentiable_outputs = list(
865
+ filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
866
+ )
867
+ if uses_single_grad(info):
868
+ return candidate_differentiable_outputs[:1]
869
+ else:
870
+ return candidate_differentiable_outputs
.venv/lib/python3.11/site-packages/torchgen/api/functionalization.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from torchgen.api import dispatcher
4
+ from torchgen.api.types import (
5
+ BaseCppType,
6
+ BaseCType,
7
+ Binding,
8
+ boolT,
9
+ ConstRefCType,
10
+ CType,
11
+ longT,
12
+ NamedCType,
13
+ tensorT,
14
+ )
15
+ from torchgen.model import (
16
+ Argument,
17
+ BaseTy,
18
+ BaseType,
19
+ FunctionSchema,
20
+ NativeFunction,
21
+ NativeFunctionsViewGroup,
22
+ )
23
+
24
+
25
+ # This file describes the translation of JIT schema to API's used
26
+ # when creating view lambdas that are used by the functionalization pass.
27
+ # There are two types of lambdas: forward lambdas and reverse lambdas.
28
+ # These API's mostly follow the dispatcher API, with a few quirks:
29
+ # - The lambda capture has to convert reference types to value types
30
+ # - While the forward lambda just directly calls into the at::_ops API
31
+ # (following the dispatcher convention), the logic here for the reverse lambda
32
+ # is responsible for generating both the call-site, and the declarations
33
+ # (which are implemented manually in the at::functionalization::impl namespace).
34
+
35
+ # The lambdas generated for each view op in the functionalization pass are of the form
36
+ # [capture_arguments](outer_arguments) -> returns_type {
37
+ # return name(inner_arguments);
38
+ # }
39
+
40
+ # Define some specific lambda input arguments.
41
+ base_binding = Binding(
42
+ name="base",
43
+ nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
44
+ argument=Argument(
45
+ name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
46
+ ),
47
+ default=None,
48
+ )
49
+ mutated_view_binding = Binding(
50
+ name="mutated_view",
51
+ nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
52
+ argument=Argument(
53
+ name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
54
+ ),
55
+ default=None,
56
+ )
57
+ mutated_view_idx_binding = Binding(
58
+ name="mutated_view_idx",
59
+ nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
60
+ argument=Argument(
61
+ name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
62
+ ),
63
+ default=None,
64
+ )
65
+ reapply_views_binding = Binding(
66
+ name="reapply_views",
67
+ nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
68
+ argument=Argument(
69
+ name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
70
+ ),
71
+ default=None,
72
+ )
73
+
74
+ InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
75
+ inverse_return_mode_binding = Binding(
76
+ name="inverse_return_mode",
77
+ nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
78
+ argument=Argument(
79
+ name="inverse_return_mode",
80
+ # NB: not actually a bool but it doesn't matter because this isn't used
81
+ type=BaseType(BaseTy.bool),
82
+ default=None,
83
+ annotation=None,
84
+ ),
85
+ default=None,
86
+ )
87
+
88
+
89
+ # The lambda capture itself doesn't have a name.
90
+ # The name returned here corresponds to the name of the inner function called by the lambda.
91
+ def name(
92
+ g: NativeFunctionsViewGroup,
93
+ *,
94
+ is_reverse: bool,
95
+ include_namespace: bool,
96
+ reapply_views: bool | None = None,
97
+ ) -> str:
98
+ if reapply_views is None:
99
+ # reapply_views is only important for the fwd lambda,
100
+ # since we always plumb the runtime "reapply_views" argument into the reverse function.
101
+ assert is_reverse
102
+ if is_reverse:
103
+ return reverse_name(g.view, include_namespace)
104
+ # in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
105
+ assert include_namespace
106
+ assert g.view_copy is not None
107
+ api_name = (
108
+ g.view.func.name.unambiguous_name()
109
+ if reapply_views
110
+ else g.view_copy.func.name.unambiguous_name()
111
+ )
112
+ return f"at::_ops::{api_name}::call"
113
+
114
+
115
+ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
116
+ # for the reverse: we plumb the "reapply_views" flag into that function and support
117
+ # both copy and non-copy variants. (We could avoid doing that, but that would require
118
+ # writing out twice as many view inverse functions).
119
+ api_name = f.func.name.unambiguous_name()
120
+ # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
121
+ if include_namespace:
122
+ return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
123
+ else:
124
+ return f"{api_name}_inverse"
125
+
126
+
127
+ def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
128
+ # capture arguments include all arguments except `self`.
129
+ # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
130
+ # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
131
+ args = func.arguments.flat_all
132
+ assert args[0].type == BaseType(BaseTy.Tensor)
133
+ non_self_args = args[1:]
134
+ non_self_value_bindings = [
135
+ dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
136
+ ]
137
+
138
+ all_bindings = [
139
+ inverse_return_mode_binding if is_reverse else reapply_views_binding
140
+ ]
141
+ all_bindings.extend(non_self_value_bindings)
142
+ return all_bindings
143
+
144
+
145
+ def returns_type(func: FunctionSchema) -> CType:
146
+ # Assertion: all view ops return tensor-like outputs
147
+ assert len(func.returns) >= 1
148
+ for ret in func.returns:
149
+ assert ret.type.is_tensor_like()
150
+ # However, the return type of the lambda is always an individual tensor.
151
+ # For multi-tensor outputs, each tensor needs to be tracked individually.
152
+ return BaseCType(tensorT)
153
+
154
+
155
+ def outer_arguments(*, is_reverse: bool) -> list[Binding]:
156
+ if is_reverse:
157
+ return [base_binding, mutated_view_binding, mutated_view_idx_binding]
158
+ else:
159
+ return [base_binding, mutated_view_idx_binding]
160
+
161
+
162
+ def inner_call_index(func: FunctionSchema) -> Binding | None:
163
+ # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
164
+ # When we replay a view op that returns multiple tensors, we need to index into the output appropriately
165
+ if len(func.returns) > 1 or (
166
+ len(func.returns) == 1 and func.returns[0].type.is_list_like()
167
+ ):
168
+ return mutated_view_idx_binding
169
+ return None
170
+
171
+
172
+ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
173
+ args = func.arguments.flat_all
174
+ assert args[0].type == BaseType(BaseTy.Tensor)
175
+ non_self_args = args[1:]
176
+ # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
177
+ # Both of these follow the dispatcher API.
178
+ non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
179
+ if not is_reverse:
180
+ # the forward lambda swaps out the original tensor argument with the lambd arg "base"
181
+ return [base_binding] + non_self_bindings
182
+ else:
183
+ # the reverse lambda does the same, but with an additional "mutated_view" arg
184
+ # additionally, we have a calling convention: for view ops that return multiple tensor outputs
185
+ # their corresponding view_inverse function takes in an additional index argument.
186
+ index_binding = inner_call_index(func)
187
+ if index_binding is not None:
188
+ return [
189
+ base_binding,
190
+ mutated_view_binding,
191
+ inverse_return_mode_binding,
192
+ index_binding,
193
+ ] + non_self_bindings
194
+ else:
195
+ return [
196
+ base_binding,
197
+ mutated_view_binding,
198
+ inverse_return_mode_binding,
199
+ ] + non_self_bindings
.venv/lib/python3.11/site-packages/torchgen/api/lazy.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from torchgen.api.types import (
6
+ BaseCppType,
7
+ BaseCType,
8
+ boolT,
9
+ CType,
10
+ deviceT,
11
+ doubleT,
12
+ generatorT,
13
+ layoutT,
14
+ ListCType,
15
+ longT,
16
+ memoryFormatT,
17
+ NamedCType,
18
+ OptionalCType,
19
+ scalarT,
20
+ scalarTypeT,
21
+ stringT,
22
+ SymIntT,
23
+ VectorCType,
24
+ )
25
+ from torchgen.model import (
26
+ Argument,
27
+ BaseTy,
28
+ BaseType,
29
+ FunctionSchema,
30
+ ListType,
31
+ OperatorName,
32
+ OptionalType,
33
+ Return,
34
+ TensorOptionsArguments,
35
+ Type,
36
+ )
37
+
38
+
39
+ _valueT: BaseCppType | None = None
40
+
41
+
42
+ # A ValueT is an IR type which represents the computation of a Tensor. In other
43
+ # words, a PyTorch user will do operations on lazy tensors, and each output lazy
44
+ # tensor internally tracks a ValueT representing the IR node that would have
45
+ # actually produced the value of this tensor for real.
46
+ #
47
+ # This is configurable because different lazy tensor backends (LTC vs XLA) will
48
+ # have different IR representations. (Though, arguably, after unification they
49
+ # shouldn't!)
50
+ def getValueT() -> BaseCppType:
51
+ global _valueT
52
+ if not _valueT:
53
+ raise NotImplementedError(
54
+ "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
55
+ )
56
+
57
+ return _valueT
58
+
59
+
60
+ def setValueT(val: BaseCppType) -> None:
61
+ global _valueT
62
+ _valueT = val
63
+
64
+
65
+ # this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
66
+ # making it easier to represent special properties of an arg.
67
+ tensorListValueT = BaseCppType("torch::lazy", "Value")
68
+
69
+
70
+ def process_ir_type(
71
+ typ: Type, properties: LazyIrProperties, *, symint: bool
72
+ ) -> BaseCType | VectorCType | OptionalCType | ListCType:
73
+ """
74
+ This function takes a type from NativeFunctions and converts it for use with
75
+ lazy tensor codegen.
76
+
77
+ Type conversion for lazy currently consists of
78
+ (1) changing at::Tensors into lazy::Values
79
+ (2) wrapping everything in a BaseCType
80
+ (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
81
+
82
+ (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
83
+ There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
84
+
85
+ This is incomplete- there are assertions in places that it's expected to need to add
86
+ more types as the codegen is used with more operators.
87
+ """
88
+ if isinstance(typ, BaseType):
89
+ if typ.name == BaseTy.Tensor:
90
+ return BaseCType(getValueT())
91
+ elif typ.name == BaseTy.Scalar:
92
+ if properties.TreatScalarsAsConstants:
93
+ return BaseCType(scalarT)
94
+ # at::scalar has special handling,
95
+ # and is wrapped in an lazy::Value just like at::tensor
96
+ return BaseCType(getValueT())
97
+ elif typ.name == BaseTy.ScalarType:
98
+ return BaseCType(scalarTypeT)
99
+ elif typ.name == BaseTy.int:
100
+ return BaseCType(longT)
101
+ elif typ.name == BaseTy.SymInt:
102
+ if symint:
103
+ return BaseCType(getValueT())
104
+ else:
105
+ return BaseCType(longT)
106
+ elif typ.name == BaseTy.bool:
107
+ return BaseCType(boolT)
108
+ elif typ.name == BaseTy.float:
109
+ return BaseCType(doubleT)
110
+ elif typ.name == BaseTy.str:
111
+ return BaseCType(stringT)
112
+ elif typ.name == BaseTy.Device:
113
+ return BaseCType(deviceT)
114
+ elif typ.name == BaseTy.Generator:
115
+ return BaseCType(generatorT)
116
+ elif typ.name == BaseTy.Layout:
117
+ return BaseCType(layoutT)
118
+ elif typ.name == BaseTy.MemoryFormat:
119
+ return BaseCType(memoryFormatT)
120
+ else:
121
+ raise AssertionError(f"TODO add support for type {repr(typ)}")
122
+ elif isinstance(typ, OptionalType):
123
+ return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
124
+ elif isinstance(typ, ListType):
125
+ if str(typ.elem) == "Tensor?":
126
+ # TODO(whc) is this actually correct? or should it use a Vector like above
127
+ return ListCType(OptionalCType(BaseCType(getValueT())))
128
+ elif str(typ.elem) == "Tensor":
129
+ # this is a TensorList which comes in from GetTensorList as a Value
130
+ return BaseCType(tensorListValueT)
131
+ elif typ.elem == BaseType(BaseTy.SymInt):
132
+ # TODO: return a value type. The problem here is analogous to
133
+ # the problem with tensorListValueT: if you have SymInt[] you
134
+ # cannot conveniently save the list of Value directly, as nodes
135
+ # expect to save values as a vector for ALL arguments. So you
136
+ # need a separate IR node that represents all of the size nodes
137
+ # assembled into a list. I'm not an LTC dev so I don't want to
138
+ # figure it out right now. Y'all figure it out...
139
+ return VectorCType(BaseCType(longT))
140
+
141
+ else:
142
+ return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
143
+ else:
144
+ raise AssertionError(f"unrecognized type {repr(typ)}")
145
+
146
+
147
+ # TODO: Determining this based off of CType is bad; this should be computed
148
+ # from Type directly; then the same logic as process_ir_type can be used
149
+ #
150
+ # Invariant: passed typ should be an *owning* CType (e.g., we will report
151
+ # that ArrayRef<Value> is NOT a value type)
152
+ def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
153
+ """
154
+ Given a type, determine if it is a Value-like type. This is equivalent to
155
+ being Tensor-like, but assumes the type has already been transformed.
156
+ """
157
+ if isinstance(typ, BaseCType):
158
+ # I am regretting my naming conventions, but now we are wrapping at::scalar in
159
+ # lazy value, while preserving other 'scalar' types as scalars in the IR
160
+ treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
161
+ return (
162
+ typ.type == getValueT()
163
+ or (typ.type == scalarT and not treat_scalars_as_constants)
164
+ or typ.type == SymIntT
165
+ )
166
+ elif typ == VectorCType(BaseCType(SymIntT)):
167
+ # TODO: report True for this
168
+ return False
169
+ elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
170
+ return isValueType(typ.elem, properties)
171
+ return False
172
+
173
+
174
+ def isSymIntType(typ: Type) -> bool:
175
+ return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
176
+
177
+
178
+ def isWrappedScalarType(typ: Type) -> bool:
179
+ """
180
+ Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
181
+ Since we literally change the type from scalarT to valueT, information is lost.
182
+ This function helps build a list of wrapped scalars to save that information
183
+ """
184
+ if isinstance(typ, BaseType):
185
+ # I am regretting my naming conventions, but now we are wrapping at::scalar in
186
+ # lazy value, while preserving other 'scalar' types as scalars in the IR
187
+ return typ.name == BaseTy.Scalar
188
+ elif isinstance(typ, (OptionalType, ListType)):
189
+ return isWrappedScalarType(typ.elem)
190
+ return False
191
+
192
+
193
+ # TODO: dedupe with Type.is_generator_like
194
+ def isGeneratorType(typ: Type) -> bool:
195
+ if isinstance(typ, BaseType):
196
+ return typ.name == BaseTy.Generator
197
+ elif isinstance(typ, (OptionalType)):
198
+ return isGeneratorType(typ.elem)
199
+ return False
200
+
201
+
202
+ # This class caches a few derived properties computed from an Argument
203
+ # and LazyIrProperties
204
+ class LazyArgument:
205
+ name: str
206
+ orig_type: Type
207
+ lazy_type_: CType | None
208
+ is_wrapped_scalar: bool
209
+ is_generator: bool
210
+ # TODO: this is lies, it is false for symint list
211
+ is_symint_or_list: bool
212
+
213
+ # Whether or not we are treating this as symint or not
214
+ symint: bool
215
+
216
+ # true if this argument is or contains a lazy IR value
217
+ is_lazy_value: bool
218
+
219
+ def __init__(
220
+ self, arg: Argument, properties: LazyIrProperties, *, symint: bool
221
+ ) -> None:
222
+ self.name = arg.name
223
+ self.orig_type = arg.type
224
+ self.symint = symint
225
+ self.is_optional = isinstance(arg.type, OptionalType)
226
+ self.is_generator = isGeneratorType(arg.type)
227
+ self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
228
+ self.is_wrapped_scalar = isWrappedScalarType(arg.type)
229
+ self.is_symint_or_list = symint and (
230
+ isSymIntType(arg.type)
231
+ or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
232
+ # TODO: lists of symints are not currently treated as value types
233
+ # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
234
+ )
235
+
236
+ self.is_lazy_value = isValueType(self.lazy_type, properties)
237
+
238
+ @property
239
+ def lazy_type(self) -> CType:
240
+ assert (
241
+ self.lazy_type_ is not None
242
+ ), f"Attempted to access lazy_type for invalid argument {self.name}"
243
+ return self.lazy_type_
244
+
245
+
246
+ class LazyIrProperties:
247
+ """Collection of properties for an IR node
248
+
249
+ The property groups are listed below. Each group is mutually
250
+ exclusive, meaning that only one property from each group can be True
251
+ at any one time. The properties can be accessed as if they were normal
252
+ attributes. The mutual exclusivity is automatically handled.
253
+ """
254
+
255
+ Properties: tuple[tuple[str, ...], ...] = (
256
+ (
257
+ "ShapePrecompute", # Assume shape has been precomputed
258
+ "ShapeCompute", # Need to compute the shape on construction
259
+ "ShapeCache", # Utilize the shape cache to defer computation
260
+ ),
261
+ (
262
+ "Lower", # Codegen full lower function
263
+ "LowerDeclOnly", # Codegen only lower function declaration
264
+ ),
265
+ (
266
+ "CanBeReused", # Codegen full reuse function
267
+ "CanBeReusedDeclOnly", # Codegen only reuse function declaration
268
+ ),
269
+ (
270
+ "CreateFn", # Codegen full create function
271
+ "CreateFnDeclOnly", # Codegen only create function declaration
272
+ ),
273
+ (
274
+ "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
275
+ ),
276
+ )
277
+
278
+ def __init__(self, *default_properties: str) -> None:
279
+ properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
280
+ LazyIrProperties.Properties
281
+ )
282
+ self.__dict__["properties"] = properties
283
+ for p in default_properties:
284
+ setattr(self, p, True)
285
+
286
+ def __getattr__(self, key: str) -> Any:
287
+ properties = self.__dict__["properties"]
288
+ for values in LazyIrProperties.Properties:
289
+ if key in values:
290
+ return properties[values] == key
291
+
292
+ return self.__getattribute__(key)
293
+
294
+ def __setattr__(self, key: str, value: Any) -> Any:
295
+ properties = self.__dict__["properties"]
296
+ for values in LazyIrProperties.Properties:
297
+ if key in values:
298
+ properties[values] = key if value else None
299
+ return value
300
+
301
+ raise KeyError(f"Invalid property: {key}")
302
+
303
+
304
+ # Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
305
+ # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
306
+ # but carries type information from a native FunctionSchema modified for use with IR nodes,
307
+ # and preserving original argument names.
308
+ #
309
+ # TODO: This is not idiomatic with how other torchgen APIs transform on schema.
310
+ class LazyIrSchema:
311
+ # The name of the operator this function schema describes.
312
+ name: OperatorName
313
+
314
+ positional_args: tuple[LazyArgument, ...]
315
+ keyword_args: tuple[LazyArgument, ...]
316
+
317
+ # TODO: Need to handle collisions with argument names at some point
318
+ returns: tuple[Return, ...]
319
+
320
+ # if this schema has a Generator arg, list its orig ctype/name but don't
321
+ # build a LazyArgument since lazy IR doesn't support it
322
+ generator_arg: NamedCType | None = None
323
+
324
+ # original function schema
325
+ func: FunctionSchema
326
+
327
+ # Whether or not we are code-genning for SymInt or not
328
+ symint: bool
329
+
330
+ properties: LazyIrProperties = LazyIrProperties(
331
+ # default properties
332
+ "ShapePrecompute",
333
+ "Lower",
334
+ "CanBeReused",
335
+ )
336
+ opkind: str | None = None
337
+
338
+ def __init__(
339
+ self,
340
+ func: FunctionSchema,
341
+ properties: LazyIrProperties | None = None,
342
+ *,
343
+ symint: bool,
344
+ ) -> None:
345
+ if properties:
346
+ self.properties = properties
347
+
348
+ self.func = func
349
+ self.symint = symint
350
+ positional_args: list[LazyArgument] = []
351
+ for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
352
+ if arg_field == "self_arg" and func.arguments.self_arg is not None:
353
+ arg = func.arguments.self_arg.argument
354
+ positional_args.append(
355
+ LazyArgument(arg, self.properties, symint=symint)
356
+ )
357
+ elif getattr(func.arguments, arg_field) is not None:
358
+ positional_args.extend(
359
+ LazyArgument(arg, self.properties, symint=symint)
360
+ for arg in getattr(func.arguments, arg_field)
361
+ )
362
+ self.positional_args = tuple(positional_args)
363
+
364
+ keyword_args: list[LazyArgument] = []
365
+ for arg_field in [
366
+ "pre_tensor_options_kwarg_only",
367
+ "tensor_options",
368
+ "post_tensor_options_kwarg_only",
369
+ "out",
370
+ ]:
371
+ curr_args = getattr(func.arguments, arg_field)
372
+ if curr_args is not None:
373
+ if isinstance(curr_args, TensorOptionsArguments):
374
+ curr_args = curr_args.all()
375
+ for arg in curr_args:
376
+ if isGeneratorType(arg.type):
377
+ assert (
378
+ self.generator_arg is None
379
+ ), "We expect there is only one generator arg"
380
+ self.generator_arg = NamedCType(
381
+ arg.name, arg.type # type:ignore[arg-type]
382
+ )
383
+ keyword_args.extend(
384
+ LazyArgument(arg, self.properties, symint=symint)
385
+ for arg in curr_args
386
+ )
387
+ self.keyword_args = tuple(keyword_args)
388
+ self.name = func.name
389
+ self.returns = func.returns
390
+
391
+ @property
392
+ def node_name(self) -> str:
393
+ """
394
+ Return camel-case version of op in node.
395
+
396
+ Note: This function also appends any `overload_name` in the operation.
397
+ For example, if the op is `bitwise_and.Tensor`, the returned name
398
+ will be `BitwiseAndTensor`.
399
+ """
400
+ op_name = f"{self.name.name}_{self.name.overload_name}".lower()
401
+ return "".join(word.capitalize() or "" for word in op_name.split("_"))
402
+
403
+ @property
404
+ def aten_name(self) -> str:
405
+ return str(self.name.name)
406
+
407
+ @property
408
+ def base_name(self) -> str:
409
+ return f"{self.name.name.base}"
410
+
411
+ def filtered_args(
412
+ self,
413
+ positional: bool = True,
414
+ keyword: bool = True,
415
+ values: bool = True,
416
+ scalars: bool = True,
417
+ generator: bool = True,
418
+ ) -> list[LazyArgument]:
419
+ # This function maintains the sorted order of arguments but provides different filtered views.
420
+ # Some parts of the code care about kwargs vs args (TS lowerings),
421
+ # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
422
+ # Generators are special cased, as they are needed for fallback/shape-inference but not supported
423
+ # in TS lowerings and therefore also omitted from lazy IR.
424
+ args: list[LazyArgument] = []
425
+ if positional:
426
+ args.extend(self.positional_args)
427
+ if keyword:
428
+ args.extend(self.keyword_args)
429
+
430
+ if values and scalars and generator:
431
+ return args
432
+ elif values and scalars:
433
+ return [a for a in args if not a.is_generator]
434
+ elif values:
435
+ return [a for a in args if a.is_lazy_value]
436
+ elif scalars:
437
+ return [
438
+ a
439
+ for a in args
440
+ if not a.is_lazy_value and (generator or not a.is_generator)
441
+ ]
442
+
443
+ return []
444
+
445
+ @property
446
+ def positional_values(self) -> list[LazyArgument]:
447
+ return self.filtered_args(
448
+ positional=True, keyword=False, values=True, scalars=False
449
+ )
450
+
451
+ @property
452
+ def positional_scalars(self) -> list[LazyArgument]:
453
+ return self.filtered_args(
454
+ positional=True, keyword=False, values=False, scalars=True
455
+ )
456
+
457
+ @property
458
+ def keyword_values(self) -> list[LazyArgument]:
459
+ return self.filtered_args(
460
+ positional=False, keyword=True, values=True, scalars=False
461
+ )
462
+
463
+ @property
464
+ def keyword_scalars(self) -> list[LazyArgument]:
465
+ return self.filtered_args(
466
+ positional=False, keyword=True, values=False, scalars=True
467
+ )
.venv/lib/python3.11/site-packages/torchgen/api/meta.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchgen.model import NativeFunctionsGroup
2
+
3
+
4
+ # Follows dispatcher calling convention, but:
5
+ # - Mutable arguments not allowed. Meta functions are always
6
+ # written in functional form. Look at FunctionSchema.signature()
7
+ # - No tensor returns; instead we return a TensorMeta describing
8
+ # the tensor in question
9
+
10
+
11
+ def name(g: NativeFunctionsGroup) -> str:
12
+ # use the overload name from the functional version
13
+ return str(g.functional.func.name).replace(".", "_")
.venv/lib/python3.11/site-packages/torchgen/api/python.py ADDED
@@ -0,0 +1,1519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Sequence
5
+
6
+ from torchgen.api import cpp
7
+ from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
8
+ from torchgen.gen import pythonify_default
9
+ from torchgen.model import (
10
+ Argument,
11
+ BaseTy,
12
+ BaseType,
13
+ FunctionSchema,
14
+ ListType,
15
+ NativeFunction,
16
+ OptionalType,
17
+ Return,
18
+ Type,
19
+ Variant,
20
+ )
21
+
22
+
23
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
24
+ #
25
+ # Data Models
26
+ #
27
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
28
+ #
29
+ # [Notes] python binding codegen
30
+ #
31
+ # The Python binding codegen produces code that takes the input list of
32
+ # PyObjects, finds the matching ATen C++ function using PythonArgParser,
33
+ # converts the PyObjects into C++ types and calls the ATen C++ function:
34
+ #
35
+ # +--------+ parsing +------------------------+ binding +-----------------------+
36
+ # | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
37
+ # +--------+ +------------------------+ +-----------------------+
38
+ #
39
+ # The following examples demonstrate the data models the Python binding
40
+ # codegen needs to deal with and the tasks it needs to accomplish. It
41
+ # helps understand the purpose of the new data types we introduced below.
42
+ #
43
+ # - Function Schema (source of truth)
44
+ #
45
+ # aten::empty.names(int[] size, *, Dimname[]? names,
46
+ # ScalarType? dtype=None, Layout? layout=None,
47
+ # Device? device=None, bool? pin_memory=None,
48
+ # MemoryFormat? memory_format=None) -> Tensor
49
+ #
50
+ # - Python Signature
51
+ #
52
+ # It's used to generate input schema string for PythonArgParser.
53
+ # Note: TensorOptions fields are reordered and the additional
54
+ # 'requires_grad' field is added:
55
+ #
56
+ # empty(IntArrayRef size, *, DimnameList? names,
57
+ # MemoryFormat? memory_format=None, ScalarType dtype=None,
58
+ # Layout layout=torch.strided, Device device=None,
59
+ # bool pin_memory=False, bool requires_grad=False)
60
+ #
61
+ # - C++ Signature
62
+ #
63
+ # It's used to generate C++ lambda formals & dispatch call.
64
+ # Note: the scattered TensorOptions fields are packed into 'options'.
65
+ #
66
+ # auto dispatch_empty =
67
+ # [](IntArrayRef size, std::optional<DimnameList> names,
68
+ # const TensorOptions & options,
69
+ # std::optional<MemoryFormat> memory_format) -> Tensor {
70
+ # pybind11::gil_scoped_release no_gil;
71
+ # return torch::empty(size, names, options, memory_format);
72
+ # };
73
+ #
74
+ # - Binding between Python Arguments and C++ Arguments
75
+ #
76
+ # Given a set of Python Arguments in scope, we need produce the
77
+ # binding expressions that translate the Python API into C++ API:
78
+ #
79
+ # Python Args Cpp Args Binding Exprs
80
+ # -----------------------------------------------------------------
81
+ # 0: size size '_r.intlist(0)'
82
+ # 1: names names 'names' [special init]
83
+ # 2: memory_format -------+
84
+ # 3: dtype -----+-|--> options 'options' [special packing]
85
+ # 4: layout / |
86
+ # 5: device / +--> memory_format '_r.memoryformatOptional(2)'
87
+ # 6: pin_memory /
88
+ # 7: requires_grad -+
89
+ #
90
+ # So the full dispatch expression would look like:
91
+ #
92
+ # dispatch_empty(_r.intlist(0), names, options,
93
+ # _r.memoryformatOptional(2))
94
+ #
95
+ # Where does 'names' come from? It involves special local init:
96
+ #
97
+ # auto __names = _r.toDimnameListOptional(1);
98
+ # std::optional<DimnameList> names =
99
+ # __names ? std::make_optional(DimnameList(__names.value()))
100
+ # : std::nullopt;
101
+ #
102
+ # Where does 'options' come from? It involves special local init
103
+ # for TensorOptions. Note that Python side has the additional
104
+ # 'requires_grad' field:
105
+ #
106
+ # const auto options = TensorOptions()
107
+ # .dtype(_r.scalartype(3))
108
+ # .device(_r.device(5))
109
+ # .layout(_r.layoutOptional(4))
110
+ # .requires_grad(_r.toBool(7))
111
+ # .pinned_memory(_r.toBool(6));
112
+ #
113
+ # In some other cases one Python Argument can map to multiple C++
114
+ # Arguments. For example:
115
+ #
116
+ # aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
117
+ # -> (Tensor values, Tensor indices)
118
+ #
119
+ # Python Args Cpp Args Binding Exprs
120
+ # ---------------------------------------------------------------------
121
+ # +----> max 'out[0]'
122
+ # /-----> max_values 'out[1]
123
+ # 0: input / self '_r.tensor(0)'
124
+ # 1: dim / dim '_r.dimname(1)'
125
+ # 2: keepdim / keepdim '_r.toBool(2)'
126
+ # 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)'
127
+ #
128
+ # As demonstrated above, the binding can involve reordering,
129
+ # packing, unpacking and special local inits.
130
+ #
131
+ #
132
+ # Let's look at a concrete example:
133
+ #
134
+ # static PythonArgParser parser({
135
+ # "abs(Tensor input, *, Tensor out=None)",
136
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
137
+ # ^
138
+ # +--- Python Schema, represented by PythonSignature and PythonArgument
139
+ #
140
+ # }, /*traceable=*/true);
141
+ #
142
+ # ParsedArgs<2> parsed_args;
143
+ # auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
144
+ #
145
+ # ...
146
+ #
147
+ # if (_r.isNone(1)) {
148
+ # ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out')
149
+ # represented by PythonArgParserOutputExpr
150
+ #
151
+ # // aten::abs(Tensor self) -> Tensor
152
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
153
+ # ^
154
+ # +--- NativeFunction schema, base version
155
+ #
156
+ # auto dispatch_abs = [](const Tensor & self) -> Tensor {
157
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
158
+ # ^
159
+ # +--- dispatch_lambda_args / dispatch_lambda_return_str
160
+ # generated from NativeFunction / CppSignature
161
+ # (deprecated PythonSignature is special)
162
+ # arguments are represented by DispatchLambdaArgument
163
+ #
164
+ # pybind11::gil_scoped_release no_gil;
165
+ # return self.abs();
166
+ # ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs
167
+ # generated from NativeFunction / CppSignature
168
+ # };
169
+ # return wrap(dispatch_abs(_r.tensor(0)));
170
+ # ~~~~~~~~~~~~~
171
+ # ^
172
+ # +--- dispatch_lambda_exprs
173
+ # binding PythonArgParserOutputExpr (python args)
174
+ # and DispatchLambdaArgument (c++ args)
175
+ #
176
+ # } else {
177
+ # // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
178
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
179
+ # ^
180
+ # +--- NativeFunction schema, out-variant
181
+ #
182
+ # auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
183
+ # pybind11::gil_scoped_release no_gil;
184
+ # return at::abs_out(out, self);
185
+ # };
186
+ # return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
187
+ # }
188
+ #
189
+ #
190
+ # [Notes] python interface codegen
191
+ # The python dataclasses below are used used to generate both python binding code
192
+ # and pyi type hint signatures.
193
+ # In theory these two should look very similar, but there are number of differences
194
+ # in how pyi signatures vs. python_arg_parser signatures are generated.
195
+ # These differences have been encapsulated in signature_str() vs. signature_str_pyi()
196
+ # to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
197
+ # For examples, only pyi signatures include return types.
198
+
199
+
200
+ @dataclass(frozen=True)
201
+ class PythonReturns:
202
+ returns: tuple[Return, ...]
203
+
204
+
205
+ @dataclass(frozen=True)
206
+ class PythonArgument:
207
+ name: str
208
+ type: Type
209
+ default: str | None
210
+
211
+ # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
212
+ #
213
+ # _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
214
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
215
+ # ^
216
+ # +--- default_init str
217
+ default_init: str | None
218
+
219
+ # Compute argument formal for python argument parsing.
220
+ # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
221
+ def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
222
+ type_str = (
223
+ argument_type_str(self.type, symint=symint)
224
+ .replace("const ", "")
225
+ .replace(" &", "")
226
+ )
227
+
228
+ name = self.name
229
+ # s/self/input/ outside method bindings
230
+ # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
231
+ # for the parse string
232
+ if name == "self" and type_str in ["Tensor", "Number"] and not method:
233
+ name = "input"
234
+
235
+ # add default
236
+ if self.default is not None:
237
+ default = {
238
+ "nullptr": "None",
239
+ "::std::nullopt": "None",
240
+ "std::nullopt": "None",
241
+ "{}": "None",
242
+ }.get(self.default, self.default)
243
+ return f"{type_str} {name}={default}"
244
+ else:
245
+ return f"{type_str} {name}"
246
+
247
+ def argument_str_pyi(
248
+ self, *, method: bool = False, deprecated: bool = False
249
+ ) -> str:
250
+ type_str = argument_type_str_pyi(self.type)
251
+
252
+ name = self.name
253
+ # s/self/input/ outside method bindings
254
+ # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
255
+ # for the parse string
256
+ if name == "self" and type_str == "Tensor" and not method and not deprecated:
257
+ name = "input"
258
+
259
+ if name == "from": # from is a Python keyword...
260
+ name += "_"
261
+
262
+ # pyi merges the _out and functional variants into the same signature, with an optional out arg
263
+ if name == "out" and type_str == "Tensor" and not deprecated:
264
+ type_str = "Optional[" + type_str + "]"
265
+
266
+ # pyi deprecated signatures don't get defaults for their out arg
267
+ treat_as_no_default = (
268
+ deprecated
269
+ and isinstance(self, PythonOutArgument)
270
+ and self.default == "None"
271
+ )
272
+
273
+ # add default
274
+ if self.default is not None and not treat_as_no_default:
275
+ if (
276
+ isinstance(self.type, ListType)
277
+ and self.type.elem == BaseType(BaseTy.int)
278
+ and self.default.startswith("{")
279
+ and self.default.endswith("}")
280
+ ):
281
+ default = (
282
+ "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
283
+ )
284
+ else:
285
+ default = {
286
+ "nullptr": "None",
287
+ "::std::nullopt": "None",
288
+ "std::nullopt": "None",
289
+ "{}": "None",
290
+ "c10::MemoryFormat::Contiguous": "contiguous_format",
291
+ "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
292
+ }.get(self.default, self.default)
293
+ return f"{name}: {type_str} = {default}"
294
+ else:
295
+ return f"{name}: {type_str}"
296
+
297
+
298
+ @dataclass(frozen=True)
299
+ class PythonOutArgument(PythonArgument):
300
+ # In Python signature multiple output fields are packed into one 'out' argument.
301
+ # When binding to C++, it's first binded to a local 'out' variable:
302
+ # 'auto out = _r.tensorlist_n<2>(2);',
303
+ # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
304
+ # TODO: maybe don't need keep scattered out fields for python signature?
305
+ outputs: tuple[PythonArgument, ...]
306
+
307
+ @staticmethod
308
+ def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
309
+ if not outputs:
310
+ return None
311
+
312
+ size = len(outputs)
313
+ if size == 1:
314
+ return PythonOutArgument(
315
+ name=outputs[0].name,
316
+ type=outputs[0].type,
317
+ default="None",
318
+ default_init=None,
319
+ outputs=outputs,
320
+ )
321
+ elif size > 1:
322
+ if any(not a.type.is_tensor_like() for a in outputs):
323
+ raise RuntimeError(f"Unsupported output type: {outputs}")
324
+ return PythonOutArgument(
325
+ name="out",
326
+ # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
327
+ type=ListType(BaseType(BaseTy.Tensor), size),
328
+ default="None",
329
+ default_init=None,
330
+ outputs=outputs,
331
+ )
332
+ raise AssertionError(r"Unexpected PythonOutArgument size")
333
+
334
+
335
+ @dataclass(frozen=True)
336
+ class PythonSignature:
337
+ # Base operator name, without inplace/outplace suffix.
338
+ name: str
339
+
340
+ # Positional arguments.
341
+ # TODO: create a dedicated SelfArgument type for 'self'?
342
+ input_args: tuple[PythonArgument, ...]
343
+
344
+ # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
345
+ # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
346
+ input_kwargs: tuple[PythonArgument, ...]
347
+
348
+ output_args: PythonOutArgument | None
349
+
350
+ # Return types, which are only used by pyi
351
+ returns: PythonReturns
352
+
353
+ # These are scattered kwargs arguments belonging to TensorOptions.
354
+ # When binding to C++, they are packed into a TensorOptions object 'options'.
355
+ # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
356
+ # for out variant), in which case they will be used as scattered fields without
357
+ # being packed into 'options'.
358
+ # TODO: maybe create a PythonTensorOptionsArgument?
359
+ tensor_options_args: tuple[PythonArgument, ...]
360
+
361
+ # method or function signature?
362
+ method: bool
363
+
364
+ @property
365
+ def deprecated(self) -> bool:
366
+ return False
367
+
368
+ def arguments(
369
+ self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
370
+ ) -> tuple[PythonArgument | PythonOutArgument, ...]:
371
+ result: list[PythonArgument | PythonOutArgument] = []
372
+ result.extend(self.input_args)
373
+ result.extend(self.input_kwargs)
374
+ if self.output_args is not None and not skip_outputs:
375
+ result.append(self.output_args)
376
+ if not skip_tensor_options:
377
+ result.extend(self.tensor_options_args)
378
+ return tuple(result)
379
+
380
+ def arguments_count(self) -> int:
381
+ return len(self.arguments())
382
+
383
+ def output_idx(self) -> int:
384
+ return len(self.input_args) + len(self.input_kwargs)
385
+
386
+ # [old codegen] Compute the Python function signature for argument parsing,
387
+ # as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
388
+ # this is NOT the same type signature as specified by PEP 484
389
+ # as understood by mypy; our format was independently developed
390
+ # and has some quirks to make it more suitable specifically
391
+ # for error parsing.
392
+ #
393
+ # For a translation to mypy-valid type signatures, see
394
+ # signature_str_pyi().
395
+ def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
396
+ args = self.arguments(skip_outputs=skip_outputs)
397
+ schema_formals: list[str] = [
398
+ a.argument_str(method=self.method, symint=symint) for a in args
399
+ ]
400
+ positional_argc = len(self.input_args)
401
+ if len(schema_formals) > positional_argc:
402
+ schema_formals.insert(positional_argc, "*")
403
+
404
+ return f'{self.name}({", ".join(schema_formals)})'
405
+
406
+ def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
407
+ args = self.arguments(skip_outputs=skip_outputs)
408
+ schema_formals: list[str] = [
409
+ a.argument_str_pyi(method=self.method) for a in args
410
+ ]
411
+ positional_argc = len(self.input_args)
412
+ if len(schema_formals) > positional_argc:
413
+ schema_formals.insert(positional_argc, "*")
414
+
415
+ # only pyi signatures include returns
416
+ returns_str = returns_str_pyi(self)
417
+ # pyi also includes self (with no typing/defaults) for methods
418
+ if self.method:
419
+ schema_formals.insert(0, "self")
420
+ return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
421
+
422
+ def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
423
+ # only pyi uses vararg signatures
424
+ args = self.arguments(skip_outputs=skip_outputs)
425
+ schema_formals: list[str] = [
426
+ a.argument_str_pyi(method=self.method) for a in args
427
+ ]
428
+ # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
429
+ num_args = self.arguments_count()
430
+ num_positionalargs = len(self.input_args)
431
+
432
+ have_vararg_version = False
433
+ if num_args > 0:
434
+ vararg_type = args[0].type
435
+ if (
436
+ isinstance(vararg_type, ListType)
437
+ and str(vararg_type.elem) in ["int", "SymInt"]
438
+ and num_positionalargs == 1
439
+ ):
440
+ have_vararg_version = True
441
+
442
+ if not have_vararg_version:
443
+ return None
444
+
445
+ # Below are the major changes in vararg vs. regular pyi signatures
446
+ # vararg signatures also omit the asterix
447
+ assert isinstance(vararg_type, ListType)
448
+ schema_formals[0] = (
449
+ "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
450
+ )
451
+
452
+ returns_str = returns_str_pyi(self)
453
+ # pyi also includes self (with no typing/defaults) for methods
454
+ if self.method:
455
+ schema_formals.insert(0, "self")
456
+ return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
457
+
458
+
459
+ # The deprecated python signature involves some special logic, so create a
460
+ # dedicated data model to store these extra properties.
461
+ @dataclass(frozen=True)
462
+ class PythonSignatureDeprecated(PythonSignature):
463
+ # Schema for the deprecated function
464
+ deprecated_schema: FunctionSchema
465
+
466
+ # The deprecated signature might miss some arguments that the corresponding
467
+ # C++ signature expects. We need store the constant default values to pass in.
468
+ # For example:
469
+ # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
470
+ # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
471
+ # [func call]: self.addmm(mat1, mat2, beta, 1)
472
+ # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
473
+ deprecated_args_exprs: tuple[str, ...]
474
+
475
+ @property
476
+ def deprecated(self) -> bool:
477
+ return True
478
+
479
+ def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
480
+ return (
481
+ PythonSignature.signature_str(
482
+ self, skip_outputs=skip_outputs, symint=symint
483
+ )
484
+ + "|deprecated"
485
+ )
486
+
487
+ def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
488
+ args = self.arguments(skip_outputs=skip_outputs)
489
+ schema_formals: list[str] = [
490
+ a.argument_str_pyi(method=self.method, deprecated=True) for a in args
491
+ ]
492
+ positional_argc = len(self.input_args)
493
+ if len(schema_formals) > positional_argc:
494
+ schema_formals.insert(positional_argc, "*")
495
+
496
+ returns_str = returns_str_pyi(self)
497
+ return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
498
+
499
+ def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
500
+ # the codegen doesn't include vararg variants for deprecated signatures
501
+ return None
502
+
503
+
504
+ # This struct is used to hold the PythonSignature and its corresponding
505
+ # NativeFunction BEFORE grouping base and out-variant functions.
506
+ # Why not store NativeFunction in PythonSignature or construct PythonSignature
507
+ # from NativeFunction? Because they are not 1-1 mapped.
508
+ # One native function could have both deprecated and non-deprecated python
509
+ # signatures - NativeFunction doesn't contain information to construct the
510
+ # deprecated python signature.
511
+ # One python signature is used to handle both the base and the out-variant
512
+ # function - see 'PythonSignatureGroup'.
513
+ @dataclass(frozen=True)
514
+ class PythonSignatureNativeFunctionPair:
515
+ signature: PythonSignature
516
+ function: NativeFunction
517
+
518
+
519
+ # We merge pairs of functions with signatures that are equivalent mod
520
+ # output arguments, and use a single entry in the python_arg_parser sig
521
+ # list for both (output arguments become optional).
522
+ @dataclass(frozen=True)
523
+ class PythonSignatureGroup:
524
+ # The signature used for Python argument parsing. The outplace signature
525
+ # is preferred if exists, because it can be used to parse inputs for both
526
+ # the out-place variant and the base version (with output omitted).
527
+ signature: PythonSignature
528
+
529
+ # The regular ATen declaration (e.g. conv2d)
530
+ base: NativeFunction
531
+
532
+ # The out variant (e.g. conv2d_out)
533
+ outplace: NativeFunction | None
534
+
535
+ @classmethod
536
+ def from_pairs(
537
+ cls,
538
+ functional: PythonSignatureNativeFunctionPair,
539
+ out: PythonSignatureNativeFunctionPair | None,
540
+ ) -> PythonSignatureGroup:
541
+ if out is None:
542
+ return PythonSignatureGroup(
543
+ signature=functional.signature,
544
+ base=functional.function,
545
+ outplace=None,
546
+ )
547
+
548
+ # prefer the signature with optional out=... arguments because it's the
549
+ # superset that can be used to parse input for both base and outplace.
550
+ signature_kwargs = out.signature.__dict__.copy()
551
+
552
+ # Out overloads in C++ don't have TensorOptions arguments,
553
+ # so take these from the functional variant
554
+ signature_kwargs[
555
+ "tensor_options_args"
556
+ ] = functional.signature.tensor_options_args
557
+
558
+ return PythonSignatureGroup(
559
+ signature=type(out.signature)(**signature_kwargs),
560
+ base=functional.function,
561
+ outplace=out.function,
562
+ )
563
+
564
+
565
+ # C++ function dispatch is wrapped in a lambda function. The lambda function
566
+ # has almost the same signature as the C++ function, only with some small
567
+ # variants - see details below.
568
+ # This data model is used to represent arguments of the lambda function
569
+ # signature.
570
+ @dataclass(frozen=True)
571
+ class DispatchLambdaArgument:
572
+ name: str
573
+ type_str: str
574
+ is_out_arg: bool
575
+
576
+
577
+ # To pass PyObjects arguments to C++ function (via the lambda wrapper),
578
+ # we need first convert PyObjects into simple C++ objects. This work
579
+ # is done by PythonArgParser.
580
+ # This data model is used to represent the output of PythonArgParser.
581
+ # It has 1-1 mapping with PythonArgument in PythonSignature.
582
+ @dataclass(frozen=True)
583
+ class PythonArgParserOutputExpr:
584
+ # argument name
585
+ name: str
586
+
587
+ # RHS expression to reference PythonArgParser output.
588
+ expr: str
589
+
590
+ # In some special cases we need create different expr, e.g.:
591
+ # '_r.isNone(1)' instead of '_r.tensor(1)'.
592
+ index: int
593
+
594
+ # The python argument it maps to.
595
+ argument: PythonArgument
596
+
597
+ @property
598
+ def is_none_expr(self) -> str:
599
+ return f"_r.isNone({self.index})"
600
+
601
+
602
+ # To pass PythonArgParser output to the lambda wrapper, we need bind
603
+ # PythonArgParserOutputExpr to DispatchLambdaArgument.
604
+ # They are not always 1-1 mapped, e.g. scattered TensorOptions fields
605
+ # need be packed into a TensorOptions object, which is the argument
606
+ # that the lambda function wrapper takes.
607
+ @dataclass(frozen=True)
608
+ class DispatchLambdaArgumentExprs:
609
+ # The exprs that provide the binding for lambda arguments, e.g.:
610
+ #
611
+ # 'self' -> '_r.tensor(0)'
612
+ # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
613
+ # 'options' -> 'options'
614
+ #
615
+ # It has 1-1 mapping with DispatchLambdaArgument.
616
+ exprs: Sequence[str]
617
+
618
+ # Special local inits, which might introduce new variables that
619
+ # the 'exprs' above reference, e.g.:
620
+ #
621
+ # 'auto out = _r.tensorlist_n<2>(2);'
622
+ #
623
+ inits: Sequence[str]
624
+
625
+
626
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
627
+ #
628
+ # Helper Functions
629
+ #
630
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
631
+
632
+
633
+ def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
634
+ return CppSignatureGroup.from_native_function(f, method=method).signature
635
+
636
+
637
+ def has_tensor_options(f: NativeFunction) -> bool:
638
+ return f.func.arguments.tensor_options is not None
639
+
640
+
641
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
642
+ #
643
+ # Python Signature
644
+ #
645
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
646
+
647
+
648
+ # 'simple_type' was introduced by the old codegen, which is slightly
649
+ # different from the python schema type, e.g.: doesn't have '?' suffix
650
+ # for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
651
+ def argument_type_str(
652
+ t: Type, *, simple_type: bool = False, symint: bool = True
653
+ ) -> str:
654
+ if isinstance(t, BaseType):
655
+ if t.name == BaseTy.Tensor:
656
+ return "Tensor"
657
+ elif t.name == BaseTy.int:
658
+ return "int64_t"
659
+ elif t.name == BaseTy.float:
660
+ return "double"
661
+ elif t.name == BaseTy.str:
662
+ return "c10::string_view"
663
+ elif t.name in [
664
+ BaseTy.bool,
665
+ BaseTy.QScheme,
666
+ BaseTy.Scalar,
667
+ BaseTy.ScalarType,
668
+ BaseTy.Generator,
669
+ BaseTy.Storage,
670
+ BaseTy.Layout,
671
+ BaseTy.Device,
672
+ BaseTy.DeviceIndex,
673
+ BaseTy.MemoryFormat,
674
+ BaseTy.Dimname,
675
+ BaseTy.Stream,
676
+ BaseTy.ConstQuantizerPtr,
677
+ BaseTy.SymInt,
678
+ ]:
679
+ # These python schema type names line up with their function schema names
680
+ return t.name.name
681
+
682
+ elif isinstance(t, OptionalType):
683
+ if str(t.elem) == "Tensor":
684
+ # Is it desired to keep '?' for simple_type with new style dispatcher?
685
+ return "Tensor?"
686
+ elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
687
+ return f"{elem}?"
688
+ elif isinstance(t, ListType):
689
+ size = t.size if not simple_type else None
690
+ if str(t.elem) == "bool":
691
+ assert t.size is not None
692
+ return f"::std::array<bool,{t.size}>"
693
+ elif str(t.elem) == "int":
694
+ return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
695
+ elif str(t.elem) == "SymInt":
696
+ if symint:
697
+ return (
698
+ f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
699
+ )
700
+ else:
701
+ return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
702
+ elif str(t.elem) == "Tensor":
703
+ return f"TensorList[{size}]" if size is not None else "TensorList"
704
+ elif str(t.elem) == "Scalar":
705
+ return f"ScalarList[{size}]" if size is not None else "ScalarList"
706
+ elif str(t.elem) == "Tensor?":
707
+ if simple_type:
708
+ return "c10::List<::std::optional<Tensor>>"
709
+ else:
710
+ return "const c10::List<::std::optional<Tensor>> &"
711
+ elif str(t.elem) == "Dimname":
712
+ return f"DimnameList[{size}]" if size is not None else "DimnameList"
713
+ elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
714
+ return f"ArrayRef<{elem}>"
715
+
716
+ raise RuntimeError(f"unrecognized type {repr(t)}")
717
+
718
+
719
+ def argument_type_size(t: Type) -> int | None:
720
+ l = t.is_list_like()
721
+ if l is not None and str(l.elem) != "bool":
722
+ return l.size
723
+ else:
724
+ return None
725
+
726
+
727
+ def argument(a: Argument) -> PythonArgument:
728
+ return PythonArgument(
729
+ name=a.name,
730
+ type=a.type,
731
+ # TODO: directly translate a.default to python default
732
+ default=(
733
+ str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
734
+ if a.default is not None
735
+ else None
736
+ ),
737
+ default_init=None,
738
+ )
739
+
740
+
741
+ # Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
742
+ def signature(
743
+ f: NativeFunction, *, method: bool = False, pyi: bool = False
744
+ ) -> PythonSignature:
745
+ return signature_from_schema(
746
+ f.func, category_override=f.category_override, method=method, pyi=pyi
747
+ )
748
+
749
+
750
+ def signature_from_schema(
751
+ func: FunctionSchema,
752
+ *,
753
+ category_override: str | None,
754
+ method: bool = False,
755
+ pyi: bool = False,
756
+ ) -> PythonSignature:
757
+ args: list[Argument] = []
758
+ args.extend(func.arguments.pre_self_positional)
759
+ # Skip SelfArgument if this is method.
760
+ if not method and func.arguments.self_arg is not None:
761
+ args.append(func.arguments.self_arg.argument)
762
+ args.extend(func.arguments.post_self_positional)
763
+ args.extend(func.arguments.pre_tensor_options_kwarg_only)
764
+ # Skip TensorOptionsArguments. Python side TensorOptions
765
+ # arguments are created based on different rules - see below.
766
+ args.extend(func.arguments.post_tensor_options_kwarg_only)
767
+ args.extend(func.arguments.out)
768
+
769
+ input_arg_set = {a.name for a in func.arguments.flat_positional}
770
+ kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
771
+ out_arg_set = {a.name for a in func.arguments.out}
772
+
773
+ input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
774
+ input_kwargs = tuple(
775
+ map(argument, filter(lambda a: a.name in kwarg_only_set, args))
776
+ )
777
+ outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
778
+
779
+ # Reintroduce the scattered fields of TensorOptions for Python.
780
+ # Compared to the cpp counterpart, the python arguments have new property
781
+ # (default_init) and a new argument 'requires_grad', which require some
782
+ # special handlings.
783
+ # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
784
+ # to the original versions in the yaml, this recreation is a potential
785
+ # source of drift between eager and JIT. Pull this logic out to a shared place.
786
+
787
+ has_tensor_input_arg = any(
788
+ a.type.is_tensor_like() for a in func.arguments.flat_non_out
789
+ )
790
+ if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
791
+ raise ValueError(
792
+ "argument named requires_grad is reserved, should not explicitly add it in the schema"
793
+ )
794
+
795
+ # [old codegen] this probably won't work if one of the returns is not a tensor,
796
+ # but it will produce a compile-time error that is obvious.
797
+ has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
798
+
799
+ name: str = cpp.name(func)
800
+ is_factory_function = category_override == "factory" or (
801
+ has_tensor_return and not has_tensor_input_arg
802
+ )
803
+ is_like_or_new_function = (
804
+ category_override in ("new", "like")
805
+ or name.startswith("new_")
806
+ or name.endswith("_like")
807
+ )
808
+ is_dummy_function = category_override == "dummy"
809
+
810
+ tensor_options_args: list[PythonArgument] = []
811
+ if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
812
+
813
+ def topt_default_init(name: str) -> str | None:
814
+ topt_args = func.arguments.tensor_options
815
+ if topt_args is None:
816
+ return None
817
+ a = getattr(topt_args, name)
818
+ if a.default is None or a.default == "None":
819
+ return None
820
+ return cpp.default_expr(a.default, a.type, symint=False)
821
+
822
+ tensor_options_args.append(
823
+ PythonArgument(
824
+ name="dtype",
825
+ type=OptionalType(BaseType(BaseTy.ScalarType)),
826
+ default="None",
827
+ default_init=(
828
+ None if is_like_or_new_function else topt_default_init("dtype")
829
+ ),
830
+ )
831
+ )
832
+ tensor_options_args.append(
833
+ PythonArgument(
834
+ name="layout",
835
+ type=OptionalType(BaseType(BaseTy.Layout)),
836
+ default="None",
837
+ default_init=(
838
+ None if is_like_or_new_function else topt_default_init("layout")
839
+ ),
840
+ )
841
+ )
842
+ tensor_options_args.append(
843
+ PythonArgument(
844
+ name="device",
845
+ type=OptionalType(BaseType(BaseTy.Device)),
846
+ default="None",
847
+ default_init=(
848
+ None
849
+ if is_like_or_new_function
850
+ else (
851
+ topt_default_init("device")
852
+ or "torch::tensors::get_default_device()"
853
+ )
854
+ ),
855
+ )
856
+ )
857
+ tensor_options_args.append(
858
+ PythonArgument(
859
+ name="pin_memory",
860
+ type=OptionalType(BaseType(BaseTy.bool)),
861
+ default="False",
862
+ default_init=None,
863
+ )
864
+ )
865
+ tensor_options_args.append(
866
+ PythonArgument(
867
+ name="requires_grad",
868
+ type=OptionalType(BaseType(BaseTy.bool)),
869
+ default="False",
870
+ default_init=None,
871
+ )
872
+ )
873
+
874
+ returns = PythonReturns(returns=func.returns)
875
+
876
+ return PythonSignature(
877
+ name=str(func.name.name),
878
+ input_args=input_args,
879
+ input_kwargs=input_kwargs,
880
+ output_args=PythonOutArgument.from_outputs(outputs),
881
+ tensor_options_args=tuple(tensor_options_args),
882
+ returns=returns,
883
+ method=method,
884
+ )
885
+
886
+
887
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
888
+ #
889
+ # Python Interface
890
+ #
891
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
892
+
893
+
894
+ def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
895
+ if len(returns) <= 1 or all(r.name is None for r in returns):
896
+ return []
897
+ else:
898
+ if any(r.name is None for r in returns):
899
+ # When building on Windows, `PyStructSequence_UnnamedField` could not be
900
+ # resolved by the linker for some reason, which cause error in building:
901
+ #
902
+ # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
903
+ # PyStructSequence_UnnamedField
904
+ #
905
+ # Thus, at this point in time, we do not support unnamed
906
+ # fields in structseq; you must either name all fields,
907
+ # or none of them.
908
+ raise ValueError("Unnamed field is not supported by codegen")
909
+
910
+ return [str(r.name) for r in returns]
911
+
912
+
913
+ def argument_type_str_pyi(t: Type) -> str:
914
+ add_optional = False
915
+ if isinstance(t, OptionalType):
916
+ t = t.elem
917
+ add_optional = True
918
+
919
+ if isinstance(t, BaseType):
920
+ if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
921
+ ret = "_int"
922
+ if t.name == BaseTy.SymInt:
923
+ ret = "Union[_int, SymInt]"
924
+ elif t.name == BaseTy.float:
925
+ ret = "_float"
926
+ elif t.name == BaseTy.str:
927
+ ret = "str"
928
+ elif t.name == BaseTy.Scalar:
929
+ ret = "Union[Number, _complex]"
930
+ elif t.name == BaseTy.ScalarType:
931
+ ret = "_dtype"
932
+ elif t.name == BaseTy.bool:
933
+ ret = "_bool"
934
+ elif t.name == BaseTy.QScheme:
935
+ ret = "_qscheme"
936
+ elif t.name == BaseTy.Layout:
937
+ ret = "_layout"
938
+ elif t.name == BaseTy.Device:
939
+ ret = "Optional[DeviceLikeType]"
940
+ elif t.name == BaseTy.MemoryFormat:
941
+ ret = "memory_format"
942
+ elif t.name == BaseTy.Dimname:
943
+ ret = "Union[str, ellipsis, None]"
944
+ elif t.name == BaseTy.Storage:
945
+ ret = "Union[Storage, UntypedStorage]"
946
+ elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
947
+ # These python schema type names line up with their function schema names
948
+ ret = t.name.name
949
+
950
+ elif isinstance(t, ListType):
951
+ if str(t.elem) == "int":
952
+ ret = "Union[_int, _size]" if t.size is not None else "_size"
953
+ elif t.is_tensor_like():
954
+ # TODO: this doesn't seem right...
955
+ # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
956
+ # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
957
+ if isinstance(t.elem, OptionalType):
958
+ add_optional = True
959
+ ret = (
960
+ "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]"
961
+ if t.size is not None
962
+ else "Union[Tuple[Tensor, ...], List[Tensor]]"
963
+ )
964
+ elif str(t.elem) == "float":
965
+ ret = "Sequence[_float]"
966
+ elif str(t.elem) == "SymInt" and t.size is not None:
967
+ elem = argument_type_str_pyi(t.elem)
968
+ ret = f"Union[{elem}, Sequence[{elem}]]"
969
+ else:
970
+ elem = argument_type_str_pyi(t.elem)
971
+ ret = f"Sequence[{elem}]"
972
+
973
+ else:
974
+ raise RuntimeError(f"unrecognized type {repr(t)}")
975
+
976
+ if add_optional:
977
+ ret = "Optional[" + ret + "]"
978
+
979
+ return ret
980
+
981
+
982
+ def return_type_str_pyi(t: Type) -> str:
983
+ # Where arguments are open to accepting Union, return types should return
984
+ # concrete types
985
+
986
+ if isinstance(t, OptionalType):
987
+ inner = return_type_str_pyi(t.elem)
988
+ return f"Optional[{inner}]"
989
+
990
+ if isinstance(t, BaseType):
991
+ if t.name == BaseTy.Device:
992
+ return "_device"
993
+ elif t.name == BaseTy.Dimname:
994
+ ret = "Optional[str]"
995
+ else:
996
+ return argument_type_str_pyi(t)
997
+
998
+ if isinstance(t, ListType):
999
+ inner = return_type_str_pyi(t.elem)
1000
+ return f"Tuple[{inner}, ...]"
1001
+
1002
+ return argument_type_str_pyi(t)
1003
+
1004
+
1005
+ def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
1006
+ python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
1007
+ structseq_name = signature.name
1008
+ field_names = structseq_fieldnames(signature.returns.returns)
1009
+ if field_names:
1010
+ # These types are structseq objects which act like named NamedTuples, but
1011
+ # the constructor acts like the constructor of tuple. Using typing.NamedTuple
1012
+ # does not allow us to override __init__.
1013
+ seq_type = f"Tuple[{', '.join(python_returns)}]"
1014
+ structseq_def_lines = [
1015
+ f"class {structseq_name}({seq_type}):",
1016
+ ]
1017
+ for name, typ in zip(field_names, python_returns):
1018
+ structseq_def_lines.extend(
1019
+ [
1020
+ " @property",
1021
+ f" def {name}(self) -> {typ}: ...",
1022
+ ]
1023
+ )
1024
+ structseq_def_lines.extend(
1025
+ [
1026
+ f" def __new__(cls, sequence: {seq_type}): ...",
1027
+ f" n_fields: _int = {len(field_names)}",
1028
+ f" n_sequeunce_fields: _int = {len(field_names)}",
1029
+ " n_unnamed_fields: _int = 0",
1030
+ " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
1031
+ "", # add an extra newline
1032
+ ]
1033
+ )
1034
+ structseq_def = "\n".join(structseq_def_lines)
1035
+ # Example:
1036
+ # structseq_def = (
1037
+ # "class max(Tuple[Tensor, Tensor]):\n"
1038
+ # " @property\n"
1039
+ # " def values(self) -> Tensor: ...\n"
1040
+ # " @property\n"
1041
+ # " def indices(self) -> Tensor: ...\n"
1042
+ # " def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n"
1043
+ # " n_fields: _int = 2",
1044
+ # " n_sequeunce_fields: _int = 2",
1045
+ # " n_unnamed_fields: _int = 0",
1046
+ # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
1047
+ # )
1048
+ return structseq_name, structseq_def
1049
+ return None
1050
+
1051
+
1052
+ def returns_str_pyi(signature: PythonSignature) -> str:
1053
+ field_names = structseq_fieldnames(signature.returns.returns)
1054
+ if field_names:
1055
+ return f"torch.return_types.{signature.name}"
1056
+
1057
+ python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
1058
+ if len(python_returns) > 1:
1059
+ return "Tuple[" + ", ".join(python_returns) + "]"
1060
+ if len(python_returns) == 1:
1061
+ return python_returns[0]
1062
+ return "None"
1063
+
1064
+
1065
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1066
+ #
1067
+ # C++ Function Dispatch
1068
+ #
1069
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1070
+ # This section provides APIs to generate the code that does C++ function
1071
+ # dispatch. The C++ function call is wrapped by a lambda function.
1072
+ # For example:
1073
+ #
1074
+ # // aten::selu_(Tensor(a!) self) -> Tensor(a!)
1075
+ # auto dispatch_selu_ = [](Tensor self) -> Tensor {
1076
+ # pybind11::gil_scoped_release no_gil;
1077
+ # return at::selu_(self);
1078
+ # };
1079
+ #
1080
+ # The lambda function's signature follows the C++ signature in common
1081
+ # cases, e.g.:
1082
+ #
1083
+ # // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
1084
+ # [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
1085
+ #
1086
+ # For out variant the 'out' argument's type is changed from 'Tensor &'
1087
+ # to 'Tensor'. It's because when calling the lambda it passes in the
1088
+ # PythonArgParser output '_r.tensor(3)', which is stack allocated object
1089
+ # and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
1090
+ #
1091
+ # // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
1092
+ # [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
1093
+ #
1094
+ # For multi-output case it can keep using reference type because the
1095
+ # PythonArgParser output has been unpacked to local variables, e.g.:
1096
+ #
1097
+ # // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
1098
+ # // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
1099
+ # [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
1100
+ #
1101
+ # For deprecated python signature, it should follow deprecated python arg order.
1102
+ # TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
1103
+
1104
+
1105
+ def dispatch_lambda_args(
1106
+ ps: PythonSignature, f: NativeFunction, symint: bool = True
1107
+ ) -> tuple[DispatchLambdaArgument, ...]:
1108
+ if isinstance(ps, PythonSignatureDeprecated):
1109
+ schema = ps.deprecated_schema
1110
+ else:
1111
+ schema = f.func
1112
+
1113
+ # Start with cpp arguments - dispatch lambda signature always include 'self'
1114
+ cpp_args = cpp.arguments(
1115
+ arguments=schema.arguments,
1116
+ faithful=False,
1117
+ symint=symint,
1118
+ method=False,
1119
+ cpp_no_default_args=f.cpp_no_default_args,
1120
+ )
1121
+ out_args: set[str] = {a.name for a in schema.arguments.out}
1122
+
1123
+ # Convert from cpp argument to lambda argument
1124
+ def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
1125
+ type_str = cpp_arg.type
1126
+ is_out_arg = cpp_arg.name in out_args
1127
+ if ps.method and cpp_arg.name == "self":
1128
+ # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
1129
+ type_str = "const at::Tensor &"
1130
+ else:
1131
+ # For other cases we need prevent dangling refs to temps (unless it's
1132
+ # unpacked scattered output)
1133
+ # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
1134
+ # TODO: avoid this special handling?
1135
+ ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
1136
+ if ensure_temp_safe:
1137
+ type_str = {
1138
+ "at::Tensor &": "at::Tensor",
1139
+ }.get(type_str, type_str)
1140
+ return DispatchLambdaArgument(
1141
+ name=cpp_arg.name,
1142
+ type_str=type_str,
1143
+ is_out_arg=is_out_arg,
1144
+ )
1145
+
1146
+ return tuple(map(dispatch_lambda_arg, cpp_args))
1147
+
1148
+
1149
+ # [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
1150
+ # it's enough to just extend the list here. Before you do this, make sure
1151
+ # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
1152
+ SUPPORTED_RETURN_TYPES = {
1153
+ "at::Tensor",
1154
+ "::std::tuple<at::Tensor,at::Tensor>",
1155
+ "::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
1156
+ "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1157
+ "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1158
+ "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1159
+ "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
1160
+ "::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
1161
+ "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
1162
+ "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
1163
+ "::std::tuple<double,int64_t>",
1164
+ "::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
1165
+ "::std::vector<at::Tensor>",
1166
+ # Needed for flash attention forw/backward
1167
+ "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
1168
+ "at::Scalar",
1169
+ "bool",
1170
+ "int64_t",
1171
+ "void*",
1172
+ "void",
1173
+ "at::QScheme",
1174
+ "double",
1175
+ "at::IntArrayRef",
1176
+ "at::ScalarType",
1177
+ "at::Stream",
1178
+ }
1179
+
1180
+
1181
+ def dispatch_lambda_return_str(f: NativeFunction) -> str:
1182
+ # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
1183
+ # because the dispatch lambdas take mutable arguments *by value*, not
1184
+ # by reference. If you then return a reference to such an argument, you
1185
+ # will now have a pointer to a dangling stack entry. Not good.
1186
+ #
1187
+ # You want:
1188
+ #
1189
+ # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
1190
+ # ^^^^^^
1191
+ #
1192
+ # *not*
1193
+ #
1194
+ # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
1195
+ # ^^^^^^^
1196
+ #
1197
+ # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
1198
+ # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
1199
+ # mutable reference to temporary. Maybe we could assign it to a
1200
+ # variable itself.)
1201
+ returns_without_annotation = tuple(
1202
+ Return(r.name, r.type, None) for r in f.func.returns
1203
+ )
1204
+ return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
1205
+ if return_str not in SUPPORTED_RETURN_TYPES:
1206
+ raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
1207
+ return return_str
1208
+
1209
+
1210
+ def cpp_dispatch_target(f: NativeFunction) -> str:
1211
+ symint = f.func.has_symint()
1212
+ name = cpp.name(f.func, symint_overload=symint)
1213
+ if Variant.method in f.variants:
1214
+ return f"self.{name}"
1215
+ if Variant.function in f.variants:
1216
+ if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
1217
+ namespace = "torch"
1218
+ else:
1219
+ namespace = "at"
1220
+ return f"{namespace}::{name}"
1221
+ raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
1222
+
1223
+
1224
+ def cpp_dispatch_exprs(
1225
+ f: NativeFunction,
1226
+ *,
1227
+ python_signature: PythonSignature | None = None,
1228
+ ) -> tuple[str, ...]:
1229
+ cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
1230
+
1231
+ exprs: tuple[str, ...] = ()
1232
+ if not isinstance(python_signature, PythonSignatureDeprecated):
1233
+ # By default the exprs are consistent with the C++ signature.
1234
+ exprs = tuple(a.name for a in cpp_args)
1235
+ else:
1236
+ # For deprecated python signature we may need fill in some constants.
1237
+ exprs = tuple(
1238
+ filter(
1239
+ lambda n: n != "out" or f.func.is_out_fn(),
1240
+ python_signature.deprecated_args_exprs,
1241
+ )
1242
+ )
1243
+
1244
+ if Variant.method in f.variants:
1245
+ exprs = tuple(filter("self".__ne__, exprs))
1246
+
1247
+ return exprs
1248
+
1249
+
1250
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1251
+ #
1252
+ # Python / C++ Args Binding
1253
+ #
1254
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1255
+
1256
+
1257
+ # We explicitly enumerate the PythonArgParser unpacking methods for all
1258
+ # supported types. This might be more verbose than necessary, partially
1259
+ # because of the irregularity of unpacking method naming, partially
1260
+ # because we want to mimic the old codegen behavior - to reject
1261
+ # unexpected and/or unsupported cases which the old codegen rejects.
1262
+ # For certain cases it is intentionally more restrictive than necessary,
1263
+ # e.g.: it doesn't accepts doublelist with definite size.
1264
+ def arg_parser_unpack_method(
1265
+ t: Type, default: str | None, default_init: str | None, *, symint: bool = True
1266
+ ) -> str:
1267
+ has_default_init = default_init is not None
1268
+ if has_default_init and str(t) not in (
1269
+ "ScalarType?",
1270
+ "ScalarType",
1271
+ "Device",
1272
+ "Device?",
1273
+ "Layout",
1274
+ "Layout?",
1275
+ "bool",
1276
+ "bool?",
1277
+ ):
1278
+ raise RuntimeError(f"type '{t}' does not supported unpacking with default")
1279
+
1280
+ if isinstance(t, BaseType):
1281
+ if t.name in [
1282
+ BaseTy.Tensor,
1283
+ BaseTy.Stream,
1284
+ BaseTy.Storage,
1285
+ BaseTy.Scalar,
1286
+ BaseTy.Dimname,
1287
+ ]:
1288
+ # These unpack methods line up with their schema names
1289
+ return t.name.name.lower()
1290
+ elif t.name == BaseTy.ScalarType:
1291
+ return "scalartypeWithDefault" if has_default_init else "scalartype"
1292
+ elif t.name == BaseTy.Device:
1293
+ return "deviceWithDefault" if has_default_init else "device"
1294
+ elif t.name == BaseTy.DeviceIndex:
1295
+ return "toInt64"
1296
+ elif t.name == BaseTy.int:
1297
+ return "toInt64"
1298
+ elif t.name == BaseTy.SymInt:
1299
+ return "toSymInt" if symint else "toInt64"
1300
+ elif t.name == BaseTy.bool:
1301
+ return "toBoolWithDefault" if has_default_init else "toBool"
1302
+ elif t.name == BaseTy.float:
1303
+ return "toDouble"
1304
+ elif t.name == BaseTy.str:
1305
+ return "stringView"
1306
+ elif t.name == BaseTy.Layout:
1307
+ return "layoutWithDefault" if has_default_init else "layout"
1308
+ elif t.name == BaseTy.MemoryFormat:
1309
+ return "memoryformat"
1310
+
1311
+ elif isinstance(t, OptionalType):
1312
+ if str(t.elem) == "Tensor":
1313
+ return "optionalTensor"
1314
+ elif str(t.elem) == "Generator":
1315
+ return "generator"
1316
+ elif str(t.elem) == "Dimname[]":
1317
+ return "toDimnameListOptional"
1318
+ elif not has_default_init and default in (
1319
+ None,
1320
+ "None",
1321
+ "::std::nullopt",
1322
+ "std::nullopt",
1323
+ ):
1324
+ # If default is None: append 'Optional' to elem's unpacking method
1325
+ return (
1326
+ arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
1327
+ )
1328
+ else:
1329
+ # Otherwise, load as underlying type with default
1330
+ return arg_parser_unpack_method(
1331
+ t.elem, default, default_init, symint=symint
1332
+ )
1333
+
1334
+ elif isinstance(t, ListType):
1335
+ if str(t.elem) == "Tensor":
1336
+ # accept and use definite size
1337
+ return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
1338
+ elif str(t.elem) == "Tensor?":
1339
+ return "list_of_optional_tensors"
1340
+ elif str(t.elem) == "Dimname":
1341
+ # accept definite size
1342
+ return "dimnamelist"
1343
+ elif str(t.elem) == "int":
1344
+ # accept definite size
1345
+ return "intlist"
1346
+ elif str(t.elem) == "float":
1347
+ return "doublelist"
1348
+ elif str(t.elem) == "SymInt":
1349
+ # accept definite size
1350
+ return "symintlist" if symint else "intlist"
1351
+ elif str(t.elem) == "Scalar":
1352
+ return "scalarlist"
1353
+ raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
1354
+
1355
+
1356
+ # Return RHS expression for python argument using PythonArgParser output.
1357
+ # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
1358
+ def arg_parser_output_expr(
1359
+ arg_index: int, a: PythonArgument, *, symint: bool = True
1360
+ ) -> PythonArgParserOutputExpr:
1361
+ has_default = a.default_init is not None
1362
+ unpack_method = arg_parser_unpack_method(
1363
+ t=a.type, default=a.default, default_init=a.default_init, symint=symint
1364
+ )
1365
+ default = f", {a.default_init}" if has_default else ""
1366
+ expr = f"_r.{unpack_method}({arg_index}{default})"
1367
+
1368
+ return PythonArgParserOutputExpr(
1369
+ name=a.name,
1370
+ expr=expr,
1371
+ index=arg_index,
1372
+ argument=a,
1373
+ )
1374
+
1375
+
1376
+ # Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
1377
+ def arg_parser_output_exprs(
1378
+ ps: PythonSignature, f: NativeFunction, *, symint: bool = True
1379
+ ) -> dict[str, PythonArgParserOutputExpr]:
1380
+ return {
1381
+ e.name: e
1382
+ for i, a in enumerate(ps.arguments())
1383
+ for e in (arg_parser_output_expr(i, a, symint=symint),)
1384
+ }
1385
+
1386
+
1387
+ # argument name to type for scattered tensor options fields
1388
+ TENSOR_OPTIONS_FIELDS = {
1389
+ "dtype": "ScalarType?",
1390
+ "device": "Device?",
1391
+ "layout": "Layout?",
1392
+ "pin_memory": "bool?",
1393
+ "requires_grad": "bool?",
1394
+ }
1395
+
1396
+
1397
+ # bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
1398
+ def dispatch_lambda_exprs(
1399
+ ps: PythonSignature, f: NativeFunction, *, symint: bool = True
1400
+ ) -> DispatchLambdaArgumentExprs:
1401
+ # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
1402
+ # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
1403
+ # outputs.
1404
+ arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
1405
+ lambda_args = dispatch_lambda_args(ps, f, symint=symint)
1406
+ inits: list[str] = []
1407
+ lambda_args_exprs: dict[str, str] = {}
1408
+
1409
+ has_toptions = has_tensor_options(f)
1410
+
1411
+ # 1. special inits/unpacking to provide binding exprs for lambda arguments.
1412
+ for a in ps.arguments(skip_tensor_options=True):
1413
+ name = a.name
1414
+ arg_parser_expr = arg_parser_outputs[a.name].expr
1415
+
1416
+ if has_toptions and name == "self":
1417
+ # TODO: why this needs to be special case?
1418
+ inits.extend(
1419
+ [
1420
+ f"auto self = {arg_parser_expr};",
1421
+ ]
1422
+ )
1423
+ lambda_args_exprs[name] = name
1424
+ elif (
1425
+ isinstance(a, PythonOutArgument)
1426
+ and len(a.outputs) > 1
1427
+ and f.func.is_out_fn()
1428
+ ):
1429
+ inits.extend(
1430
+ [
1431
+ f"auto out = {arg_parser_expr};",
1432
+ ]
1433
+ )
1434
+ for i, out_arg in enumerate(a.outputs):
1435
+ lambda_args_exprs[out_arg.name] = f"out[{i}]"
1436
+ elif str(a.type) == "Dimname[]?":
1437
+ # [old codegen]
1438
+ # TODO: make this part of something more general, or get rid of it.
1439
+ # optional<ArrayRef<T>> are special. The PythonArgParser returns an
1440
+ # optional<vector<T>>, which cannot be implicitly converted to
1441
+ # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
1442
+ inits.extend(
1443
+ [
1444
+ f"auto __{name} = {arg_parser_expr};",
1445
+ f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950
1446
+ ]
1447
+ )
1448
+ lambda_args_exprs[name] = name
1449
+ else:
1450
+ # default case - directly using PythonArgParser output expr
1451
+ lambda_args_exprs[name] = arg_parser_expr
1452
+
1453
+ # method's self is passed directly to python binding, rather than parsed
1454
+ if ps.method:
1455
+ lambda_args_exprs["self"] = "self"
1456
+
1457
+ # 2. special packing/checking for TensorOptions.
1458
+ tensor_options_args_names = [a.name for a in ps.tensor_options_args]
1459
+ if has_toptions:
1460
+ if f.func.is_out_fn():
1461
+ raise RuntimeError(f"{f.func}: tensor options with output arg")
1462
+ for a in ps.tensor_options_args:
1463
+ if a.name not in TENSOR_OPTIONS_FIELDS:
1464
+ raise RuntimeError(
1465
+ f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
1466
+ )
1467
+ if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
1468
+ raise RuntimeError(
1469
+ f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
1470
+ )
1471
+ if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
1472
+ raise RuntimeError(
1473
+ f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
1474
+ )
1475
+
1476
+ inits.append(
1477
+ f"""\
1478
+ const auto options = TensorOptions()
1479
+ .dtype({arg_parser_outputs['dtype'].expr})
1480
+ .device({arg_parser_outputs['device'].expr})
1481
+ .layout({arg_parser_outputs['layout'].expr})
1482
+ .requires_grad({arg_parser_outputs['requires_grad'].expr})
1483
+ .pinned_memory({arg_parser_outputs['pin_memory'].expr});
1484
+ torch::utils::maybe_initialize_device(options);
1485
+ """
1486
+ )
1487
+ lambda_args_exprs["options"] = "options"
1488
+
1489
+ # 3. special case - access scattered TensorOptions fields without packing
1490
+ # TODO: maybe move to the generator side as it's not related to binding.
1491
+ if not has_toptions and tensor_options_args_names:
1492
+ if "dtype" in tensor_options_args_names:
1493
+ # we're an output-arg variant, check these args against output tensor
1494
+ if not f.func.is_out_fn():
1495
+ raise RuntimeError(
1496
+ f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
1497
+ )
1498
+ if not all(a in tensor_options_args_names for a in ("layout", "device")):
1499
+ raise RuntimeError(
1500
+ f"{f.func}: incomplete tensor options for output check"
1501
+ )
1502
+
1503
+ inits.append(
1504
+ f"""\
1505
+ check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr},
1506
+ {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr},
1507
+ {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr});
1508
+ """
1509
+ )
1510
+ # we'll set requires_grad on outgoing tensor
1511
+ if "requires_grad" not in tensor_options_args_names:
1512
+ raise RuntimeError(
1513
+ f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
1514
+ )
1515
+
1516
+ return DispatchLambdaArgumentExprs(
1517
+ exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
1518
+ inits=inits,
1519
+ )
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ If you add a file to this directory, you **MUST** update
2
+ `torch/CMakeLists.txt` and add the file as a dependency to
3
+ the `add_custom_command` call.
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-311.pyc ADDED
Binary file (2.32 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-311.pyc ADDED
Binary file (6.72 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-311.pyc ADDED
Binary file (5.28 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-311.pyc ADDED
Binary file (35.1 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-311.pyc ADDED
Binary file (25.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-311.pyc ADDED
Binary file (46.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-311.pyc ADDED
Binary file (21.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-311.pyc ADDED
Binary file (6.65 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-311.pyc ADDED
Binary file (79.2 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-311.pyc ADDED
Binary file (16.4 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-311.pyc ADDED
Binary file (43.6 kB). View file
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/build.bzl ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def define_targets(rules):
2
+ rules.py_library(
3
+ name = "autograd",
4
+ srcs = rules.glob(["*.py"]),
5
+ data = rules.glob([
6
+ "*.yaml",
7
+ "templates/*",
8
+ ]),
9
+ visibility = ["//:__subpackages__"],
10
+ deps = [
11
+ rules.requirement("PyYAML"),
12
+ "//torchgen",
13
+ ],
14
+ )
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/context.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import Callable
3
+
4
+ from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
5
+ from torchgen.context import native_function_manager
6
+ from torchgen.utils import T
7
+
8
+
9
+ # Like tools.api.context.with_native_function, but for
10
+ # NativeFunctionWithDifferentiabilityInfo.
11
+ def with_native_function_with_differentiability_info(
12
+ func: Callable[[NFWDI], T]
13
+ ) -> Callable[[NFWDI], T]:
14
+ @functools.wraps(func)
15
+ def wrapper(f: NFWDI) -> T:
16
+ with native_function_manager(f.func):
17
+ return func(f)
18
+
19
+ return wrapper
20
+
21
+
22
+ # Like the above but with an additional dispatch key string argument
23
+ def with_native_function_with_differentiability_info_and_key(
24
+ func: Callable[[NFWDI, str], T]
25
+ ) -> Callable[[NFWDI, str], T]:
26
+ @functools.wraps(func)
27
+ def wrapper(f: NFWDI, key: str) -> T:
28
+ with native_function_manager(f.func):
29
+ return func(f, key)
30
+
31
+ return wrapper
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/derivatives.yaml ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ For procedural tests needed for __torch_function__, we use this function
3
+ to export method names and signatures as needed by the tests in
4
+ test/test_overrides.py.
5
+
6
+ python -m tools.autograd.gen_annotated_fn_args \
7
+ aten/src/ATen/native/native_functions.yaml \
8
+ aten/src/ATen/native/tags.yaml \
9
+ $OUTPUT_DIR \
10
+ tools/autograd
11
+
12
+ Where $OUTPUT_DIR is where you would like the files to be
13
+ generated. In the full build system, OUTPUT_DIR is
14
+ torch/testing/_internal/generated
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import os
21
+ import textwrap
22
+ from collections import defaultdict
23
+ from typing import Any, Sequence, TYPE_CHECKING
24
+
25
+ import torchgen.api.python as python
26
+ from torchgen.context import with_native_function
27
+ from torchgen.gen import parse_native_yaml
28
+ from torchgen.utils import FileManager
29
+
30
+ from .gen_python_functions import (
31
+ is_py_fft_function,
32
+ is_py_linalg_function,
33
+ is_py_nn_function,
34
+ is_py_special_function,
35
+ is_py_torch_function,
36
+ is_py_variable_method,
37
+ should_generate_py_binding,
38
+ )
39
+
40
+
41
+ if TYPE_CHECKING:
42
+ from torchgen.model import Argument, BaseOperatorName, NativeFunction
43
+
44
+
45
+ def gen_annotated(
46
+ native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
47
+ ) -> None:
48
+ native_functions = parse_native_yaml(
49
+ native_yaml_path, tags_yaml_path
50
+ ).native_functions
51
+ mappings = (
52
+ (is_py_torch_function, "torch._C._VariableFunctions"),
53
+ (is_py_nn_function, "torch._C._nn"),
54
+ (is_py_linalg_function, "torch._C._linalg"),
55
+ (is_py_special_function, "torch._C._special"),
56
+ (is_py_fft_function, "torch._C._fft"),
57
+ (is_py_variable_method, "torch.Tensor"),
58
+ )
59
+ annotated_args: list[str] = []
60
+ for pred, namespace in mappings:
61
+ groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
62
+ for f in native_functions:
63
+ if not should_generate_py_binding(f) or not pred(f):
64
+ continue
65
+ groups[f.func.name.name].append(f)
66
+ for group in groups.values():
67
+ for f in group:
68
+ annotated_args.append(f"{namespace}.{gen_annotated_args(f)}")
69
+
70
+ template_path = os.path.join(autograd_dir, "templates")
71
+ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
72
+ fm.write_with_template(
73
+ "annotated_fn_args.py",
74
+ "annotated_fn_args.py.in",
75
+ lambda: {
76
+ "annotated_args": textwrap.indent("\n".join(annotated_args), " "),
77
+ },
78
+ )
79
+
80
+
81
+ @with_native_function
82
+ def gen_annotated_args(f: NativeFunction) -> str:
83
+ def _get_kwargs_func_exclusion_list() -> list[str]:
84
+ # functions that currently don't work with kwargs in test_overrides.py
85
+ return [
86
+ "diagonal",
87
+ "round_",
88
+ "round",
89
+ "scatter_",
90
+ ]
91
+
92
+ def _add_out_arg(
93
+ out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
94
+ ) -> None:
95
+ for arg in args:
96
+ if arg.default is not None:
97
+ continue
98
+ out_arg: dict[str, Any] = {}
99
+ out_arg["is_kwarg_only"] = str(is_kwarg_only)
100
+ out_arg["name"] = arg.name
101
+ out_arg["simple_type"] = python.argument_type_str(
102
+ arg.type, simple_type=True
103
+ )
104
+ size_t = python.argument_type_size(arg.type)
105
+ if size_t:
106
+ out_arg["size"] = size_t
107
+ out_args.append(out_arg)
108
+
109
+ out_args: list[dict[str, Any]] = []
110
+ _add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
111
+ if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
112
+ _add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)
113
+
114
+ return f"{f.func.name.name}: {repr(out_args)},"
115
+
116
+
117
+ def main() -> None:
118
+ parser = argparse.ArgumentParser(description="Generate annotated_fn_args script")
119
+ parser.add_argument(
120
+ "native_functions", metavar="NATIVE", help="path to native_functions.yaml"
121
+ )
122
+ parser.add_argument("tags", metavar="TAGS", help="path to tags.yaml")
123
+ parser.add_argument("out", metavar="OUT", help="path to output directory")
124
+ parser.add_argument(
125
+ "autograd", metavar="AUTOGRAD", help="path to template directory"
126
+ )
127
+ args = parser.parse_args()
128
+ gen_annotated(args.native_functions, args.tags, args.out, args.autograd)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_autograd.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ To run this file by hand from the root of the PyTorch
3
+ repository, run:
4
+
5
+ python -m tools.autograd.gen_autograd \
6
+ aten/src/ATen/native/native_functions.yaml \
7
+ aten/src/ATen/native/tags.yaml \
8
+ $OUTPUT_DIR \
9
+ tools/autograd
10
+
11
+ Where $OUTPUT_DIR is where you would like the files to be
12
+ generated. In the full build system, OUTPUT_DIR is
13
+ torch/csrc/autograd/generated/
14
+ """
15
+
16
+ # gen_autograd.py generates C++ autograd functions and Python bindings.
17
+ #
18
+ # It delegates to the following scripts:
19
+ #
20
+ # gen_autograd_functions.py: generates subclasses of torch::autograd::Node
21
+ # gen_variable_type.py: generates VariableType.h which contains all tensor methods
22
+ # gen_python_functions.py: generates Python bindings to THPVariable
23
+ #
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import os
29
+
30
+ from torchgen.api import cpp
31
+ from torchgen.api.autograd import (
32
+ match_differentiability_info,
33
+ NativeFunctionWithDifferentiabilityInfo,
34
+ )
35
+ from torchgen.gen import parse_native_yaml
36
+ from torchgen.selective_build.selector import SelectiveBuilder
37
+
38
+ from . import gen_python_functions
39
+ from .gen_autograd_functions import (
40
+ gen_autograd_functions_lib,
41
+ gen_autograd_functions_python,
42
+ )
43
+ from .gen_inplace_or_view_type import gen_inplace_or_view_type
44
+ from .gen_trace_type import gen_trace_type
45
+ from .gen_variable_factories import gen_variable_factories
46
+ from .gen_variable_type import gen_variable_type
47
+ from .gen_view_funcs import gen_view_funcs
48
+ from .load_derivatives import load_derivatives
49
+
50
+
51
+ def gen_autograd(
52
+ native_functions_path: str,
53
+ tags_path: str,
54
+ out: str,
55
+ autograd_dir: str,
56
+ operator_selector: SelectiveBuilder,
57
+ disable_autograd: bool = False,
58
+ ) -> None:
59
+ # Parse and load derivatives.yaml
60
+ differentiability_infos, used_dispatch_keys = load_derivatives(
61
+ os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
62
+ )
63
+
64
+ template_path = os.path.join(autograd_dir, "templates")
65
+
66
+ native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
67
+ fns = sorted(
68
+ filter(
69
+ operator_selector.is_native_function_selected_for_training, native_funcs
70
+ ),
71
+ key=lambda f: cpp.name(f.func),
72
+ )
73
+ fns_with_diff_infos: list[
74
+ NativeFunctionWithDifferentiabilityInfo
75
+ ] = match_differentiability_info(fns, differentiability_infos)
76
+
77
+ # Generate VariableType.h/cpp
78
+ if not disable_autograd:
79
+ gen_variable_type(
80
+ out,
81
+ native_functions_path,
82
+ tags_path,
83
+ fns_with_diff_infos,
84
+ template_path,
85
+ used_dispatch_keys,
86
+ )
87
+
88
+ gen_inplace_or_view_type(
89
+ out, native_functions_path, tags_path, fns_with_diff_infos, template_path
90
+ )
91
+
92
+ # operator filter not applied as tracing sources are excluded in selective build
93
+ gen_trace_type(out, native_funcs, template_path)
94
+ # Generate Functions.h/cpp
95
+ gen_autograd_functions_lib(out, differentiability_infos, template_path)
96
+
97
+ # Generate variable_factories.h
98
+ gen_variable_factories(out, native_functions_path, tags_path, template_path)
99
+
100
+ # Generate ViewFuncs.h/cpp
101
+ gen_view_funcs(out, fns_with_diff_infos, template_path)
102
+
103
+
104
+ def gen_autograd_python(
105
+ native_functions_path: str,
106
+ tags_path: str,
107
+ out: str,
108
+ autograd_dir: str,
109
+ ) -> None:
110
+ differentiability_infos, _ = load_derivatives(
111
+ os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
112
+ )
113
+
114
+ template_path = os.path.join(autograd_dir, "templates")
115
+
116
+ # Generate Functions.h/cpp
117
+ gen_autograd_functions_python(out, differentiability_infos, template_path)
118
+
119
+ # Generate Python bindings
120
+ deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
121
+ gen_python_functions.gen(
122
+ out, native_functions_path, tags_path, deprecated_path, template_path
123
+ )
124
+
125
+
126
+ def main() -> None:
127
+ parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
128
+ parser.add_argument(
129
+ "native_functions", metavar="NATIVE", help="path to native_functions.yaml"
130
+ )
131
+ parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml")
132
+ parser.add_argument("out", metavar="OUT", help="path to output directory")
133
+ parser.add_argument(
134
+ "autograd", metavar="AUTOGRAD", help="path to autograd directory"
135
+ )
136
+ args = parser.parse_args()
137
+ gen_autograd(
138
+ args.native_functions,
139
+ args.tags,
140
+ args.out,
141
+ args.autograd,
142
+ SelectiveBuilder.get_nop_selector(),
143
+ )
144
+
145
+
146
+ if __name__ == "__main__":
147
+ main()
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generates ADInplaceOrViewType.h/cpp
2
+ #
3
+ # NOTE: If any changes are being made to the ADInplaceOrView codegen please also check
4
+ # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
5
+ # The fallback is expected to mimick this codegen, so we should keep the two in sync.
6
+
7
+ from __future__ import annotations
8
+
9
+ from torchgen.api import cpp
10
+ from torchgen.api.autograd import (
11
+ dispatch_strategy,
12
+ gen_differentiable_outputs,
13
+ NativeFunctionWithDifferentiabilityInfo,
14
+ )
15
+ from torchgen.api.types import (
16
+ BaseCType,
17
+ Binding,
18
+ boolT,
19
+ ConstRefCType,
20
+ CType,
21
+ DispatcherSignature,
22
+ intArrayRefT,
23
+ longT,
24
+ OptionalCType,
25
+ symIntArrayRefT,
26
+ SymIntT,
27
+ tensorT,
28
+ )
29
+ from torchgen.code_template import CodeTemplate
30
+ from torchgen.context import with_native_function
31
+ from torchgen.model import (
32
+ NativeFunction,
33
+ SchemaKind,
34
+ SelfArgument,
35
+ TensorOptionsArguments,
36
+ Type,
37
+ )
38
+ from torchgen.utils import FileManager
39
+
40
+ from .context import with_native_function_with_differentiability_info
41
+ from .gen_trace_type import (
42
+ get_return_value,
43
+ MANUAL_AUTOGRAD,
44
+ tie_return_values,
45
+ type_wrapper_name,
46
+ )
47
+
48
+
49
+ # See NOTE [ Autograd View Variables ] in variable.h for details.
50
+ # If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
51
+ # you **MUST** also update the public list of view ops accordingly in
52
+ # docs/source/tensor_view.rst. Note not all ATen functions are exposed to public,
53
+ # e.g alias & sparse_coo_tensor_with_dims_and_tensors.
54
+ #
55
+ # A map: function name => name of the argument that all outputs are view of
56
+
57
+ VIEW_FUNCTIONS_WITH_METADATA_CHANGE = [
58
+ "view_as_complex",
59
+ "view_as_real",
60
+ "_conj",
61
+ "_neg_view",
62
+ "_nested_get_values",
63
+ "_nested_view_from_buffer",
64
+ "_nested_view_from_jagged",
65
+ ]
66
+
67
+ VIEW_FUNCTIONS = {
68
+ "numpy_T": "self",
69
+ "alias": "self",
70
+ "as_strided": "self",
71
+ "diagonal": "self",
72
+ "expand": "self",
73
+ "permute": "self",
74
+ "select": "self",
75
+ "slice": "self",
76
+ "slice_inverse": "self",
77
+ "split": "self",
78
+ "split_with_sizes": "self",
79
+ "squeeze": "self",
80
+ "t": "self",
81
+ "transpose": "self",
82
+ "unfold": "self",
83
+ "unsqueeze": "self",
84
+ "flatten": "self",
85
+ "view": "self",
86
+ "unbind": "self",
87
+ "_indices": "self",
88
+ "_values": "self",
89
+ "indices": "self",
90
+ "values": "self",
91
+ "crow_indices": "self",
92
+ "col_indices": "self",
93
+ "ccol_indices": "self",
94
+ "row_indices": "self",
95
+ # sparse_coo ctor output should really be views of both indices and values,
96
+ # but we only supports making as view of a single variable, and indices is
97
+ # discrete anyways.
98
+ # FIXME: clone indices on construction.
99
+ "sparse_coo_tensor_with_dims_and_tensors": "values",
100
+ "_reshape_alias": "self",
101
+ "_test_autograd_multiple_dispatch_view": "self",
102
+ }
103
+
104
+ for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE:
105
+ VIEW_FUNCTIONS[key] = "self"
106
+
107
+ # note: some VIEW_FUNCTIONS are just compositions of the view functions above
108
+ # this list contains both the root view functions and any that are purely composed
109
+ # of viewing functions, and is used by the JIT to determine when an operator
110
+ # may return a view of its inputs; however they may sometimes return a copy.
111
+ # (e.g. `contiguous`)
112
+ RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union(
113
+ {
114
+ "chunk",
115
+ "detach",
116
+ "contiguous",
117
+ "reshape",
118
+ "reshape_as",
119
+ "expand_as",
120
+ "view_as",
121
+ "real",
122
+ "imag",
123
+ "narrow",
124
+ "movedim",
125
+ "tensor_split",
126
+ "swapdims",
127
+ "swapaxes",
128
+ "mT",
129
+ "mH",
130
+ "adjoint",
131
+ "matrix_H",
132
+ }
133
+ )
134
+
135
+ # These are the functions we consider views for the purposes of validating
136
+ # StorageImpl and TensorImpl in gen_variable_type.
137
+ # `_unsafe_view` is not included in VIEW_FUNCTIONS above because it is not a
138
+ # view for the purposes of ADInplaceOrView kernel, we do not want to call as_view
139
+ # See NOTE [Unsafe View] for more info.
140
+ ALL_VIEW_FUNCTIONS = {
141
+ **VIEW_FUNCTIONS,
142
+ "_unsafe_view": "self",
143
+ }
144
+
145
+ ARRAYREF_TO_VEC = CodeTemplate(
146
+ """\
147
+ auto ${vec} = ${arg}.vec();
148
+ """
149
+ )
150
+
151
+ OPTIONAL_TO_VAL = CodeTemplate(
152
+ """\
153
+ auto ${val} = ${arg}.value_or(${default});
154
+ """
155
+ )
156
+
157
+ CALL_DISPATCH = CodeTemplate(
158
+ """\
159
+ at::_ops::${unambiguous_name}::call(${unpacked_args})"""
160
+ )
161
+
162
+ REVERSE_VIEW_DISPATCH = CodeTemplate(
163
+ """\
164
+ ${reverse_name}(${unpacked_args})"""
165
+ )
166
+
167
+ MULTI_OUTPUT_VIEW_ITERATION = CodeTemplate(
168
+ """\
169
+ for (auto ${view_idx} : c10::irange(${var}.size())) {
170
+ ${body}
171
+ }
172
+ """
173
+ )
174
+
175
+ SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate(
176
+ """\
177
+ std::unique_ptr<torch::autograd::ViewFunc> func(nullptr);
178
+ std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
179
+ if (${is_view_with_metadata_change} ||
180
+ !self.unsafeGetTensorImpl()->support_as_strided() ||
181
+ self.unsafeGetTensorImpl()->is_python_dispatch() ||
182
+ c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
183
+ ${replay_view_func}
184
+ ${reverse_replay_view_func}
185
+ }
186
+ """
187
+ )
188
+
189
+ REPLAY_VIEW_FUNC = CodeTemplate(
190
+ """\
191
+ func = std::make_unique<${view_func_name}>(${view_func_args});
192
+ """
193
+ )
194
+
195
+ REVERSE_REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate(
196
+ """\
197
+ rev_func = [=](const at::Tensor& ${input_view}) {
198
+ return ${reverse_replay_view_call};
199
+ };
200
+ """
201
+ )
202
+
203
+ METHOD_DEFINITION = CodeTemplate(
204
+ """\
205
+ ${return_type} ${type_wrapper_name}(${formals}) {
206
+ ${type_definition_body}
207
+ }
208
+ """
209
+ )
210
+
211
+ WRAPPER_REGISTRATION = CodeTemplate(
212
+ """\
213
+ m.impl("${unqual_operator_name_with_overload}",
214
+ TORCH_FN(${class_type}::${type_wrapper_name})
215
+ );
216
+ """
217
+ )
218
+
219
+ AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate(
220
+ """\
221
+ m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImplementedFallback());
222
+ """
223
+ )
224
+
225
+ INPLACE_REDISPATCH = CodeTemplate(
226
+ """\
227
+ {
228
+ at::AutoDispatchBelowADInplaceOrView guard;
229
+ at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
230
+ }
231
+ """
232
+ )
233
+
234
+ ASSIGN_RETURN_VALUE = CodeTemplate(
235
+ """\
236
+ ${return_values} = ${rhs_value};
237
+ """
238
+ )
239
+
240
+ VIEW_REDISPATCH = CodeTemplate(
241
+ """\
242
+ ${assign_return_values} ([&]() {
243
+ at::AutoDispatchBelowADInplaceOrView guard;
244
+ return at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
245
+ })();
246
+ """
247
+ )
248
+
249
+ TMP_VAR = "_tmp"
250
+
251
+
252
+ # FIXME: Ideally these functions should be methods on Type class, but we have a
253
+ # comment in codegen/model.py there saying these concepts are not well defined.
254
+ # Thus we put a version that commonly used by autograd codegen here.
255
+ def is_tensor_type(t: Type) -> bool:
256
+ # TODO: Should handle optional here?
257
+ return t.is_tensor_like() and t.is_list_like() is None
258
+
259
+
260
+ def is_tensor_list_type(t: Type) -> bool:
261
+ # TODO: Should handle optional here?
262
+ return t.is_tensor_like() and t.is_list_like() is not None
263
+
264
+
265
+ UNPACK_TENSOR = CodeTemplate(
266
+ """\
267
+ auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});"""
268
+ )
269
+
270
+
271
+ def unpacked_name(arg_name: str) -> str:
272
+ return arg_name + "_"
273
+
274
+
275
+ # e.g. select.int -> select_copy_int_inverse()
276
+ def inverse_view_name(f: NativeFunction) -> str:
277
+ copy_variant = f"{f.root_name}_copy"
278
+ overload = f"{f.func.name.overload_name}"
279
+ if overload != "":
280
+ overload = "_" + overload
281
+ return f"{copy_variant}{overload}_inverse"
282
+
283
+
284
+ def extract_bindings(f: NativeFunction) -> list[Binding]:
285
+ return [
286
+ r
287
+ for a in f.func.schema_order_arguments()
288
+ for r in cpp.argument(
289
+ a,
290
+ method=False,
291
+ symint=True,
292
+ cpp_no_default_args=set(),
293
+ faithful=False,
294
+ has_tensor_options=False,
295
+ )
296
+ ]
297
+
298
+
299
+ @with_native_function
300
+ def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]:
301
+ body: list[str] = []
302
+ unpacked_bindings: list[Binding] = []
303
+
304
+ for i, binding in enumerate(extract_bindings(f)):
305
+ assert not isinstance(binding.argument, SelfArgument)
306
+ if isinstance(binding.argument, TensorOptionsArguments):
307
+ raise RuntimeError("VariableKernel shouldn't take TensorOptions")
308
+
309
+ is_nullable = binding.argument.type.is_nullable()
310
+ if not binding.argument.type.is_tensor_like() or is_nullable:
311
+ unpacked_bindings.append(binding)
312
+ continue
313
+
314
+ is_tensor_list = is_tensor_list_type(binding.argument.type)
315
+ ref = (not is_nullable) and not is_tensor_list
316
+ suffix = "_opt" if is_nullable and not is_tensor_list else ""
317
+ body.append(
318
+ UNPACK_TENSOR.substitute(
319
+ arg_name=binding.name,
320
+ arg_pos=i,
321
+ suffix=suffix,
322
+ ref="&" if ref else "",
323
+ )
324
+ )
325
+ unpacked_bindings.append(
326
+ Binding(
327
+ name=unpacked_name(binding.name),
328
+ nctype=binding.nctype,
329
+ argument=binding.argument,
330
+ default=binding.default,
331
+ )
332
+ )
333
+
334
+ return body, unpacked_bindings
335
+
336
+
337
+ def get_base_name(f: NativeFunction) -> str:
338
+ return f.func.name.name.base # TODO: should be str(f.func.name.name)?
339
+
340
+
341
+ def get_view_info(f: NativeFunction) -> str | None:
342
+ base_name = get_base_name(f)
343
+ view_info = VIEW_FUNCTIONS.get(base_name, None)
344
+ if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
345
+ view_info = "self"
346
+ return view_info
347
+
348
+
349
+ def emit_view_func(
350
+ f: NativeFunction, bindings: list[Binding], view_idx: str | None = None
351
+ ) -> str:
352
+ """Generate an additional lambda function to recover views in backward when as_strided is not supported.
353
+ See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
354
+ """
355
+ # TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
356
+ input_base = "input_base"
357
+ replay_view_func = ""
358
+ updated_args: list[str] = []
359
+ known_view_arg_simple_types: list[CType] = [
360
+ BaseCType(longT),
361
+ OptionalCType(BaseCType(longT)),
362
+ BaseCType(SymIntT),
363
+ OptionalCType(BaseCType(SymIntT)),
364
+ BaseCType(boolT),
365
+ BaseCType(intArrayRefT),
366
+ BaseCType(symIntArrayRefT),
367
+ ConstRefCType(BaseCType(tensorT)),
368
+ ConstRefCType(OptionalCType(BaseCType(tensorT))),
369
+ ]
370
+ for binding in bindings:
371
+ arg, arg_type = binding.name, binding.nctype.type
372
+ if arg == "self":
373
+ updated_args.append(input_base)
374
+ continue
375
+ if arg_type not in known_view_arg_simple_types:
376
+ known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types])
377
+ raise TypeError(
378
+ f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: "
379
+ f"{known_types_str}. Please update the list or materialize it so that it can be closed "
380
+ "over by value, also add a test in pytorch/xla/test/test_operations.py where this code "
381
+ "is exercised."
382
+ )
383
+ if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType(
384
+ symIntArrayRefT
385
+ ):
386
+ # It's not safe to close over IntArrayRef by value, since this is a
387
+ # reference type, so materialize a vector to close over by value
388
+ arg_vec = arg + "_vec"
389
+ replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec)
390
+ updated_args.append(arg_vec)
391
+ elif arg_type == OptionalCType(BaseCType(longT)):
392
+ # Materialize int64_t? to int64_t
393
+ arg_value = arg + "_val"
394
+ replay_view_func += OPTIONAL_TO_VAL.substitute(
395
+ arg=arg, val=arg_value, default="0"
396
+ )
397
+ updated_args.append(arg_value)
398
+ elif arg_type == ConstRefCType(BaseCType(tensorT)) or arg_type == ConstRefCType(
399
+ OptionalCType(BaseCType(tensorT))
400
+ ):
401
+ # NB: Closing over a tensor. If a user modifies this tensor, this will be silently
402
+ # incorrect. The proper thing to do is to store the version counter and copy on write.
403
+ updated_args.append(arg)
404
+ else:
405
+ updated_args.append(arg)
406
+
407
+ from .gen_view_funcs import view_func_name
408
+
409
+ view_func_args = [b.name for b in bindings if b.name != "self"]
410
+ if view_idx is not None:
411
+ view_func_args.append(f"{view_idx}")
412
+ replay_view_func += REPLAY_VIEW_FUNC.substitute(
413
+ view_func_name=view_func_name(f, include_namespace=True),
414
+ view_func_args=view_func_args,
415
+ )
416
+
417
+ input_view = "input_view"
418
+ reverse_unpacked_args = [
419
+ "self",
420
+ f"{input_view}",
421
+ # inverse_return_mode=
422
+ "at::functionalization::InverseReturnMode::AlwaysView",
423
+ *(() if view_idx is None else (f"{view_idx}",)),
424
+ # skip input_base arg
425
+ *updated_args[1:],
426
+ ]
427
+
428
+ from torchgen.api.functionalization import reverse_name
429
+
430
+ reverse_replay_view_call = REVERSE_VIEW_DISPATCH.substitute(
431
+ reverse_name=reverse_name(f, include_namespace=True),
432
+ unpacked_args=reverse_unpacked_args,
433
+ )
434
+ reverse_replay_view_func = REVERSE_REPLAY_VIEW_LAMBDA_FUNC.substitute(
435
+ input_view=input_view, reverse_replay_view_call=reverse_replay_view_call
436
+ )
437
+
438
+ is_view_with_metadata_change = (
439
+ "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false"
440
+ )
441
+
442
+ return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute(
443
+ is_view_with_metadata_change=is_view_with_metadata_change,
444
+ replay_view_func=replay_view_func,
445
+ reverse_replay_view_func=reverse_replay_view_func,
446
+ )
447
+
448
+
449
+ def emit_view_body(
450
+ fn: NativeFunctionWithDifferentiabilityInfo, var: str
451
+ ) -> tuple[str, str]:
452
+ # See NOTE [ Autograd View Variables ] in variable.h for details.
453
+ f = fn.func
454
+ base_name = get_base_name(f)
455
+ view_info = get_view_info(f)
456
+ call = ""
457
+ differentiable_outputs = gen_differentiable_outputs(fn)
458
+ differentiable_output_vars = {r.name for r in differentiable_outputs}
459
+ if not isinstance(view_info, str):
460
+ raise TypeError(
461
+ f"The view info should be a string for {base_name}, but it is: {view_info}"
462
+ )
463
+ if len(differentiable_output_vars) == 0:
464
+ # no output is differentiable (.indices() for SparseTensors for example)
465
+ rhs_value = (
466
+ f"as_view({view_info}, {var}, "
467
+ f"/* is_bw_differentiable */ false, /* is_fw_differentiable */ false)"
468
+ )
469
+ elif len(differentiable_output_vars) == 1:
470
+ # Single differentiable output (Tensor or Tensor[])
471
+ return_info = differentiable_outputs[0]
472
+ # We only support simple Tensor or a TensorList for functions that return views
473
+ if not is_tensor_type(return_info.type) and not is_tensor_list_type(
474
+ return_info.type
475
+ ):
476
+ raise RuntimeError(
477
+ f"{base_name} that return differentiable views can only return Tensor or Tensor[]"
478
+ )
479
+
480
+ # See Note [ View + Inplace detection]
481
+ def get_creation_meta_in_mode(original: str) -> str:
482
+ creation_meta_with_grad_mode = f"(at::GradMode::is_enabled() ? {original} : CreationMeta::NO_GRAD_MODE)"
483
+ return f"InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : {creation_meta_with_grad_mode}"
484
+
485
+ # Only allow rebasing of the history if we return a single Tensor
486
+ # If we are in a no grad block, raise a warning
487
+ # See NOTE [ View + Inplace detection ] for more details about this logic
488
+ if is_tensor_list_type(return_info.type):
489
+ creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE")
490
+ view_idx = "view_idx"
491
+ view_func = emit_view_func(
492
+ f, extract_bindings(f), view_idx=view_idx
493
+ ).strip()
494
+ as_view_call = (
495
+ f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], "
496
+ "/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, "
497
+ "/* view_func */ std::move(func), /* rev_view_func */ rev_func, "
498
+ f"/* creation_meta */ {creation_meta});"
499
+ )
500
+ call += MULTI_OUTPUT_VIEW_ITERATION.substitute(
501
+ var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}"
502
+ )
503
+ rhs_value = f"std::move({var})"
504
+ else:
505
+ call += emit_view_func(f, extract_bindings(f), view_idx=None)
506
+ creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT")
507
+ rhs_value = (
508
+ f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, "
509
+ "/* is_fw_differentiable */ true, "
510
+ f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})"
511
+ )
512
+ else:
513
+ # This could be supported but we don't need it at the moment, so keeping things simple.
514
+ raise RuntimeError(
515
+ "Function that return multiple differentiable output "
516
+ "when at least one of them is view is not supported."
517
+ )
518
+ return call, rhs_value
519
+
520
+
521
+ def modifies_arguments(f: NativeFunction) -> bool:
522
+ return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
523
+
524
+
525
+ @with_native_function_with_differentiability_info
526
+ def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]:
527
+ f = fn.func
528
+ inplace_view_body: list[str] = []
529
+
530
+ dispatcher_sig = DispatcherSignature.from_schema(f.func)
531
+ dispatcher_exprs = dispatcher_sig.exprs()
532
+
533
+ # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance.
534
+ # See Note [Plumbing Keys Through The Dispatcher] for details.
535
+ dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset"
536
+ redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
537
+
538
+ # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
539
+ # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
540
+ if modifies_arguments(f): # inplace op
541
+ inplace_view_body.append(
542
+ INPLACE_REDISPATCH.substitute(
543
+ unambiguous_name=f.func.name.unambiguous_name(),
544
+ unpacked_args=redispatch_args,
545
+ )
546
+ )
547
+ for r in cpp.return_names(f):
548
+ inplace_view_body.append(f"increment_version({r});")
549
+ else:
550
+ assert get_view_info(f) is not None
551
+ inplace_view_body.append(
552
+ VIEW_REDISPATCH.substitute(
553
+ assign_return_values="auto " + TMP_VAR + " = ",
554
+ unambiguous_name=f.func.name.unambiguous_name(),
555
+ unpacked_args=redispatch_args,
556
+ )
557
+ )
558
+ call, rhs_value = emit_view_body(fn, TMP_VAR)
559
+ inplace_view_body.append(call)
560
+ assert rhs_value is not None
561
+ inplace_view_body.append(
562
+ ASSIGN_RETURN_VALUE.substitute(
563
+ return_values=tie_return_values(f), rhs_value=rhs_value
564
+ )
565
+ )
566
+ if f.func.returns:
567
+ inplace_view_body.append(f"return {get_return_value(f)};")
568
+ return inplace_view_body
569
+
570
+
571
+ @with_native_function
572
+ def gen_formals(f: NativeFunction) -> str:
573
+ return ", ".join(
574
+ # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
575
+ # See Note [Plumbing Keys Through The Dispatcher] for details.
576
+ ["c10::DispatchKeySet ks"]
577
+ + [
578
+ f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}'
579
+ for a in f.func.schema_order_arguments()
580
+ ]
581
+ )
582
+
583
+
584
+ @with_native_function_with_differentiability_info
585
+ def inplace_or_view_method_definition(
586
+ fn: NativeFunctionWithDifferentiabilityInfo,
587
+ ) -> str | None:
588
+ f = fn.func
589
+ if get_view_info(f) is None and (
590
+ # For functions that modify their inputs but don't return them,
591
+ # we can't give them autograd support.
592
+ # See https://github.com/pytorch/pytorch/issues/53796
593
+ not modifies_arguments(f)
594
+ or len(f.func.returns) == 0
595
+ ):
596
+ return None
597
+ return METHOD_DEFINITION.substitute(
598
+ return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
599
+ type_wrapper_name=type_wrapper_name(f),
600
+ formals=gen_formals(f),
601
+ type_definition_body=emit_inplace_or_view_body(fn),
602
+ )
603
+
604
+
605
+ @with_native_function_with_differentiability_info
606
+ def inplace_or_view_method_registration(
607
+ fn: NativeFunctionWithDifferentiabilityInfo,
608
+ ) -> str | None:
609
+ f = fn.func
610
+ if get_view_info(f) is None and (
611
+ not modifies_arguments(f) or len(f.func.returns) == 0
612
+ ):
613
+ return None
614
+ return WRAPPER_REGISTRATION.substitute(
615
+ unqual_operator_name_with_overload=f.func.name,
616
+ type_wrapper_name=type_wrapper_name(f),
617
+ class_type="ADInplaceOrView",
618
+ )
619
+
620
+
621
+ def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
622
+ f = fn.func
623
+ name = cpp.name(f.func)
624
+ return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == "use_derived"
625
+
626
+
627
+ def gen_inplace_or_view_type_env(
628
+ fn: NativeFunctionWithDifferentiabilityInfo,
629
+ ) -> dict[str, list[str]]:
630
+ definition = inplace_or_view_method_definition(fn)
631
+ registration = inplace_or_view_method_registration(fn)
632
+
633
+ return {
634
+ "ops_headers": (
635
+ [f"#include <ATen/ops/{fn.func.root_name}_ops.h>"]
636
+ if definition is not None
637
+ else []
638
+ ),
639
+ "inplace_or_view_method_definitions": [definition]
640
+ if definition is not None
641
+ else [],
642
+ "inplace_or_view_wrapper_registrations": [registration]
643
+ if registration is not None
644
+ else [],
645
+ }
646
+
647
+
648
+ def gen_inplace_or_view_type(
649
+ out: str,
650
+ native_yaml_path: str,
651
+ tags_yaml_path: str,
652
+ fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
653
+ template_path: str,
654
+ ) -> None:
655
+ # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
656
+ # template regarding sharding of the generated files.
657
+ num_shards = 2
658
+
659
+ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
660
+ fm.write_sharded(
661
+ "ADInplaceOrViewType.cpp",
662
+ [fn for fn in fns_with_infos if use_derived(fn)],
663
+ key_fn=lambda fn: fn.func.root_name,
664
+ base_env={
665
+ "generated_comment": "@"
666
+ + f"generated from {fm.template_dir_for_comments()}/ADInplaceOrViewType.cpp",
667
+ },
668
+ env_callable=gen_inplace_or_view_type_env,
669
+ num_shards=2,
670
+ sharded_keys={
671
+ "ops_headers",
672
+ "inplace_or_view_method_definitions",
673
+ "inplace_or_view_wrapper_registrations",
674
+ },
675
+ )
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_python_functions.py ADDED
@@ -0,0 +1,1402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generates Python bindings for ATen functions
2
+ #
3
+ # The bindings are generated as methods on python_variable or functions on the
4
+ # torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse
5
+ # or torch._C._special objects.
6
+ #
7
+
8
+ # Code tries to stick to the following rules:
9
+ #
10
+ # - templates should be colocated with the functions that use them.
11
+ # no templates are currently shared between functions, but if that
12
+ # happens, maybe put the template with the first one
13
+ #
14
+ # - don't use environment dictionaries when calling template.substitute().
15
+ # pass named arguments directly for everything, otherwise it's much too
16
+ # hard to track what's actually being used and by who
17
+ #
18
+ # - colocate any new hacks/adjustments with existing ones of the same kind.
19
+ # ideally in a data structure rather than code if possible. See e.g.
20
+ # SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
21
+ #
22
+ # - similarly, conversions from one format to another should ideally happen
23
+ # all at once in a single place.
24
+ #
25
+ # - no nontrivial nested functions. couple-liners are ok but please no more.
26
+ # especially avoid functions that read/write outer variables defined far away.
27
+ #
28
+ # - raise RuntimeError instead of asserting, and put as much
29
+ # information as is available into the message. I.e. no need to
30
+ # plumb in new params whose only purpose is to fill out an error
31
+ # message, but use what's there
32
+ #
33
+
34
+ from __future__ import annotations
35
+
36
+ import itertools
37
+ import re
38
+ from collections import defaultdict
39
+ from typing import Callable, Iterable, Sequence
40
+
41
+ import yaml
42
+
43
+ from torchgen.api import cpp
44
+ from torchgen.api.python import (
45
+ arg_parser_output_exprs,
46
+ cpp_dispatch_exprs,
47
+ cpp_dispatch_target,
48
+ dispatch_lambda_args,
49
+ dispatch_lambda_exprs,
50
+ dispatch_lambda_return_str,
51
+ has_tensor_options,
52
+ PythonSignature,
53
+ PythonSignatureDeprecated,
54
+ PythonSignatureGroup,
55
+ PythonSignatureNativeFunctionPair,
56
+ signature,
57
+ signature_from_schema,
58
+ structseq_fieldnames,
59
+ )
60
+ from torchgen.code_template import CodeTemplate
61
+ from torchgen.context import with_native_function
62
+ from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
63
+ from torchgen.model import (
64
+ Argument,
65
+ BaseOperatorName,
66
+ FunctionSchema,
67
+ NativeFunction,
68
+ SchemaKind,
69
+ Type,
70
+ Variant,
71
+ )
72
+ from torchgen.utils import FileManager, split_name_params
73
+ from torchgen.yaml_utils import YamlLoader
74
+
75
+ from .gen_inplace_or_view_type import is_tensor_list_type
76
+ from .gen_trace_type import should_trace
77
+
78
+
79
+ #
80
+ # declarations blocklist
81
+ # We skip codegen for these functions, for various reasons.
82
+ # Future PRs will categorize this list and eliminate or hoist
83
+ # them out of eager-only codegen.
84
+ # See https://github.com/pytorch/pytorch/issues/30788
85
+ #
86
+
87
+ # These functions require manual Python bindings or are not exposed to Python
88
+ _SKIP_PYTHON_BINDINGS = [
89
+ "alias",
90
+ "contiguous",
91
+ "is_cuda",
92
+ "is_sparse",
93
+ "is_sparse_csr",
94
+ "size",
95
+ "stride",
96
+ "sym_size",
97
+ "sym_stride",
98
+ "sym_storage_offset",
99
+ "sym_numel",
100
+ ".*_backward",
101
+ ".*_backward_(out|input|weight|bias)",
102
+ ".*_forward",
103
+ ".*_forward_out",
104
+ ".*_jvp",
105
+ "_unsafe_view",
106
+ "tensor",
107
+ "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*",
108
+ "_range.*",
109
+ "_sparse_add_out",
110
+ "_sparse_div.*",
111
+ "_sparse_mul.*",
112
+ "_sparse_sub.*",
113
+ "_sparse_dense_add_out",
114
+ "index",
115
+ "index_out",
116
+ "unique_dim_consecutive",
117
+ "_cumsum.*",
118
+ "_cumprod.*",
119
+ "_sum.*",
120
+ "_prod.*",
121
+ "_th_.*",
122
+ "_thnn_.*",
123
+ "range.*",
124
+ "_solve.*",
125
+ "_inverse.*",
126
+ "_cholesky.*",
127
+ "_triangular_solve.*",
128
+ "_qr.*",
129
+ "_svd.*",
130
+ "slice",
131
+ "item",
132
+ "_local_scalar_dense",
133
+ "to",
134
+ "_to_copy",
135
+ "_to_copy_out",
136
+ "_reshape_copy",
137
+ "_reshape_copy_out",
138
+ "copy_sparse_to_sparse_",
139
+ "copy_",
140
+ "_foreach_copy",
141
+ "numpy_T",
142
+ "matrix_H",
143
+ "mT",
144
+ "mH", # these need to be an attributes in Python, not functions
145
+ "nonzero(_(out|numpy))?",
146
+ "set_data",
147
+ ".*_overrideable", # overrideable functions for backend extension
148
+ "data",
149
+ "is_leaf",
150
+ "output_nr",
151
+ "_version",
152
+ "requires_grad_",
153
+ "retains_grad",
154
+ "set_",
155
+ "_fw_primal",
156
+ "fake_quantize_per_tensor_affine_cachemask",
157
+ "fake_quantize_per_channel_affine_cachemask",
158
+ "_new_zeros_with_same_feature_meta",
159
+ "_has_same_storage_numel", # used for forward AD internals
160
+ "_reshape_alias",
161
+ "replace_", # only used by the functionalization pass, doesn't need to be exposed to python
162
+ "copy", # only used by the functionalization pass
163
+ "fill.Tensor", # only used by the functionalization pass
164
+ "fill.Scalar", # only used by the functionalization pass
165
+ "lift.*",
166
+ "normal_functional", # only used by the functionalization pass
167
+ "nbytes",
168
+ "itemsize",
169
+ "_batch_norm_with_update",
170
+ "_batch_norm_with_update_out",
171
+ "_batch_norm_no_update",
172
+ ]
173
+
174
+ SKIP_PYTHON_BINDINGS = [
175
+ re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS
176
+ ]
177
+
178
+ # These function signatures are not exposed to Python. Note that this signature
179
+ # list does not support regex.
180
+ SKIP_PYTHON_BINDINGS_SIGNATURES = [
181
+ "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
182
+ "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
183
+ "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
184
+ "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
185
+ "mul.Scalar(Tensor self, Scalar other) -> Tensor",
186
+ "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
187
+ "div.Scalar(Tensor self, Scalar other) -> Tensor",
188
+ "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
189
+ ]
190
+
191
+
192
+ @with_native_function
193
+ def should_generate_py_binding(f: NativeFunction) -> bool:
194
+ # NativeFunctions that are entirely code-generated should not get python bindings
195
+ # because these codegen implementations are often inefficient. A handful of
196
+ # view_copy style ops were exposed accidentally when they were handwritten and now
197
+ # that we are moving them to codegen for bc reasons we need to keep them exposed in
198
+ # python.
199
+ if "generated" in f.tags and "view_copy" not in f.tags:
200
+ return False
201
+
202
+ name = cpp.name(f.func)
203
+ for skip_regex in SKIP_PYTHON_BINDINGS:
204
+ if skip_regex.match(name):
205
+ return False
206
+
207
+ signature = str(f.func)
208
+ for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
209
+ if pattern == signature:
210
+ return False
211
+ return True
212
+
213
+
214
+ def get_pycname(name: BaseOperatorName) -> str:
215
+ return f"THPVariable_{name}"
216
+
217
+
218
+ def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool:
219
+ return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0
220
+
221
+
222
+ def is_py_variable_method(f: NativeFunction) -> bool:
223
+ return f.python_module is None and Variant.method in f.variants
224
+
225
+
226
+ def is_py_torch_function(f: NativeFunction) -> bool:
227
+ return f.python_module is None and Variant.function in f.variants
228
+
229
+
230
+ def is_py_nn_function(f: NativeFunction) -> bool:
231
+ return f.python_module == "nn"
232
+
233
+
234
+ def is_py_fft_function(f: NativeFunction) -> bool:
235
+ return f.python_module == "fft"
236
+
237
+
238
+ def is_py_linalg_function(f: NativeFunction) -> bool:
239
+ return f.python_module == "linalg"
240
+
241
+
242
+ def is_py_nested_function(f: NativeFunction) -> bool:
243
+ return f.python_module == "nested"
244
+
245
+
246
+ def is_py_sparse_function(f: NativeFunction) -> bool:
247
+ return f.python_module == "sparse"
248
+
249
+
250
+ def is_py_special_function(f: NativeFunction) -> bool:
251
+ return f.python_module == "special"
252
+
253
+
254
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
255
+ #
256
+ # Main Function
257
+ #
258
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
259
+
260
+
261
+ def gen(
262
+ out: str,
263
+ native_yaml_path: str,
264
+ tags_yaml_path: str,
265
+ deprecated_yaml_path: str,
266
+ template_path: str,
267
+ *,
268
+ symint: bool = True,
269
+ ) -> None:
270
+ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
271
+ native_functions = parse_native_yaml(
272
+ native_yaml_path, tags_yaml_path
273
+ ).native_functions
274
+ native_functions = list(filter(should_generate_py_binding, native_functions))
275
+
276
+ methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
277
+ create_python_bindings(
278
+ fm,
279
+ methods,
280
+ is_py_variable_method,
281
+ None,
282
+ "python_variable_methods.cpp",
283
+ method=True,
284
+ symint=symint,
285
+ )
286
+
287
+ # NOTE: num_shards here must be synced with gatherTorchFunctions in
288
+ # torch/csrc/autograd/python_torch_functions_manual.cpp
289
+ functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
290
+ create_python_bindings_sharded(
291
+ fm,
292
+ functions,
293
+ is_py_torch_function,
294
+ "torch",
295
+ "python_torch_functions.cpp",
296
+ method=False,
297
+ num_shards=3,
298
+ symint=symint,
299
+ )
300
+
301
+ create_python_bindings(
302
+ fm,
303
+ functions,
304
+ is_py_nn_function,
305
+ "torch.nn",
306
+ "python_nn_functions.cpp",
307
+ method=False,
308
+ symint=symint,
309
+ )
310
+
311
+ create_python_bindings(
312
+ fm,
313
+ functions,
314
+ is_py_fft_function,
315
+ "torch.fft",
316
+ "python_fft_functions.cpp",
317
+ method=False,
318
+ symint=symint,
319
+ )
320
+
321
+ create_python_bindings(
322
+ fm,
323
+ functions,
324
+ is_py_linalg_function,
325
+ "torch.linalg",
326
+ "python_linalg_functions.cpp",
327
+ method=False,
328
+ symint=symint,
329
+ )
330
+
331
+ create_python_bindings(
332
+ fm,
333
+ functions,
334
+ is_py_nested_function,
335
+ "torch.nested",
336
+ "python_nested_functions.cpp",
337
+ method=False,
338
+ )
339
+
340
+ create_python_bindings(
341
+ fm,
342
+ functions,
343
+ is_py_sparse_function,
344
+ "torch.sparse",
345
+ "python_sparse_functions.cpp",
346
+ method=False,
347
+ symint=symint,
348
+ )
349
+
350
+ create_python_bindings(
351
+ fm,
352
+ functions,
353
+ is_py_special_function,
354
+ "torch.special",
355
+ "python_special_functions.cpp",
356
+ method=False,
357
+ symint=symint,
358
+ )
359
+
360
+ # Currently, we only use `functions` to generate `return_types` bindings.
361
+ # All methods which return structseq have function variant at this point.
362
+ # If any method only operator with structseq is added in the future,
363
+ # we will have to address that.
364
+ create_python_return_type_bindings(
365
+ fm, functions, lambda fn: True, "python_return_types.cpp"
366
+ )
367
+ create_python_return_type_bindings_header(
368
+ fm, functions, lambda fn: True, "python_return_types.h"
369
+ )
370
+
371
+ valid_tags = parse_tags_yaml(tags_yaml_path)
372
+
373
+ def gen_tags_enum() -> dict[str, str]:
374
+ return {
375
+ "enum_of_valid_tags": (
376
+ "".join(
377
+ [f'\n.value("{tag}", at::Tag::{tag})' for tag in sorted(valid_tags)]
378
+ )
379
+ )
380
+ }
381
+
382
+ fm.write("python_enum_tag.cpp", gen_tags_enum)
383
+
384
+
385
+ def group_filter_overloads(
386
+ pairs: Sequence[PythonSignatureNativeFunctionPair],
387
+ pred: Callable[[NativeFunction], bool],
388
+ ) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
389
+ grouped: dict[
390
+ BaseOperatorName, list[PythonSignatureNativeFunctionPair]
391
+ ] = defaultdict(list)
392
+ for pair in pairs:
393
+ if pred(pair.function):
394
+ grouped[pair.function.func.name.name].append(pair)
395
+ return grouped
396
+
397
+
398
+ def create_python_bindings(
399
+ fm: FileManager,
400
+ pairs: Sequence[PythonSignatureNativeFunctionPair],
401
+ pred: Callable[[NativeFunction], bool],
402
+ module: str | None,
403
+ filename: str,
404
+ *,
405
+ method: bool,
406
+ symint: bool = True,
407
+ ) -> None:
408
+ """Generates Python bindings to ATen functions"""
409
+ py_methods: list[str] = []
410
+ ops_headers: list[str] = []
411
+ py_method_defs: list[str] = []
412
+ py_forwards: list[str] = []
413
+
414
+ grouped = group_filter_overloads(pairs, pred)
415
+
416
+ for name in sorted(grouped.keys(), key=str):
417
+ overloads = grouped[name]
418
+ py_methods.append(
419
+ method_impl(name, module, overloads, method=method, symint=symint)
420
+ )
421
+ py_method_defs.append(method_def(name, module, overloads, method=method))
422
+ py_forwards.extend(forward_decls(name, overloads, method=method))
423
+ ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
424
+
425
+ fm.write_with_template(
426
+ filename,
427
+ filename,
428
+ lambda: {
429
+ "generated_comment": "@"
430
+ + f"generated from {fm.template_dir_for_comments()}/{filename}",
431
+ "ops_headers": ops_headers,
432
+ "py_forwards": py_forwards,
433
+ "py_methods": py_methods,
434
+ "py_method_defs": py_method_defs,
435
+ },
436
+ )
437
+
438
+
439
+ def create_python_return_type_bindings(
440
+ fm: FileManager,
441
+ pairs: Sequence[PythonSignatureNativeFunctionPair],
442
+ pred: Callable[[NativeFunction], bool],
443
+ filename: str,
444
+ ) -> None:
445
+ """
446
+ Generate function to initialize and return named tuple for native functions
447
+ which returns named tuple and registration invocations in `python_return_types.cpp`.
448
+ """
449
+ py_return_types_definition: list[str] = []
450
+ py_return_types_registrations: list[str] = []
451
+
452
+ grouped = group_filter_overloads(pairs, pred)
453
+
454
+ for name in sorted(grouped.keys(), key=str):
455
+ overloads = grouped[name]
456
+ definitions, registrations = generate_return_type_definition_and_registrations(
457
+ overloads
458
+ )
459
+ py_return_types_definition.append(
460
+ "" if not definitions else "\n".join(definitions)
461
+ )
462
+ py_return_types_registrations.append(
463
+ "" if not registrations else "\n".join(registrations)
464
+ )
465
+
466
+ fm.write_with_template(
467
+ filename,
468
+ filename,
469
+ lambda: {
470
+ "generated_comment": "@"
471
+ + f"generated from {fm.template_dir_for_comments()}/{filename}",
472
+ "py_return_types": py_return_types_definition,
473
+ "py_return_types_registrations": py_return_types_registrations,
474
+ },
475
+ )
476
+
477
+
478
+ def create_python_return_type_bindings_header(
479
+ fm: FileManager,
480
+ pairs: Sequence[PythonSignatureNativeFunctionPair],
481
+ pred: Callable[[NativeFunction], bool],
482
+ filename: str,
483
+ ) -> None:
484
+ """
485
+ Generate function to initialize and return named tuple for native functions
486
+ which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
487
+ """
488
+ py_return_types_declarations: list[str] = []
489
+
490
+ grouped = group_filter_overloads(pairs, pred)
491
+
492
+ for name in sorted(grouped.keys(), key=str):
493
+ overloads = grouped[name]
494
+ declarations = generate_return_type_declarations(overloads)
495
+ py_return_types_declarations.append(
496
+ "" if not declarations else "\n".join(declarations)
497
+ )
498
+
499
+ fm.write_with_template(
500
+ filename,
501
+ filename,
502
+ lambda: {
503
+ "generated_comment": "@"
504
+ + f"generated from {fm.template_dir_for_comments()}/{filename}",
505
+ "py_return_types_declarations": py_return_types_declarations,
506
+ },
507
+ )
508
+
509
+
510
+ def create_python_bindings_sharded(
511
+ fm: FileManager,
512
+ pairs: Sequence[PythonSignatureNativeFunctionPair],
513
+ pred: Callable[[NativeFunction], bool],
514
+ module: str | None,
515
+ filename: str,
516
+ *,
517
+ method: bool,
518
+ num_shards: int,
519
+ symint: bool = True,
520
+ ) -> None:
521
+ """Generates Python bindings to ATen functions"""
522
+ grouped = group_filter_overloads(pairs, pred)
523
+
524
+ def key_func(
525
+ kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
526
+ ) -> str:
527
+ return kv[0].base
528
+
529
+ def env_func(
530
+ kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]
531
+ ) -> dict[str, list[str]]:
532
+ name, fn_pairs = kv
533
+ return {
534
+ "ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
535
+ "py_forwards": list(forward_decls(name, fn_pairs, method=method)),
536
+ "py_methods": [
537
+ method_impl(name, module, fn_pairs, method=method, symint=symint)
538
+ ],
539
+ "py_method_defs": [method_def(name, module, fn_pairs, method=method)],
540
+ }
541
+
542
+ fm.write_sharded(
543
+ filename,
544
+ grouped.items(),
545
+ base_env={
546
+ "generated_comment": "@"
547
+ + f"generated from {fm.template_dir_for_comments()}/{filename}",
548
+ },
549
+ key_fn=key_func,
550
+ env_callable=env_func,
551
+ num_shards=num_shards,
552
+ sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
553
+ )
554
+
555
+
556
+ def load_signatures(
557
+ native_functions: list[NativeFunction],
558
+ deprecated_yaml_path: str,
559
+ *,
560
+ method: bool,
561
+ skip_deprecated: bool = False,
562
+ pyi: bool = False,
563
+ ) -> Sequence[PythonSignatureNativeFunctionPair]:
564
+ @with_native_function
565
+ def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
566
+ return PythonSignatureNativeFunctionPair(
567
+ signature=signature(f, method=method, pyi=pyi),
568
+ function=f,
569
+ )
570
+
571
+ pairs = list(map(gen_signature_pairs, native_functions))
572
+ deprecated = load_deprecated_signatures(
573
+ pairs, deprecated_yaml_path, method=method, pyi=pyi
574
+ )
575
+ return pairs if skip_deprecated else pairs + deprecated
576
+
577
+
578
+ def load_deprecated_signatures(
579
+ pairs: Sequence[PythonSignatureNativeFunctionPair],
580
+ deprecated_yaml_path: str,
581
+ *,
582
+ method: bool,
583
+ pyi: bool,
584
+ ) -> list[PythonSignatureNativeFunctionPair]:
585
+ # The deprecated.yaml doesn't have complete type information, we need
586
+ # find and leverage the original ATen signature (to which it delegates
587
+ # the call) to generate the full python signature.
588
+ # We join the deprecated and the original signatures using type-only form.
589
+
590
+ # group the original ATen signatures by name
591
+ grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list)
592
+ for pair in pairs:
593
+ grouped[pair.signature.name].append(pair)
594
+
595
+ # find matching original signatures for each deprecated signature
596
+ results: list[PythonSignatureNativeFunctionPair] = []
597
+
598
+ with open(deprecated_yaml_path) as f:
599
+ deprecated_defs = yaml.load(f, Loader=YamlLoader)
600
+
601
+ for deprecated in deprecated_defs:
602
+ schema = FunctionSchema.parse(deprecated["name"])
603
+ aten_name, call_args = split_name_params(deprecated["aten"])
604
+ is_out = aten_name.endswith("_out")
605
+ if is_out:
606
+ aten_name = aten_name.replace("_out", "")
607
+
608
+ # HACK: these are fixed constants used to pass the aten function.
609
+ # The type must be known ahead of time
610
+ known_constants = {
611
+ "1": Type.parse("Scalar"),
612
+ }
613
+ schema_args_by_name = {a.name: a for a in schema.arguments.flat_all}
614
+ for name in call_args:
615
+ assert (
616
+ name in schema_args_by_name or name in known_constants
617
+ ), f"deprecation definiton: Unrecognized value {name}"
618
+
619
+ # Map deprecated signature arguments to their aten signature and test
620
+ # if the types and alias annotation match.
621
+ def is_schema_compatible(
622
+ aten_schema: FunctionSchema,
623
+ ) -> bool:
624
+ arguments: Iterable[Argument]
625
+ if is_out:
626
+ arguments = itertools.chain(
627
+ aten_schema.arguments.out, aten_schema.arguments.flat_non_out
628
+ )
629
+ else:
630
+ arguments = aten_schema.arguments.flat_all
631
+
632
+ for i, arg in enumerate(arguments):
633
+ if i < len(call_args):
634
+ arg_name = call_args[i]
635
+ if arg_name in known_constants:
636
+ schema_type = known_constants[arg_name]
637
+ schema_annotation = None
638
+ else:
639
+ schema_arg = schema_args_by_name[arg_name]
640
+ schema_type = schema_arg.type
641
+ schema_annotation = schema_arg.annotation
642
+
643
+ if schema_type != arg.type or schema_annotation != arg.annotation:
644
+ return False
645
+ else:
646
+ if arg.default is None:
647
+ return False
648
+
649
+ return len(schema.returns) == len(aten_schema.returns) and all(
650
+ a == b for a, b in zip(schema.returns, aten_schema.returns)
651
+ )
652
+
653
+ any_schema_found = False
654
+ for pair in grouped[aten_name]:
655
+ if not is_schema_compatible(pair.function.func):
656
+ continue
657
+ any_schema_found = True
658
+
659
+ python_sig = signature_from_schema(
660
+ schema,
661
+ category_override=pair.function.category_override,
662
+ method=method,
663
+ pyi=pyi,
664
+ )
665
+
666
+ results.append(
667
+ PythonSignatureNativeFunctionPair(
668
+ signature=PythonSignatureDeprecated(
669
+ name=python_sig.name,
670
+ input_args=python_sig.input_args,
671
+ input_kwargs=python_sig.input_kwargs,
672
+ output_args=python_sig.output_args,
673
+ tensor_options_args=python_sig.tensor_options_args,
674
+ method=python_sig.method,
675
+ deprecated_schema=schema,
676
+ deprecated_args_exprs=tuple(call_args),
677
+ returns=python_sig.returns,
678
+ ),
679
+ function=pair.function,
680
+ )
681
+ )
682
+ assert (
683
+ any_schema_found
684
+ ), f"No native function with name {aten_name} matched signature:\n {str(schema)}"
685
+
686
+ return results
687
+
688
+
689
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
690
+ #
691
+ # Named Tuple Codegen
692
+ #
693
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
694
+
695
+
696
+ @with_native_function
697
+ def gen_structseq_typename_key(f: NativeFunction) -> str:
698
+ name = cpp.name(f.func)
699
+ fieldnames = structseq_fieldnames(f.func.returns)
700
+ return "_".join([name] + fieldnames)
701
+
702
+
703
+ def emit_structseq_call(
704
+ overloads: Sequence[PythonSignatureNativeFunctionPair],
705
+ ) -> tuple[list[str], dict[str, str]]:
706
+ """
707
+ Generate block of named tuple type def inits, and add typeref snippets
708
+ to declarations that use them
709
+ """
710
+ typenames: dict[
711
+ str, str
712
+ ] = {} # map from unique name + field name lists to typedef name
713
+ typedefs: list[str] = [] # typedef declarations and init code
714
+
715
+ for overload in overloads:
716
+ fieldnames = structseq_fieldnames(overload.function.func.returns)
717
+ if not fieldnames:
718
+ continue
719
+
720
+ name = cpp.name(overload.function.func) # use @with_native_function?
721
+ tn_key = gen_structseq_typename_key(overload.function)
722
+ typename = typenames.get(tn_key)
723
+ if typename is None:
724
+ typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
725
+ typenames[tn_key] = typename
726
+ typedefs.append(
727
+ f"""\
728
+ static PyTypeObject* {typename} = generated::get_{name}_structseq();"""
729
+ )
730
+
731
+ return typedefs, typenames
732
+
733
+
734
+ def generate_return_type_definition_and_registrations(
735
+ overloads: Sequence[PythonSignatureNativeFunctionPair],
736
+ ) -> tuple[list[str], list[str]]:
737
+ """
738
+ Generate block of function in `python_return_types.cpp` to initialize
739
+ and return named tuple for a native function which returns named tuple
740
+ and registration invocations in same file.
741
+ """
742
+ typenames: dict[
743
+ str, str
744
+ ] = {} # map from unique name + field name lists to typedef name
745
+ definitions: list[str] = [] # function definition to register the typedef
746
+ registrations: list[str] = [] # register call for the typedef
747
+
748
+ for overload in overloads:
749
+ fieldnames = structseq_fieldnames(overload.function.func.returns)
750
+ if not fieldnames:
751
+ continue
752
+
753
+ fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames)
754
+
755
+ name = cpp.name(overload.function.func) # use @with_native_function?
756
+ tn_key = gen_structseq_typename_key(overload.function)
757
+ typename = typenames.get(tn_key)
758
+
759
+ if typename is None:
760
+ typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
761
+ typenames[tn_key] = typename
762
+ definitions.append(
763
+ f"""\
764
+ PyTypeObject* get_{name}_structseq() {{
765
+ static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }};
766
+ static PyTypeObject {typename};
767
+ static bool is_initialized = false;
768
+ static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
769
+ if (!is_initialized) {{
770
+ PyStructSequence_InitType(&{typename}, &desc);
771
+ {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
772
+ is_initialized = true;
773
+ }}
774
+ return &{typename};
775
+ }}
776
+ """
777
+ )
778
+ registrations.append(
779
+ f'addReturnType(return_types_module, "{name}", generated::get_{name}_structseq());'
780
+ )
781
+
782
+ return definitions, registrations
783
+
784
+
785
+ def generate_return_type_declarations(
786
+ overloads: Sequence[PythonSignatureNativeFunctionPair],
787
+ ) -> list[str]:
788
+ """
789
+ Generate block of function declarations in `python_return_types.h` to initialize
790
+ and return named tuple for a native function.
791
+ """
792
+ typenames: dict[
793
+ str, str
794
+ ] = {} # map from unique name + field name lists to typedef name
795
+ declarations: list[str] = [] # function declaration to register the typedef
796
+
797
+ for overload in overloads:
798
+ fieldnames = structseq_fieldnames(overload.function.func.returns)
799
+ if not fieldnames:
800
+ continue
801
+
802
+ name = cpp.name(overload.function.func) # use @with_native_function?
803
+ tn_key = gen_structseq_typename_key(overload.function)
804
+ typename = typenames.get(tn_key)
805
+
806
+ if typename is None:
807
+ typename = (
808
+ f'{name}NamedTuple{"" if not declarations else len(declarations)}'
809
+ )
810
+ typenames[tn_key] = typename
811
+ declarations.append(f"PyTypeObject* get_{name}_structseq();")
812
+
813
+ return declarations
814
+
815
+
816
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
817
+ #
818
+ # Method Impl Codegen
819
+ #
820
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
821
+
822
+ # python binding for all overloads of a particular function/method
823
+ PY_VARIABLE_METHOD_VARARGS = CodeTemplate(
824
+ r"""\
825
+ // ${name}
826
+ static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
827
+ {
828
+ ${method_header}
829
+ static PythonArgParser parser({
830
+ ${signatures}
831
+ }, /*traceable=*/${traceable});
832
+
833
+ ParsedArgs<${max_args}> parsed_args;
834
+ auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
835
+ ${check_has_torch_function}
836
+ switch (_r.idx) {
837
+ ${dispatch}
838
+ }
839
+ ${method_footer}
840
+ }
841
+
842
+ """
843
+ )
844
+
845
+ # handler for a single parsed signature - may be a single overload or
846
+ # a pair of overloads that whose signatures only differ in output params
847
+ # (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
848
+ PY_VARIABLE_CASE = CodeTemplate(
849
+ """\
850
+ case ${overload_index}: {
851
+ ${body}
852
+ }
853
+ """
854
+ )
855
+
856
+ # python binding for single-overload function/method
857
+ PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate(
858
+ """\
859
+ // ${name}
860
+ static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
861
+ {
862
+ ${method_header}
863
+ static PythonArgParser parser({
864
+ ${signatures}
865
+ }, /*traceable=*/${traceable});
866
+
867
+ ParsedArgs<${max_args}> parsed_args;
868
+ auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
869
+ ${check_has_torch_function}
870
+ ${dispatch}
871
+ ${method_footer}
872
+ }
873
+
874
+ """
875
+ )
876
+
877
+ # python binding for a method with no args, shortcuts parsing
878
+ PY_VARIABLE_METHOD_NOARGS = CodeTemplate(
879
+ """\
880
+ // ${name}
881
+ static PyObject * ${pycname}(PyObject* self_, PyObject* args)
882
+ {
883
+ ${method_header}
884
+ ${check_has_torch_function}
885
+ ${dispatch}
886
+ ${method_footer}
887
+ }
888
+
889
+ """
890
+ )
891
+
892
+
893
+ def method_impl(
894
+ name: BaseOperatorName,
895
+ module: str | None,
896
+ overloads: Sequence[PythonSignatureNativeFunctionPair],
897
+ *,
898
+ method: bool,
899
+ symint: bool = True,
900
+ ) -> str:
901
+ """
902
+ Generate a python binding for all overloads of an op.
903
+ """
904
+ pycname = get_pycname(name)
905
+ noarg = is_noarg(overloads)
906
+ structseq_inits, structseq_typenames = emit_structseq_call(overloads)
907
+
908
+ method_header = ["HANDLE_TH_ERRORS"]
909
+ method_header += structseq_inits
910
+ method_header += (
911
+ ["const Tensor& self = THPVariable_Unpack(self_);"] if method else []
912
+ )
913
+
914
+ method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"]
915
+
916
+ traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
917
+
918
+ grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(
919
+ overloads, symint=symint
920
+ )
921
+ is_singleton = len(grouped_overloads) == 1
922
+ signatures: list[str] = []
923
+ dispatch: list[str] = []
924
+ for overload_index, overload in enumerate(grouped_overloads):
925
+ signature = overload.signature.signature_str(symint=symint)
926
+ signatures.append(f"{cpp_string(str(signature))},")
927
+ dispatch_body = emit_dispatch_case(overload, structseq_typenames, symint=symint)
928
+ dispatch.append(
929
+ PY_VARIABLE_CASE.substitute(
930
+ overload_index=overload_index, body=dispatch_body
931
+ )
932
+ if not is_singleton
933
+ else dispatch_body
934
+ )
935
+
936
+ if noarg:
937
+ template = PY_VARIABLE_METHOD_NOARGS
938
+ elif is_singleton:
939
+ template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
940
+ else:
941
+ template = PY_VARIABLE_METHOD_VARARGS
942
+
943
+ return template.substitute(
944
+ name=name,
945
+ pycname=pycname,
946
+ method_header=method_header,
947
+ max_args=max(o.signature.arguments_count() for o in overloads),
948
+ signatures=signatures,
949
+ traceable=traceable,
950
+ check_has_torch_function=gen_has_torch_function_check(
951
+ name=name,
952
+ module=module,
953
+ noarg=noarg,
954
+ method=method,
955
+ ),
956
+ dispatch=dispatch,
957
+ method_footer=method_footer,
958
+ self_="self_" if method else "nullptr",
959
+ )
960
+
961
+
962
+ def gen_has_torch_function_check(
963
+ name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool
964
+ ) -> str:
965
+ if noarg:
966
+ if method:
967
+ return f"""\
968
+ if(check_has_torch_function(self_)) {{
969
+ return handle_torch_function(self_, "{name}");
970
+ }}
971
+ """
972
+ else:
973
+ return ""
974
+
975
+ self_ = "self_" if method else "nullptr"
976
+ namespace = (
977
+ {
978
+ "torch": "THPVariableFunctionsModule",
979
+ "torch.nn": "THPNNVariableFunctionsModule",
980
+ "torch.fft": "THPFFTVariableFunctionsModule",
981
+ "torch.linalg": "THPLinalgVariableFunctionsModule",
982
+ "torch.nested": "THPNestedVariableFunctionsModule",
983
+ "torch.sparse": "THPSparseVariableFunctionsModule",
984
+ "torch.special": "THPSpecialVariableFunctionsModule",
985
+ }[module]
986
+ if module
987
+ else "THPVariableClass"
988
+ )
989
+
990
+ return f"""\
991
+ if(_r.has_torch_function()) {{
992
+ return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}");
993
+ }}
994
+ """
995
+
996
+
997
+ # handler for output/no-output overload pair
998
+ PY_VARIABLE_OUT = CodeTemplate(
999
+ """\
1000
+ if (_r.isNone(${out_idx})) {
1001
+ ${call_dispatch}
1002
+ } else {
1003
+ ${call_dispatch_out}
1004
+ }
1005
+ """
1006
+ )
1007
+
1008
+
1009
+ def emit_dispatch_case(
1010
+ overload: PythonSignatureGroup,
1011
+ structseq_typenames: dict[str, str],
1012
+ *,
1013
+ symint: bool = True,
1014
+ ) -> str:
1015
+ """
1016
+ Emit dispatch code for a single parsed signature. This corresponds to either
1017
+ a single native function, or a pair that differ only in output params. In the
1018
+ latter case, a single python signature is used for both and dispatching
1019
+ switches on the presence/absence of passed output args.
1020
+ """
1021
+ if overload.outplace is not None:
1022
+ # dispatch output and no-output variants, branch on _r.isNone(<out_idx>)
1023
+ return PY_VARIABLE_OUT.substitute(
1024
+ out_idx=overload.signature.output_idx(),
1025
+ call_dispatch=emit_single_dispatch(
1026
+ overload.signature, overload.base, structseq_typenames, symint=symint
1027
+ ),
1028
+ call_dispatch_out=emit_single_dispatch(
1029
+ overload.signature,
1030
+ overload.outplace,
1031
+ structseq_typenames,
1032
+ symint=symint,
1033
+ ),
1034
+ )
1035
+ else:
1036
+ # no-output version only
1037
+ return emit_single_dispatch(
1038
+ overload.signature, overload.base, structseq_typenames, symint=symint
1039
+ )
1040
+
1041
+
1042
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1043
+ #
1044
+ # Forward Declarations Codegen
1045
+ #
1046
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1047
+
1048
+
1049
+ def forward_decls(
1050
+ name: BaseOperatorName,
1051
+ overloads: Sequence[PythonSignatureNativeFunctionPair],
1052
+ *,
1053
+ method: bool,
1054
+ ) -> tuple[str, ...]:
1055
+ if method:
1056
+ return ()
1057
+
1058
+ pycname = get_pycname(name)
1059
+ if is_noarg(overloads):
1060
+ return (
1061
+ f"""\
1062
+ static PyObject * {pycname}(PyObject* self_, PyObject* args);
1063
+ """,
1064
+ )
1065
+ else:
1066
+ return (
1067
+ f"""\
1068
+ static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
1069
+ """,
1070
+ )
1071
+
1072
+
1073
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1074
+ #
1075
+ # Method Def (Binding Table Entry) Codegen
1076
+ #
1077
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1078
+
1079
+
1080
+ def method_def(
1081
+ name: BaseOperatorName,
1082
+ module: str | None,
1083
+ overloads: Sequence[PythonSignatureNativeFunctionPair],
1084
+ *,
1085
+ method: bool,
1086
+ ) -> str:
1087
+ """
1088
+ Generate method def entry.
1089
+ """
1090
+ pycname = get_pycname(name)
1091
+
1092
+ if name.dunder_method:
1093
+ # PyMethodDef entry for binary op, throws not implemented error
1094
+ pycname = f"TypeError_to_NotImplemented_<{pycname}>"
1095
+
1096
+ if is_noarg(overloads):
1097
+ flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS"
1098
+ else:
1099
+ pycname = f"castPyCFunctionWithKeywords({pycname})"
1100
+ flags = "METH_VARARGS | METH_KEYWORDS"
1101
+
1102
+ if module == "torch":
1103
+ flags += " | METH_STATIC"
1104
+
1105
+ return f'{{"{name}", {pycname}, {flags}, NULL}},'
1106
+
1107
+
1108
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1109
+ #
1110
+ # Overload Sorting and Grouping
1111
+ #
1112
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1113
+
1114
+
1115
+ def group_overloads(
1116
+ overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
1117
+ ) -> Sequence[PythonSignatureGroup]:
1118
+ bases: dict[str, PythonSignatureNativeFunctionPair] = {}
1119
+ outplaces: dict[str, PythonSignatureNativeFunctionPair] = {}
1120
+
1121
+ # first group by signature ignoring out arguments
1122
+ for overload in overloads:
1123
+ sig = overload.signature.signature_str(skip_outputs=True, symint=symint)
1124
+ if overload.function.func.is_out_fn():
1125
+ if sig in outplaces:
1126
+ raise RuntimeError(
1127
+ f"Found duplicated function definition:\n- {overload.function.func}.\n"
1128
+ f"Existing definition:\n- {outplaces[sig].function.func}."
1129
+ )
1130
+ outplaces[sig] = overload
1131
+ else:
1132
+ if sig in bases:
1133
+ raise RuntimeError(
1134
+ f"Found duplicated function definition:\n- {overload.function.func}.\n"
1135
+ f"Existing definition:\n- {bases[sig].function.func}."
1136
+ )
1137
+ bases[sig] = overload
1138
+
1139
+ for sig, out in outplaces.items():
1140
+ if sig not in bases:
1141
+ candidates: list[str] = []
1142
+ for overload in overloads:
1143
+ if (
1144
+ str(overload.function.func.name.name)
1145
+ == str(out.function.func.name.name)
1146
+ and not overload.function.func.is_out_fn()
1147
+ and not overload.signature.deprecated
1148
+ ):
1149
+ candidates.append(
1150
+ overload.signature.signature_str(
1151
+ skip_outputs=True, symint=symint
1152
+ )
1153
+ )
1154
+ out_sig = out.signature.signature_str(symint=symint)
1155
+ raise RuntimeError(
1156
+ f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
1157
+ f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
1158
+ "correctly in native_functions.yaml. We discovered the following candidate(s): \n"
1159
+ + "\n".join(f"- {candidate}" for candidate in candidates)
1160
+ )
1161
+
1162
+ grouped = [
1163
+ PythonSignatureGroup.from_pairs(
1164
+ functional=base,
1165
+ out=outplaces.get(sig),
1166
+ )
1167
+ for sig, base in bases.items()
1168
+ ]
1169
+ return sort_overloads(grouped, symint=symint)
1170
+
1171
+
1172
+ # This function declares a partial order on declarations, and sorts them according
1173
+ # to its linear extension. This is necessary, because there's some ambiguity in the
1174
+ # choice of overload, and we want a different order.
1175
+ #
1176
+ # See Note[Order of overloads matters]
1177
+ #
1178
+ # A few examples of ambiguous python signature pairs.
1179
+ #
1180
+ # All parameters have the same type, except one taking Tensor the other taking
1181
+ # Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor
1182
+ # object can be accepted as Scalar type parameter (see python_arg_parser.cpp).
1183
+ # Therefore, same input arguments might be accepted by either python signature.
1184
+ # We want to always parse the one taking Tensor first.
1185
+ #
1186
+ # bitwise_and(Tensor input, Tensor other, *, Tensor out=None)
1187
+ # bitwise_and(Tensor input, Scalar other, *, Tensor out=None)
1188
+ #
1189
+ # If they have different number of parameters then they are not ambiguous - but
1190
+ # the difference on output param can be ignored as it's optional.
1191
+ #
1192
+ # multiply(Tensor input, Tensor other, *, Tensor out=None)
1193
+ # multiply(Tensor input, Scalar other)
1194
+ #
1195
+ # Both positional args and keyword-only args are considered together.
1196
+ #
1197
+ # subtract(Tensor other, *, Scalar alpha=1)
1198
+ # subtract(Scalar other, Scalar alpha=1)
1199
+ #
1200
+ # A few ambiguous cases which it does NOT handle yet.
1201
+ #
1202
+ # If there is any difference in other parameters besides the Tensor/Scalar
1203
+ # difference, then they are not considered ambiguous by this method anymore.
1204
+ # However, the difference could be too trivial to disambiguate.
1205
+ #
1206
+ # foo(Tensor input, Scalar other, Scalar bar)
1207
+ # foo(Tensor input, Tensor other, double bar)
1208
+ #
1209
+ # If they are taking different number of parameters then they are not considered
1210
+ # ambiguous anymore, even if the difference is only on optional kwargs.
1211
+ #
1212
+ # foo(Scalar other, Scalar alpha=1)
1213
+ # foo(Tensor other, *, Scalar alpha=1, Scalar beta=1)
1214
+ #
1215
+
1216
+
1217
+ def sort_overloads(
1218
+ grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True
1219
+ ) -> Sequence[PythonSignatureGroup]:
1220
+ # NB: Smaller here means lower priority
1221
+
1222
+ def is_arg_smaller(t1: Type, t2: Type) -> bool:
1223
+ return (
1224
+ str(t1) == "Scalar"
1225
+ and str(t2) == "Tensor"
1226
+ or str(t1) == "Scalar?"
1227
+ and str(t2) == "Tensor?"
1228
+ or "Dimname" in str(t1)
1229
+ and "Dimname" not in str(t2)
1230
+ or
1231
+ # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
1232
+ # discussed why it is important to prioritize int/int? over int[]
1233
+ str(t1) == "int[]"
1234
+ and (str(t2) == "int" or str(t2) == "int?")
1235
+ or
1236
+ # TensorList currently throws an error during argument parsing, that's why it needs to be
1237
+ # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
1238
+ str(t1) == "Tensor[]"
1239
+ and str(t2).find("[]") != -1
1240
+ or
1241
+ # Prioritize IntArrayRef overload over SymIntArrayRef
1242
+ str(t1) == "SymInt[]"
1243
+ and str(t2) == "int[]"
1244
+ or
1245
+ # Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly
1246
+ # converted to either int or SymInt. Prioritize the Tensor overload since it otherwise gets shadowed.
1247
+ (str(t1) == "SymInt" or str(t1) == "int")
1248
+ and str(t2) == "Tensor"
1249
+ )
1250
+
1251
+ def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
1252
+ """Returns True if s1 < s2 in the partial order."""
1253
+ args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True)
1254
+ if len(args1) != len(args2):
1255
+ return False
1256
+ # TODO: should use some canonical form instead of 'str(arg.type)' - see comments
1257
+ # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which
1258
+ # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'.
1259
+ equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2))
1260
+ smaller_or_equal = all(
1261
+ str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type)
1262
+ for arg1, arg2 in zip(args1, args2)
1263
+ )
1264
+ return smaller_or_equal and not equal
1265
+
1266
+ # First sort by signature
1267
+ grouped_overloads = sorted(
1268
+ grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint)
1269
+ )
1270
+
1271
+ # Construct the relation graph
1272
+ larger_than: dict[int, set[int]] = defaultdict(set)
1273
+ for i1, overload1 in enumerate(grouped_overloads):
1274
+ for i2, overload2 in enumerate(grouped_overloads):
1275
+ if is_smaller(overload1.signature, overload2.signature):
1276
+ larger_than[i1].add(i2)
1277
+
1278
+ if not larger_than:
1279
+ return list(grouped_overloads)
1280
+
1281
+ # Use a topological sort to sort overloads according to the partial order.
1282
+ N = len(grouped_overloads)
1283
+ sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N)))
1284
+
1285
+ for idx in range(N):
1286
+ # The size of sorted_ids will grow to N eventually.
1287
+ i = sorted_ids[idx]
1288
+ for j in sorted(larger_than.keys()):
1289
+ larger = larger_than[j]
1290
+ larger.discard(i)
1291
+ if not larger:
1292
+ del larger_than[j]
1293
+ sorted_ids.append(j)
1294
+
1295
+ return [grouped_overloads[x] for x in sorted_ids]
1296
+
1297
+
1298
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1299
+ #
1300
+ # Codegen API Integration
1301
+ #
1302
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1303
+
1304
+
1305
+ def emit_single_dispatch(
1306
+ ps: PythonSignature,
1307
+ f: NativeFunction,
1308
+ structseq_typenames: dict[str, str],
1309
+ *,
1310
+ symint: bool = True,
1311
+ ) -> str:
1312
+ """
1313
+ Emit dispatch code for a single native function.
1314
+ """
1315
+
1316
+ @with_native_function
1317
+ def go(f: NativeFunction) -> str:
1318
+ # header comments
1319
+ if isinstance(ps, PythonSignatureDeprecated):
1320
+ schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}"
1321
+ else:
1322
+ schema_comment = f"// aten::{f.func}"
1323
+
1324
+ deprecated = "[deprecated] " if ps.deprecated else ""
1325
+
1326
+ # dispatch lambda signature
1327
+ name = cpp.name(f.func)
1328
+ lambda_formals = ", ".join(
1329
+ f"{a.type_str} {a.name}" for a in dispatch_lambda_args(ps, f, symint=symint)
1330
+ )
1331
+ lambda_return = dispatch_lambda_return_str(f)
1332
+
1333
+ # dispatch lambda body
1334
+ dispatch_callee = cpp_dispatch_target(f)
1335
+ dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
1336
+
1337
+ # from arg parser outputs to dispatch lambda arguments
1338
+ parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
1339
+ lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint)
1340
+ inits = "\n".join(lambda_arg_exprs.inits)
1341
+ lambda_args = ", ".join(lambda_arg_exprs.exprs)
1342
+
1343
+ # scatter fields
1344
+ # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
1345
+ # solution for enabling the 'requires_grad' argument for tensor methods
1346
+ # new_full, new_empty, and new_zeros. A much better but more difficult to
1347
+ # implement solution involves refactoring according to Ed's description here:
1348
+ # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
1349
+ need_set_requires_grad = ps.tensor_options_args and (
1350
+ not has_tensor_options(f)
1351
+ or (ps.method and ("requires_grad" in parser_outputs))
1352
+ )
1353
+ set_requires_grad = (
1354
+ f'.set_requires_grad({parser_outputs["requires_grad"].expr})'
1355
+ if need_set_requires_grad
1356
+ else ""
1357
+ )
1358
+
1359
+ if lambda_return == "void":
1360
+ # Make in-place foreach return `self` at python-binding level.
1361
+ # ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954
1362
+ self_arg = f.func.arguments.self_arg
1363
+ return_stmt: str
1364
+ if (
1365
+ str(f.func.name).startswith("_foreach_")
1366
+ and f.func.kind() == SchemaKind.inplace
1367
+ ):
1368
+ # note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place
1369
+ # variant and it unlikely to have it in the future. Thus it's safe to have the following assert.
1370
+ assert self_arg is not None and is_tensor_list_type(
1371
+ self_arg.argument.type
1372
+ )
1373
+ return_stmt = """PyObject* self_tensorlist = _r.args[0];
1374
+ Py_INCREF(self_tensorlist);
1375
+ return self_tensorlist;
1376
+ """
1377
+ else:
1378
+ return_stmt = "Py_RETURN_NONE;"
1379
+ return f"""\
1380
+ {schema_comment}
1381
+ {inits}
1382
+ auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
1383
+ pybind11::gil_scoped_release no_gil;
1384
+ {dispatch_callee}({dispatch_args});
1385
+ }};
1386
+ dispatch_{name}({lambda_args}){set_requires_grad};
1387
+ {return_stmt}
1388
+ """
1389
+ else:
1390
+ typename = structseq_typenames.get(gen_structseq_typename_key(f))
1391
+ structseq_typeref = f"{typename}, " if typename is not None else ""
1392
+ return f"""\
1393
+ {schema_comment}
1394
+ {inits}
1395
+ auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
1396
+ pybind11::gil_scoped_release no_gil;
1397
+ return {dispatch_callee}({dispatch_args});
1398
+ }};
1399
+ return wrap({structseq_typeref}dispatch_{name}({lambda_args}){set_requires_grad});
1400
+ """
1401
+
1402
+ return go(f)
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_trace_type.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ from typing import Sequence
5
+
6
+ from torchgen.api import cpp
7
+ from torchgen.api.types import DispatcherSignature
8
+ from torchgen.code_template import CodeTemplate
9
+ from torchgen.context import with_native_function
10
+ from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments
11
+ from torchgen.utils import FileManager
12
+
13
+
14
+ # Note [Manual Backend kernels]
15
+ # For these ops, we want to manually register to dispatch key Backend and
16
+ # skip codegen-ed registeration to all keys before Backend.
17
+ # For codegen this means:
18
+ # - op set below must match ops with manual_kernel_registration=True in native_functions.yaml
19
+ # where we skip codegen backend kernels
20
+ # - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration
21
+ # - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration
22
+ # Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now.
23
+ # You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
24
+ MANUAL_BACKEND = {
25
+ "options",
26
+ "data",
27
+ "set_data",
28
+ "is_leaf",
29
+ "output_nr",
30
+ "_version",
31
+ "retain_grad",
32
+ "_backward",
33
+ "requires_grad_",
34
+ }
35
+
36
+ # For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
37
+ # You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
38
+ MANUAL_AUTOGRAD_AND_TRACER = {
39
+ "resize_",
40
+ "resize_as_",
41
+ "detach",
42
+ "detach_",
43
+ "copy_",
44
+ "_fw_primal",
45
+ "_make_dual",
46
+ }
47
+
48
+ # Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops:
49
+ # union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER)
50
+ # You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
51
+ MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_BACKEND | MANUAL_AUTOGRAD_AND_TRACER
52
+
53
+ # These functions we don't want to record for tracing, because we always want
54
+ # to trace their constituent parts. This is a temporary hack in lieue
55
+ # of proper scopes, where subsequent compilation passes can ask for the unfolding
56
+ # on demand. Only concrete ATen methods can be disabled this way; it will have
57
+ # NO EFFECT otherwise.
58
+ DONT_RECORD_TRACE = {
59
+ "convolution",
60
+ "conv1d",
61
+ "conv2d",
62
+ "conv3d",
63
+ "conv_transpose1d",
64
+ "conv_transpose2d",
65
+ "conv_transpose3d",
66
+ "lstm_cell",
67
+ "gru_cell",
68
+ "rnn_tanh_cell",
69
+ "rnn_relu_cell",
70
+ # FIXME: figure out a better way when we support sparse tensors in jit
71
+ "_coalesced",
72
+ }
73
+
74
+
75
+ def should_trace(f: NativeFunction) -> bool:
76
+ # Operations involving Storage or Type are not traceable at the moment
77
+ if any(
78
+ str(arg.type) in {"Storage", "Type", "ConstQuantizerPtr"}
79
+ for arg in f.func.schema_order_arguments()
80
+ ):
81
+ return False
82
+ # We can't trace functions which don't have any Tensor or TensorList returns
83
+ if not any(r.type.is_tensor_like() for r in f.func.returns):
84
+ return False
85
+ return f.func.name.name.base not in DONT_RECORD_TRACE
86
+
87
+
88
+ SELECT = CodeTemplate(
89
+ """\
90
+
91
+ if (${cond}) {
92
+ ${true}
93
+ } else {
94
+ ${false}
95
+ }
96
+ """
97
+ )
98
+
99
+ OP_NAME = CodeTemplate(
100
+ """\
101
+ op_name = c10::Symbol::fromQualString("aten::${trace_name}");
102
+ """
103
+ )
104
+
105
+ # These functions have their names recorded under trace renamed,
106
+ RENAME_TRACE = {
107
+ "zero": "zeros_like", # replacing aten::zero_ with aten::zeros_like
108
+ "fill": "full_like", # replacing aten::fill_ with aten::full_like
109
+ }
110
+
111
+
112
+ def format_trace_op_name(f: NativeFunction) -> str:
113
+ # TODO: byte-for-byte compatible with old codegen behavior - should clean up
114
+ if (
115
+ f.func.kind() in (SchemaKind.functional, SchemaKind.out)
116
+ or f.func.name.name.dunder_method
117
+ ):
118
+ # special case for *_out functions: the in-place and out-of-place ops
119
+ # are overloaded with the same name in the JIT
120
+ trace_name = str(f.func.name.name)
121
+ trace_name = RENAME_TRACE.get(trace_name, trace_name)
122
+ return OP_NAME.substitute(trace_name=trace_name)
123
+
124
+ # otherwise, this is an in-place op and we need to emit both in- and
125
+ # out-of-place versions
126
+ outplace_trace_name = f.func.name.name.base
127
+ inplace_trace_name = cpp.name(f.func)
128
+ outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name)
129
+ inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name)
130
+
131
+ return SELECT.substitute(
132
+ cond="tracer_state->force_outplace",
133
+ true=OP_NAME.substitute(trace_name=outplace_trace_name),
134
+ false=OP_NAME.substitute(trace_name=inplace_trace_name),
135
+ )
136
+
137
+
138
+ ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""")
139
+
140
+
141
+ def format_trace_inputs(f: NativeFunction) -> str:
142
+ def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]:
143
+ if isinstance(arg, TensorOptionsArguments):
144
+ name = "options"
145
+ return [
146
+ ADD_TRACE_INPUT.substitute(
147
+ name=name, input="c10::optTypeMetaToScalarType(options.dtype_opt())"
148
+ ),
149
+ ADD_TRACE_INPUT.substitute(name=name, input="options.layout()"),
150
+ ADD_TRACE_INPUT.substitute(name=name, input="options.device()"),
151
+ ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()"),
152
+ ]
153
+ else:
154
+ name = arg.name
155
+ if str(arg.type) == "Tensor?[]":
156
+ return [f'jit::tracer::addInputs(node, "{name}", {name});']
157
+ else:
158
+ return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
159
+
160
+ args: list[Argument | TensorOptionsArguments] = list(
161
+ f.func.schema_order_arguments()
162
+ )
163
+
164
+ if f.func.is_out_fn():
165
+ # *_out functions take the result as a separate argument, but we don't want to
166
+ # trace that argument directly. Instead, we trace its TensorOptions.
167
+ # So first, we need to remove the out argument from the list of arguments to trace.
168
+ num_out_args = len(f.func.arguments.out)
169
+ args = args[:-num_out_args]
170
+
171
+ trace_inputs = itertools.chain.from_iterable(
172
+ dispatch_trace_input(arg) for arg in args
173
+ )
174
+
175
+ if f.func.is_out_fn():
176
+ # for *_out functions, handle the result argument differently for inplace/outplace.
177
+ # For inplace: just add the input to the end to confirm with the JIT schema
178
+ inplace = [
179
+ ADD_TRACE_INPUT.substitute(
180
+ name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name
181
+ )
182
+ for i in range(num_out_args)
183
+ ]
184
+
185
+ # for outplace: do nothing, except if the function is a factory.
186
+ # Factories are a bit special because their out-of-place overloads
187
+ # take an extra TensorOptions argument, which is missing in the _out function
188
+ has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns)
189
+ has_tensor_input_arg = any(
190
+ a.type.is_tensor_like() for a in f.func.arguments.flat_non_out
191
+ )
192
+ is_factory_method = f.category_override == "factory" or (
193
+ has_tensor_return and not has_tensor_input_arg
194
+ )
195
+
196
+ # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method`
197
+ # flag for the whole family of ops with the same basename if any of them is a
198
+ # factory method. For most cases the whole family of ops are indeed all factory
199
+ # method - 'normal' is the only exception. So we handle it specially here to avoid
200
+ # cloning the old logic.
201
+ if f.func.name.name.base == "normal":
202
+ is_factory_method = True
203
+
204
+ if is_factory_method:
205
+ outplace = [
206
+ ADD_TRACE_INPUT.substitute(
207
+ name="out",
208
+ input="c10::optTypeMetaToScalarType(out.options().dtype_opt())",
209
+ ),
210
+ ADD_TRACE_INPUT.substitute(name="out", input="out.options().layout()"),
211
+ ADD_TRACE_INPUT.substitute(name="out", input="out.options().device()"),
212
+ ADD_TRACE_INPUT.substitute(
213
+ name="out", input="out.options().pinned_memory()"
214
+ ),
215
+ ]
216
+ else:
217
+ outplace = []
218
+
219
+ trace_inputs = itertools.chain(
220
+ trace_inputs,
221
+ [
222
+ SELECT.substitute(
223
+ cond="tracer_state->force_outplace",
224
+ true="\n".join(outplace),
225
+ false="\n".join(inplace),
226
+ )
227
+ ],
228
+ )
229
+
230
+ return "\n".join(trace_inputs)
231
+
232
+
233
+ # `torch.jit.trace` have undocumented keyword argument `_force_outplace`,
234
+ # which force jit to replace functions with outplace variants (for
235
+ # example `aten::add_` becomes `aten::add`).
236
+ #
237
+ # This replacement implemented in-place with minimum modifications of
238
+ # arguments stack (as it assumes that outplace call has the same arguments
239
+ # as inplace version).
240
+ #
241
+ # However there are no such substitutions available for `aten::fill_`
242
+ # and `aten::zero_` operators, as we never implemented `aten::fill`
243
+ # and `aten::zero`. So jit tracing hack replacing `aten::zero_` with
244
+ # `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`.
245
+ #
246
+ # But as they potentially can have different arguments, we also have
247
+ # to hack into the stack and add missing ones.
248
+ #
249
+ # A possible alternative would be:
250
+ #
251
+ # - Add `aten::fill` and `aten::zero`
252
+ #
253
+ # - Or keep `aten::zeros_like` arguments aligned with `aten::zero_`
254
+ # arguments (inside of the `native_functions.yaml`)
255
+ RENAME_TRACE_ADD_ARGS = {
256
+ "fill": """\
257
+ jit::tracer::addInputs(node, "options", ::std::optional<ScalarType>());
258
+ jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt));
259
+ jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt));
260
+ jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt));
261
+ ::std::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
262
+ jit::tracer::addInputs(node, "memory_format", memory_format);
263
+ """,
264
+ "zero": """\
265
+ jit::tracer::addInputs(node, "options", ::std::optional<ScalarType>());
266
+ jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt));
267
+ jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt));
268
+ jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt));
269
+ ::std::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
270
+ jit::tracer::addInputs(node, "memory_format", memory_format);
271
+ """,
272
+ }
273
+
274
+ INPLACE_GUARD = CodeTemplate(
275
+ """\
276
+ jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input});
277
+ """
278
+ )
279
+
280
+ PRE_RECORD_TRACE = CodeTemplate(
281
+ """\
282
+ torch::jit::Node* node = nullptr;
283
+ std::shared_ptr<jit::tracer::TracingState> tracer_state;
284
+ if (jit::tracer::isTracing()) {
285
+ tracer_state = jit::tracer::getTracingState();
286
+ at::Symbol op_name;
287
+ ${set_op_name}
288
+ node = tracer_state->createNode(op_name, /*num_outputs=*/0);
289
+ jit::tracer::recordSourceLocation(node);
290
+ ${add_trace_inputs}
291
+ tracer_state->insertNode(node);
292
+ ${inplace_guard}
293
+ jit::tracer::setTracingState(nullptr);
294
+ }
295
+ """
296
+ )
297
+
298
+
299
+ def format_prerecord_trace(f: NativeFunction) -> str:
300
+ if not should_trace(f):
301
+ return ""
302
+
303
+ # TODO: clean up old codegen behavior
304
+ is_inplace = (
305
+ f.func.kind() in (SchemaKind.inplace, SchemaKind.out)
306
+ and not f.func.name.name.dunder_method
307
+ )
308
+ add_args = (
309
+ RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else ""
310
+ )
311
+ additional_inputs = (
312
+ SELECT.substitute(
313
+ cond="tracer_state->force_outplace",
314
+ true=add_args,
315
+ false="",
316
+ )
317
+ if add_args
318
+ else ""
319
+ )
320
+
321
+ return PRE_RECORD_TRACE.substitute(
322
+ set_op_name=format_trace_op_name(f),
323
+ add_trace_inputs=format_trace_inputs(f) + additional_inputs,
324
+ inplace_guard=INPLACE_GUARD.substitute(
325
+ name=cpp.name(f.func),
326
+ mutable_input=f.func.arguments.out[0].name
327
+ if f.func.arguments.out
328
+ else "self",
329
+ )
330
+ if is_inplace
331
+ else "",
332
+ )
333
+
334
+
335
+ POST_RECORD_TRACE = CodeTemplate(
336
+ """\
337
+ if (tracer_state) {
338
+ jit::tracer::setTracingState(std::move(tracer_state));
339
+ ${add_trace_outputs}
340
+ }
341
+ """
342
+ )
343
+
344
+
345
+ def format_postrecord_trace(f: NativeFunction) -> str:
346
+ if not should_trace(f):
347
+ return ""
348
+
349
+ # For outplacing ops, *_out overloads require special handling to move the
350
+ # output *argument* to a return value
351
+ if f.func.is_out_fn():
352
+ output_names_outplace = [arg.name for arg in f.func.arguments.out]
353
+ output_names_inplace = cpp.return_names(f)
354
+
355
+ # Code size optimization: the common case is that the return value is
356
+ # the same for both variants
357
+ if output_names_outplace == output_names_inplace:
358
+ outputs = [
359
+ f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace
360
+ ]
361
+ return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
362
+
363
+ selection = SELECT.substitute(
364
+ cond="force_outplace",
365
+ true="\n".join(
366
+ f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace
367
+ ),
368
+ false="\n".join(
369
+ f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace
370
+ ),
371
+ )
372
+ return POST_RECORD_TRACE.substitute(add_trace_outputs=selection)
373
+ else:
374
+ output_names = cpp.return_names(f)
375
+ outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names]
376
+ return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
377
+
378
+
379
+ def tie_return_values(f: NativeFunction) -> str:
380
+ if len(f.func.returns) == 1:
381
+ return f'auto {f.func.returns[0].name or "result"}'
382
+ names = cpp.return_names(f)
383
+ return f'auto [{", ".join(names)}]'
384
+
385
+
386
+ def get_return_value(f: NativeFunction) -> str:
387
+ names = cpp.return_names(f)
388
+ if len(f.func.returns) == 1:
389
+ return names[0]
390
+ if f.func.kind() == SchemaKind.out:
391
+ return f'std::forward_as_tuple({", ".join(names)})'
392
+ else:
393
+ moved = ", ".join(f"std::move({name})" for name in names)
394
+ return f"std::make_tuple({moved})"
395
+
396
+
397
+ TRACE_DISPATCH = CodeTemplate(
398
+ """\
399
+ ${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});"""
400
+ )
401
+
402
+
403
+ def emit_trace_body(f: NativeFunction) -> list[str]:
404
+ trace_body: list[str] = []
405
+
406
+ trace_body.append(format_prerecord_trace(f))
407
+
408
+ dispatcher_sig = DispatcherSignature.from_schema(f.func)
409
+ dispatcher_exprs = dispatcher_sig.exprs()
410
+
411
+ # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
412
+ # See Note [Plumbing Keys Through The Dispatcher] for details.
413
+ dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)"
414
+ redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
415
+
416
+ assign_return_values = (
417
+ f"{tie_return_values(f)} = "
418
+ if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable]
419
+ and f.func.returns
420
+ else ""
421
+ )
422
+
423
+ # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
424
+ # We could probably work harder to ensure that the fast variants are
425
+ # called instead, but the perf benefit would be minimal.
426
+ trace_body.append(
427
+ TRACE_DISPATCH.substitute(
428
+ assign_return_values=assign_return_values,
429
+ unambiguous_name=f.func.name.unambiguous_name(),
430
+ unpacked_args=redispatch_args,
431
+ )
432
+ )
433
+
434
+ trace_body.append(format_postrecord_trace(f))
435
+ if f.func.returns:
436
+ trace_body.append(f"return {get_return_value(f)};")
437
+ return trace_body
438
+
439
+
440
+ METHOD_DEFINITION = CodeTemplate(
441
+ """\
442
+ ${return_type} ${type_wrapper_name}(${formals}) {
443
+ ${type_definition_body}
444
+ }
445
+ """
446
+ )
447
+
448
+
449
+ def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str:
450
+ if f.func.name.overload_name:
451
+ name = f"{cpp.name(f.func)}_{f.func.name.overload_name}"
452
+ else:
453
+ name = cpp.name(f.func)
454
+
455
+ # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key.
456
+ # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated,
457
+ # the key argument should not be passed.
458
+ # We do not append key if it is Default so that generated functions from
459
+ # before per-dispatch-key derivatives were added retain the same names.
460
+ if key != "Default":
461
+ name = name + f"_{key}"
462
+ return name
463
+
464
+
465
+ @with_native_function
466
+ def method_definition(f: NativeFunction) -> str:
467
+ assert cpp.name(f.func) not in MANUAL_TRACER
468
+
469
+ formals = ", ".join(
470
+ # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
471
+ # See Note [Plumbing Keys Through The Dispatcher] for details.
472
+ ["c10::DispatchKeySet ks"]
473
+ + [
474
+ f'{cpp.argument_type(a, binds="__placeholder__", symint=True).cpp_type()} {a.name}'
475
+ for a in f.func.schema_order_arguments()
476
+ ]
477
+ )
478
+
479
+ return METHOD_DEFINITION.substitute(
480
+ return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
481
+ type_wrapper_name=type_wrapper_name(f),
482
+ formals=formals,
483
+ type_definition_body=emit_trace_body(f),
484
+ )
485
+
486
+
487
+ WRAPPER_REGISTRATION = CodeTemplate(
488
+ """\
489
+ m.impl("${name}",
490
+ TORCH_FN(${class_type}::${type_wrapper_name})
491
+ );
492
+ """
493
+ )
494
+
495
+
496
+ @with_native_function
497
+ def method_registration(f: NativeFunction) -> str:
498
+ assert cpp.name(f.func) not in MANUAL_TRACER
499
+
500
+ return WRAPPER_REGISTRATION.substitute(
501
+ name=f.func.name,
502
+ type_wrapper_name=type_wrapper_name(f),
503
+ class_type="TraceType",
504
+ )
505
+
506
+
507
+ def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]:
508
+ return {
509
+ "ops_headers": [f"#include <ATen/ops/{fn.root_name}_ops.h>"],
510
+ "trace_method_definitions": [method_definition(fn)],
511
+ "trace_wrapper_registrations": [method_registration(fn)],
512
+ }
513
+
514
+
515
+ def gen_trace_type(
516
+ out: str, native_functions: list[NativeFunction], template_path: str
517
+ ) -> None:
518
+ # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
519
+ # template regarding sharding of the generated files.
520
+ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
521
+ fm.write_sharded(
522
+ "TraceType.cpp",
523
+ [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
524
+ key_fn=lambda fn: fn.root_name,
525
+ base_env={
526
+ "generated_comment": "@"
527
+ + f"generated from {fm.template_dir_for_comments()}/TraceType.cpp",
528
+ },
529
+ env_callable=gen_trace_type_func,
530
+ num_shards=5,
531
+ sharded_keys={
532
+ "ops_headers",
533
+ "trace_method_definitions",
534
+ "trace_wrapper_registrations",
535
+ },
536
+ )
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_factories.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables.
2
+ #
3
+ # This writes one file: variable_factories.h
4
+
5
+ from __future__ import annotations
6
+
7
+ import re
8
+
9
+ import torchgen.api.python as python
10
+ from torchgen.api import cpp
11
+ from torchgen.api.types import CppSignatureGroup
12
+ from torchgen.context import with_native_function
13
+ from torchgen.gen import parse_native_yaml
14
+ from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
15
+ from torchgen.utils import FileManager, mapMaybe
16
+
17
+
18
+ OPTIONAL_TYPE_PATTERN = re.compile(r"std::optional<(.+)>")
19
+ TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
20
+
21
+
22
+ # Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc.
23
+ # TODO: maybe update the cpp argument API to take optional namespace argument?
24
+ def fully_qualified_type(argument_type: str) -> str:
25
+ def maybe_optional_type(type: str, is_opt: bool) -> str:
26
+ return f"std::optional<{type}>" if is_opt else type
27
+
28
+ opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type)
29
+ is_opt = opt_match is not None
30
+ if opt_match:
31
+ argument_type = argument_type[opt_match.start(1) : opt_match.end(1)]
32
+ match = TYPE_PATTERN.match(argument_type)
33
+ if match is None:
34
+ return maybe_optional_type(argument_type, is_opt)
35
+ index = match.start(1)
36
+ qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}"
37
+ return maybe_optional_type(qualified_type, is_opt)
38
+
39
+
40
+ def gen_variable_factories(
41
+ out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str
42
+ ) -> None:
43
+ native_functions = parse_native_yaml(
44
+ native_yaml_path, tags_yaml_path
45
+ ).native_functions
46
+ factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
47
+ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
48
+ fm.write_with_template(
49
+ "variable_factories.h",
50
+ "variable_factories.h",
51
+ lambda: {
52
+ "generated_comment": "@"
53
+ + f"generated from {fm.template_dir_for_comments()}/variable_factories.h",
54
+ "ops_headers": [
55
+ f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions
56
+ ],
57
+ "function_definitions": list(mapMaybe(process_function, factory_functions)),
58
+ },
59
+ )
60
+
61
+
62
+ @with_native_function
63
+ def is_factory_function(f: NativeFunction) -> bool:
64
+ if Variant.function not in f.variants:
65
+ return False
66
+
67
+ name = cpp.name(f.func)
68
+ has_tensor_options = python.has_tensor_options(f)
69
+ return has_tensor_options or name.endswith("_like")
70
+
71
+
72
+ @with_native_function
73
+ def process_function(f: NativeFunction) -> str | None:
74
+ name = cpp.name(f.func)
75
+ has_tensor_options = python.has_tensor_options(f)
76
+ is_factory = has_tensor_options or name.endswith("_like")
77
+
78
+ if Variant.function not in f.variants or not is_factory:
79
+ return None
80
+
81
+ cpp_sigs = CppSignatureGroup.from_native_function(f, method=False)
82
+ sigs = [cpp_sigs.signature]
83
+ if cpp_sigs.symint_signature is not None:
84
+ sigs.append(cpp_sigs.symint_signature)
85
+ r = ""
86
+ for sig in sigs:
87
+ formals: list[str] = []
88
+ exprs: list[str] = []
89
+ requires_grad = "false"
90
+ for arg in sig.arguments():
91
+ qualified_type = fully_qualified_type(arg.type)
92
+ if arg.default:
93
+ formals.append(f"{qualified_type} {arg.name} = {arg.default}")
94
+ else:
95
+ formals.append(f"{qualified_type} {arg.name}")
96
+
97
+ if isinstance(arg.argument, TensorOptionsArguments):
98
+ # note: we remove the requires_grad setting from the TensorOptions because
99
+ # it is ignored anyways (and we actually have an assertion that it isn't set
100
+ # which would fail otherwise). We handle requires_grad explicitly here
101
+ # instead of passing it through to the kernel.
102
+ exprs.append(
103
+ f"at::TensorOptions({arg.name}).requires_grad(::std::nullopt)"
104
+ )
105
+ # Manually set the requires_grad bit on the result tensor.
106
+ requires_grad = f"{arg.name}.requires_grad()"
107
+ else:
108
+ exprs.append(arg.name)
109
+
110
+ r += f"""\
111
+ inline at::Tensor {sig.name()}({', '.join(formals)}) {{
112
+ at::AutoDispatchBelowADInplaceOrView guard;
113
+ return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad});
114
+ }}
115
+ """
116
+ return r
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_variable_type.py ADDED
@@ -0,0 +1,2180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generates VariableType.h/cpp
2
+ #
3
+ # **If any changes are being made to the VariableType codegen please also check
4
+ # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
5
+ #
6
+ # VariableType is a subclass of at::Type that provides the binding code
7
+ # necessary to provide a differentiable version of ATen operators. There are a
8
+ # number of different things we could mean:
9
+ #
10
+ # - Given a non-differentiable forward implementation, we might
11
+ # directly associate it with a backward implementation to make
12
+ # it differentiable. This is the common case.
13
+ #
14
+ # - Some functions don't need a backwards implementation, because
15
+ # backpropagation will never propagate beyond them. There are a
16
+ # number of different reasons why this may be the case:
17
+ #
18
+ # - The function has no differentiable inputs
19
+ # - The function's output is not differentiable
20
+ # - The function has no data dependency on its input
21
+ #
22
+ # - Some function don't need a backwards implementation because they
23
+ # are implemented as a composition of other (differentiable) ATen
24
+ # functions. These are dispatched directly to the Type superclass,
25
+ # which will in turn dispatch back to VariableType for its
26
+ # differentiable subcomponents.
27
+ #
28
+
29
+ from __future__ import annotations
30
+
31
+ import re
32
+ from typing import Callable, Sequence
33
+
34
+ from torchgen.api import cpp
35
+ from torchgen.api.autograd import (
36
+ DifferentiableInput,
37
+ dispatch_strategy,
38
+ ForwardDerivative,
39
+ gen_differentiable_outputs,
40
+ is_differentiable,
41
+ NativeFunctionWithDifferentiabilityInfo,
42
+ SavedAttribute,
43
+ )
44
+ from torchgen.api.types import (
45
+ ArrayRefCType,
46
+ BaseCppType,
47
+ BaseCType,
48
+ Binding,
49
+ DispatcherSignature,
50
+ intArrayRefT,
51
+ iTensorListRefT,
52
+ ListCType,
53
+ MutRefCType,
54
+ OptionalCType,
55
+ scalarT,
56
+ SpecialArgName,
57
+ stringT,
58
+ symIntArrayRefT,
59
+ TENSOR_LIST_LIKE_CTYPES,
60
+ tensorListT,
61
+ tensorT,
62
+ TupleCType,
63
+ VectorCType,
64
+ )
65
+ from torchgen.code_template import CodeTemplate
66
+ from torchgen.context import (
67
+ native_function_manager,
68
+ with_native_function,
69
+ with_native_function_and,
70
+ )
71
+ from torchgen.model import (
72
+ Argument,
73
+ BaseType,
74
+ ListType,
75
+ NativeFunction,
76
+ SchemaKind,
77
+ SelfArgument,
78
+ TensorOptionsArguments,
79
+ )
80
+ from torchgen.utils import FileManager, mapMaybe
81
+
82
+ from .context import with_native_function_with_differentiability_info_and_key
83
+ from .gen_inplace_or_view_type import (
84
+ ALL_VIEW_FUNCTIONS,
85
+ ASSIGN_RETURN_VALUE,
86
+ AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION,
87
+ gen_formals,
88
+ get_base_name,
89
+ get_view_info,
90
+ is_tensor_list_type,
91
+ is_tensor_type,
92
+ METHOD_DEFINITION,
93
+ modifies_arguments,
94
+ TMP_VAR,
95
+ unpack_args,
96
+ unpacked_name,
97
+ use_derived,
98
+ WRAPPER_REGISTRATION,
99
+ )
100
+ from .gen_trace_type import (
101
+ get_return_value,
102
+ MANUAL_AUTOGRAD_AND_TRACER,
103
+ MANUAL_BACKEND,
104
+ tie_return_values,
105
+ type_wrapper_name,
106
+ )
107
+
108
+
109
+ # We don't set or modify grad_fn on these methods. Generally, they return
110
+ # tensors that have requires_grad=False. In-place functions listed here will
111
+ # not examine or modify requires_grad or grad_fn.
112
+ # NB: this does NOT include overload name
113
+ DONT_REQUIRE_DERIVATIVE = {
114
+ # These only depend on the input Tensor's shape and device, not the data
115
+ "empty_like",
116
+ "ones_like",
117
+ "full_like",
118
+ "zeros_like",
119
+ "rand_like",
120
+ "randn_like",
121
+ "new_empty",
122
+ "new_empty_strided",
123
+ "new_full",
124
+ "new_zeros",
125
+ "new_ones",
126
+ # These are only implemented on integral types
127
+ "__and__",
128
+ "__iand__",
129
+ "__ilshift__",
130
+ "__ior__",
131
+ "__irshift__",
132
+ "__ixor__",
133
+ "__lshift__",
134
+ "__or__",
135
+ "__rshift__",
136
+ "__xor__",
137
+ # These work on integral data types, and hence don't require derivative
138
+ "_sobol_engine_draw",
139
+ "_sobol_engine_ff",
140
+ "_sobol_engine_scramble_",
141
+ "_sobol_engine_initialize_state_",
142
+ # This is an unsafe method that is meant to be out of reach of autograd.
143
+ "_coalesced_",
144
+ # Quantize functions should not record gradients
145
+ "quantize_per_tensor",
146
+ "quantize_per_channel",
147
+ # Functions that return integers should not have output that require gradients
148
+ "argmax",
149
+ "argmin",
150
+ "argsort",
151
+ "searchsorted",
152
+ "bucketize",
153
+ # Functions that return booleans are not differentiable
154
+ "isnan",
155
+ "isposinf",
156
+ "isneginf",
157
+ "isinf",
158
+ "signbit",
159
+ "isin",
160
+ "allclose",
161
+ # Functions return none are not differentiable
162
+ "record_stream",
163
+ # These functions are not differentiable
164
+ "logical_and",
165
+ "logical_xor",
166
+ "logical_not",
167
+ "logical_or",
168
+ # This function returns nested_tensor shape as a tensor that is non-differentiable
169
+ "_nested_tensor_size",
170
+ "_nested_tensor_strides",
171
+ "_nested_tensor_storage_offsets",
172
+ }
173
+
174
+ # The C -> R functions at the time of adding this are still being audited and tested
175
+ # but will not error out.
176
+ # C -> C, R -> C functions for which backward is correctly implemented and tested
177
+ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
178
+ "fill",
179
+ "t",
180
+ "t_copy",
181
+ "view",
182
+ "reshape",
183
+ "reshape_as",
184
+ "view_as",
185
+ "view_copy",
186
+ "roll",
187
+ "clone",
188
+ "block_diag",
189
+ "diag_embed",
190
+ "repeat",
191
+ "expand",
192
+ "expand_copy",
193
+ "flip",
194
+ "fliplr",
195
+ "flipud",
196
+ "rot90",
197
+ "nanmean",
198
+ "nansum",
199
+ "transpose",
200
+ "permute",
201
+ "squeeze",
202
+ "unsqueeze",
203
+ "unsqueeze_copy",
204
+ "resize",
205
+ "resize_as",
206
+ "tril",
207
+ "triu",
208
+ "chunk",
209
+ "zero_",
210
+ "eq_",
211
+ "ne_",
212
+ "add",
213
+ "__radd__",
214
+ "sum",
215
+ "_conj",
216
+ "sin",
217
+ "cos",
218
+ "mul",
219
+ "sinc",
220
+ "sinh",
221
+ "cosh",
222
+ "__rmul__",
223
+ "sgn",
224
+ "asin",
225
+ "acos",
226
+ "sub",
227
+ "div",
228
+ "cat",
229
+ "view_as_complex",
230
+ "index_put",
231
+ "neg",
232
+ "complex",
233
+ "select",
234
+ "where",
235
+ "as_strided",
236
+ "as_strided_copy",
237
+ "as_strided_scatter",
238
+ "slice",
239
+ "constant_pad_nd",
240
+ "unbind",
241
+ "split",
242
+ "split_with_sizes",
243
+ "unsafe_split",
244
+ "split_with_sizes_backward",
245
+ "dot",
246
+ "vdot",
247
+ "cholesky",
248
+ "triangular_solve",
249
+ "mm",
250
+ "_unsafe_view",
251
+ "mv",
252
+ "outer",
253
+ "bmm",
254
+ "diagonal",
255
+ "alias",
256
+ "atan",
257
+ "log",
258
+ "log10",
259
+ "log1p",
260
+ "log2",
261
+ "logaddexp",
262
+ "logsumexp",
263
+ "logcumsumexp",
264
+ "reciprocal",
265
+ "tan",
266
+ "pow",
267
+ "rsqrt",
268
+ "tanh",
269
+ "tanh_backward",
270
+ "asinh",
271
+ "acosh",
272
+ "atanh",
273
+ "take",
274
+ "fill_",
275
+ "exp",
276
+ "exp2",
277
+ "expm1",
278
+ "nonzero",
279
+ "mean",
280
+ "std_mean",
281
+ "var_mean",
282
+ "inverse",
283
+ "solve",
284
+ "linalg_cholesky",
285
+ "addcmul",
286
+ "addcdiv",
287
+ "matrix_exp",
288
+ "linalg_matrix_exp",
289
+ "_linalg_eigh",
290
+ "cholesky_solve",
291
+ "linalg_qr",
292
+ "_linalg_svd",
293
+ "_fft_c2c",
294
+ "_fft_r2c",
295
+ "linalg_solve",
296
+ "sqrt",
297
+ "stack",
298
+ "gather",
299
+ "index_select",
300
+ "index_add_",
301
+ "linalg_inv",
302
+ "linalg_inv_ex",
303
+ "baddbmm",
304
+ "addbmm",
305
+ "addmm",
306
+ "addmv",
307
+ "addr",
308
+ "linalg_householder_product",
309
+ "ormqr",
310
+ "reflection_pad1d",
311
+ "reflection_pad2d",
312
+ "reflection_pad3d",
313
+ "linalg_cholesky_ex",
314
+ "linalg_eig",
315
+ "diagonal_copy",
316
+ "diagonal_scatter",
317
+ "alias_copy",
318
+ "select_backward",
319
+ "diagonal_backward",
320
+ "slice_backward",
321
+ "reflection_pad1d_backward",
322
+ "reflection_pad2d_backward",
323
+ "reflection_pad3d_backward",
324
+ "_sparse_sparse_matmul",
325
+ "replication_pad1d",
326
+ "replication_pad2d",
327
+ "replication_pad3d",
328
+ "put",
329
+ "put_",
330
+ "_to_copy",
331
+ "replication_pad1d_backward",
332
+ "replication_pad2d_backward",
333
+ "replication_pad3d_backward",
334
+ "diag",
335
+ "masked_scatter",
336
+ "masked_select",
337
+ "index_add",
338
+ "index_fill",
339
+ "trace",
340
+ "polar",
341
+ "cumsum",
342
+ "rsub",
343
+ "eig",
344
+ "lerp",
345
+ "linalg_vector_norm",
346
+ "cumprod",
347
+ "prod",
348
+ "index_copy",
349
+ "lu",
350
+ "unfold",
351
+ "unfold_backward",
352
+ "index",
353
+ "masked_fill",
354
+ "masked_scatter_backward",
355
+ "linalg_cross",
356
+ "lu_unpack",
357
+ "renorm",
358
+ "_conj_physical",
359
+ "linalg_lu_factor_ex",
360
+ "scatter",
361
+ "scatter_add",
362
+ "sigmoid",
363
+ "sigmoid_backward",
364
+ "sparse_mask",
365
+ "trapezoid",
366
+ "cumulative_trapezoid",
367
+ "conj_physical_",
368
+ "_neg_view",
369
+ "_reshape_alias",
370
+ "_reshape_copy",
371
+ "_linalg_det",
372
+ "lu_solve",
373
+ "linalg_solve_triangular",
374
+ "linalg_pinv",
375
+ "linalg_lstsq",
376
+ "unfold_copy",
377
+ "col2im",
378
+ "im2col",
379
+ "cholesky_inverse",
380
+ "to_sparse",
381
+ "sparse_sampled_addmm",
382
+ "linalg_lu",
383
+ "pixel_shuffle",
384
+ "pixel_unshuffle",
385
+ "channel_shuffle",
386
+ "linalg_lu_solve",
387
+ "_linalg_slogdet",
388
+ "_linalg_solve_ex",
389
+ "_unsafe_index",
390
+ "_unsafe_index_put",
391
+ "_unsafe_masked_index",
392
+ "_unsafe_masked_index_put_accumulate",
393
+ }
394
+
395
+ GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
396
+ "_to_dense",
397
+ "_coalesce",
398
+ "coalesce",
399
+ "values",
400
+ "_sparse_coo_tensor_with_dims_and_tensors",
401
+ "_sparse_addmm",
402
+ }
403
+
404
+ GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX)
405
+
406
+ # Some operators invalidate the grad_accumulator. Let's reset it.
407
+ RESET_GRAD_ACCUMULATOR = {"set_", "resize_"}
408
+
409
+ # NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
410
+ #
411
+ # We check the following properties:
412
+ # 1) A function should never change the input tensors' underlying c10::TensorImpl
413
+ # pointers or c10::Storage pointers, even if it modifies its input tensors (via
414
+ # inplace or out-variants)
415
+ # If the function does not modify its arguments, we also check the following properties
416
+ # pertaining to its output:
417
+ # 2) Its TensorImpl has use_count of 1
418
+ # 3) If the function is a view function, it has the same StorageImpl as that of
419
+ # the input it is aliased with. Otherwise, its StorageImpl has use_count of 1
420
+ #
421
+ # The following code templates implement the checks for this invariant:
422
+ SAVE_TENSOR_STORAGE = CodeTemplate(
423
+ """\
424
+ auto ${tensor_name}_storage_saved =
425
+ ${tensor_name}.has_storage() ? ::std::optional<Storage>(${tensor_name}.storage()) : ::std::nullopt;
426
+ """
427
+ )
428
+
429
+
430
+ # If tensor_name == out_tensor_name, used to enforce (1), otherwise used for (2)
431
+ ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate(
432
+ """\
433
+ if (${tensor_name}_storage_saved.has_value() &&
434
+ !at::impl::dispatch_mode_enabled() &&
435
+ !at::impl::tensor_has_dispatch(${tensor_name}) &&
436
+ !at::impl::tensor_has_dispatch(${out_tensor_name}))
437
+ TORCH_INTERNAL_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage()));
438
+ """
439
+ )
440
+
441
+ SAVE_TENSORLIST_STORAGE = CodeTemplate(
442
+ """\
443
+ std::vector<::std::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
444
+ for (const Tensor& tensor : ${tensorlist_name})
445
+ ${tensorlist_name}_storage_saved.push_back(
446
+ tensor.has_storage() ? ::std::optional<Storage>(tensor.storage()) : ::std::nullopt);
447
+ """
448
+ )
449
+
450
+ ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate(
451
+ """\
452
+ for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
453
+ if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
454
+ TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage()));
455
+ }
456
+ """
457
+ )
458
+
459
+ SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
460
+ """\
461
+ std::vector<::std::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
462
+ for (const ::std::optional<Tensor>& tensor : ${tensorlist_name})
463
+ ${tensorlist_name}_storage_saved.push_back(
464
+ tensor.has_value() && tensor->has_storage() ? ::std::optional<Storage>(tensor->storage()) : ::std::nullopt);
465
+ """
466
+ )
467
+
468
+ ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
469
+ """\
470
+ for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
471
+ if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
472
+ TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(
473
+ static_cast<::std::optional<Tensor>>(${tensorlist_name}[i])->storage()));
474
+ }
475
+ """
476
+ )
477
+
478
+ SAVE_TENSOR_IMPL = CodeTemplate(
479
+ """\
480
+ c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
481
+ if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
482
+ """
483
+ )
484
+
485
+ ENFORCE_SAME_TENSOR_IMPL = CodeTemplate(
486
+ """\
487
+ if (${tensor_name}_impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name}))
488
+ TORCH_INTERNAL_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr());
489
+ """
490
+ )
491
+
492
+ ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate(
493
+ """\
494
+ if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name}))
495
+ TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}");
496
+ """
497
+ )
498
+
499
+ ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE = CodeTemplate(
500
+ """\
501
+ if (${tensor_name}.has_storage() && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) {
502
+ TORCH_INTERNAL_ASSERT(${tensor_name}.storage().use_count() == 1, "function: ${fn_name}");
503
+ }
504
+ """
505
+ )
506
+
507
+ SAVE_TENSORLIST_IMPL = CodeTemplate(
508
+ """\
509
+ std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
510
+ for (size_t i=0; i<${tensorlist_name}.size(); i++)
511
+ if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr();
512
+ """
513
+ )
514
+
515
+ ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate(
516
+ """\
517
+ for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
518
+ if (${tensorlist_name}_impl_saved[i] && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
519
+ TORCH_INTERNAL_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr());
520
+ }
521
+ """
522
+ )
523
+
524
+ SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate(
525
+ """\
526
+ std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
527
+ for (size_t i=0; i<${tensorlist_name}.size(); i++) {
528
+ ::std::optional<Tensor> t = ${tensorlist_name}[i];
529
+ if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr();
530
+ }
531
+ """
532
+ )
533
+
534
+ ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate(
535
+ """\
536
+ for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
537
+ if (${tensorlist_name}_impl_saved[i])
538
+ TORCH_INTERNAL_ASSERT(
539
+ ${tensorlist_name}_impl_saved[i] == static_cast<::std::optional<Tensor>>(${tensorlist_name}[i])->getIntrusivePtr());
540
+ }
541
+ """
542
+ )
543
+
544
+ # The following list contains functions that we don't enforce the invariant on.
545
+ DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
546
+ # These functions are expected to change impl or storage of input tensors
547
+ "set_",
548
+ "_cudnn_rnn_flatten_weight",
549
+ "_unsafe_masked_index",
550
+ "_unsafe_masked_index_put_accumulate",
551
+ }
552
+ DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = {
553
+ # These non-inplace, non-out functions return tensors with use_count > 1
554
+ # Therefore, they MAY (but not necessarily) return one of its inputs as-is
555
+ # See https://github.com/pytorch/pytorch/issues/60426 for more information
556
+ "_embedding_bag",
557
+ "_embedding_bag_forward_only",
558
+ "q_per_channel_scales",
559
+ "q_per_channel_zero_points",
560
+ "lu_unpack",
561
+ "_cudnn_rnn_backward",
562
+ # The below failed StorageImpl use_count check but we skip tensor_impl check
563
+ # just in case
564
+ "_cudnn_rnn",
565
+ "dequantize_self",
566
+ # lift() should never actually be called with a requires_grad=True tensor,
567
+ "lift",
568
+ "lift_fresh",
569
+ "lift_fresh_copy",
570
+ # Nested Tensors related functions
571
+ # _nested_tensor_size() should never actually be called with requires_grad=True tensor
572
+ "_nested_tensor_size",
573
+ "_nested_tensor_strides",
574
+ "_nested_tensor_storage_offsets",
575
+ }
576
+
577
+ DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
578
+ # These non-view functions return tensors with storage use_count != 1
579
+ "_slow_conv2d_forward",
580
+ "slow_conv3d_forward",
581
+ "channel_shuffle",
582
+ # If an input is returned as-is in output, we cannot guarantee its storage_impl
583
+ # use count to be 1 either.
584
+ *DONT_ENFORCE_TENSOR_IMPL_USE_COUNT,
585
+ }
586
+ # END CHECKS FOR [ TensorImpl and Storage Pointer Sanity Checks ]
587
+
588
+ DECLARE_GRAD_FN = CodeTemplate(
589
+ """\
590
+ std::shared_ptr<${op}> grad_fn;
591
+ """
592
+ )
593
+
594
+ DECLARE_VECTOR_OF_GRAD_FN = CodeTemplate(
595
+ """\
596
+ std::vector<std::shared_ptr<${op}>> grad_fns;
597
+ """
598
+ )
599
+
600
+ SETUP_ANY_REQUIRES_GRAD = CodeTemplate(
601
+ """\
602
+ [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( ${args_with_derivatives} );
603
+ ${extra_differentiability_conditions}
604
+ """
605
+ )
606
+
607
+ SETUP_DERIVATIVE = CodeTemplate(
608
+ """\
609
+ if (_any_requires_grad) {
610
+ ${setup}
611
+ }
612
+ """
613
+ )
614
+
615
+ SETUP_NONE_REQUIRES_GRAD = CodeTemplate(
616
+ """\
617
+ if (compute_requires_grad( ${args_to_check} )) {
618
+ throw_error_out_requires_grad("${base_name}");
619
+ }
620
+ """
621
+ )
622
+
623
+ ASSIGN_GRAD_FN = CodeTemplate(
624
+ """\
625
+ grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode);
626
+ grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
627
+ """
628
+ )
629
+
630
+ # note(crcrpar): `compute_requires_grad` in the template below is supplied with arguments indexed with `i`
631
+ # while the `SETUP_ANY_REQUIRES_GRAD` above takes whole tensors and scalars.
632
+ ASSIGN_VECTOR_OF_GRAD_FN = CodeTemplate(
633
+ """\
634
+ for (const auto& i : c10::irange( ${irange} )) {
635
+ const auto ith_requires_grad = compute_requires_grad(${args_with_derivatives});
636
+ check_inplace(self[i], ith_requires_grad);
637
+ grad_fns.push_back([&]() -> std::shared_ptr<${op}> {
638
+ if (!ith_requires_grad) {
639
+ return nullptr;
640
+ } else {
641
+ auto grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode);
642
+ grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
643
+ return grad_fn;
644
+ }
645
+ }());
646
+ }
647
+ """
648
+ )
649
+
650
+ CALL_REDISPATCH = CodeTemplate(
651
+ """\
652
+ at::redispatch::${api_name}(${unpacked_args})"""
653
+ )
654
+ # If the non-variable operation has return values, we use the `tmp` variable to hold the
655
+ # values temporarily and pass the values to the return variables outside of the
656
+ # `at::AutoDispatchBelowAutograd` guard block.
657
+ DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP = CodeTemplate(
658
+ """\
659
+ auto ${tmp_var} = ([&]() {
660
+ if (${any_has_forward_grad}) {
661
+ static c10::OperatorName full_name("aten::${op_name}", "${op_overload}");
662
+ static ::std::optional<c10::OperatorHandle> opt_op = c10::Dispatcher::singleton().findSchema(full_name);
663
+ return impl::run_jit_decomposition_with_args_for_jvp<${return_types}>("${op_name}", *opt_op, ks, ${arg_names});
664
+ } else {
665
+ ${guard}
666
+ return ${base_type_call};
667
+ }
668
+ })();
669
+ """
670
+ )
671
+
672
+ DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate(
673
+ """\
674
+ auto ${tmp_var} = ([&]() {
675
+ ${guard}
676
+ return ${base_type_call};
677
+ })();
678
+ """
679
+ )
680
+
681
+ DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate(
682
+ """\
683
+ {
684
+ ${guard}
685
+ ${base_type_call};
686
+ }
687
+ """
688
+ )
689
+
690
+ SET_HISTORY = CodeTemplate(
691
+ """\
692
+ if (grad_fn) {
693
+ ${fn}_history(${differentiable_outputs}, grad_fn);
694
+ }
695
+ """
696
+ )
697
+
698
+ LOOP_OVER_VECTOR_OF_GRAD_FNS = CodeTemplate(
699
+ """\
700
+ if (!grad_fns.empty()) {
701
+ ${preamble}
702
+ for (const auto& i : c10::irange(grad_fns.size())) {
703
+ auto grad_fn = grad_fns[i];
704
+ if (grad_fn != nullptr) {
705
+ ${statements}
706
+ }
707
+ }
708
+ }
709
+ """
710
+ )
711
+
712
+ CONDITIONAL = CodeTemplate(
713
+ """\
714
+ if (${cond}) {
715
+ ${statements}
716
+ }
717
+ """
718
+ )
719
+
720
+ RUN_ONLY_IN_DEBUG_MODE = CodeTemplate(
721
+ """\
722
+ #ifndef NDEBUG
723
+ ${statements}
724
+ #endif
725
+ """
726
+ )
727
+
728
+ FW_DERIVATIVE_CHECK_TEMPLATE = CodeTemplate(
729
+ """\
730
+ isFwGradDefined(${req_inp})\
731
+ """
732
+ )
733
+ FW_DERIVATIVE_SIZE_CHECK_TEMPLATE = CodeTemplate(
734
+ """\
735
+ TORCH_CHECK(
736
+ self.size() == ${inp_name}.size(),
737
+ "Tensor lists must have the same number of tensors, got ",
738
+ self.size(),
739
+ " and ",
740
+ ${inp_name}.size());
741
+ """
742
+ )
743
+
744
+ FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate(
745
+ """\
746
+ isFwGradDefinedTensorList(${req_inp})\
747
+ """
748
+ )
749
+
750
+ FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate(
751
+ """\
752
+ auto ${inp_name}_t_raw = toNonOptFwGrad(${inp});
753
+ auto ${inp_name}_tensor = toNonOptTensor(${inp});
754
+ auto ${inp_name}_t = (${inp_name}_t_raw.defined() || !${inp_name}_tensor.defined())
755
+ ? ${inp_name}_t_raw : at::${zeros_fn}(${inp_name}_tensor.sym_sizes(), ${inp_name}_tensor.options());
756
+ """
757
+ )
758
+
759
+ FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate(
760
+ """\
761
+ auto ${inp_name}_p = toNonOptPrimal(${inp});
762
+ """
763
+ )
764
+
765
+ FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate(
766
+ """\
767
+ if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}.defined()) {
768
+ // The hardcoded 0 here will need to be updated once we support multiple levels.
769
+ ${out_arg}._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace});
770
+ }
771
+ """
772
+ )
773
+
774
+ FW_DERIVATIVE_SETTER_TENSOR_FOREACH = CodeTemplate(
775
+ """\
776
+ for (const auto& i : c10::irange(${out_arg}_new_fw_grad_opts.size())) {
777
+ auto& ${out_arg}_new_fw_grad_opt = ${out_arg}_new_fw_grad_opts[i];
778
+ if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}[i].defined()) {
779
+ // The hardcoded 0 here will need to be updated once we support multiple levels.
780
+ ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace});
781
+ }
782
+ }
783
+ """
784
+ )
785
+
786
+ FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate(
787
+ """\
788
+ if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined()
789
+ && ${out_arg}.defined()) {
790
+ ${out_arg}._set_fw_grad(std::get<${idx}>(${all_res}_new_fw_grad_opt.value()), /* level */ 0, /* is_inplace_op */ false);
791
+ }
792
+ """
793
+ )
794
+
795
+ FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate(
796
+ """\
797
+ if (${out_arg}_new_fw_grad_opt.has_value()) {
798
+ auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value();
799
+ TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size());
800
+ for (const auto i : c10::irange(${out_arg}.size())) {
801
+ if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) {
802
+ // The hardcoded 0 here will need to be updated once we support multiple levels.
803
+ ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace});
804
+ }
805
+ }
806
+ }
807
+ """
808
+ )
809
+
810
+ FW_DERIVATIVE_TEMPLATE = CodeTemplate(
811
+ """\
812
+ ${fw_grad_opt_definition}
813
+ if (${requires_fw_grad}) {
814
+ ${unpacked_arguments}
815
+ ${out_arg}_new_fw_grad_opt = ${formula};
816
+ }
817
+ """
818
+ )
819
+
820
+ FW_DERIVATIVE_FOREACH_TEMPLATE = CodeTemplate(
821
+ """\
822
+ ${fw_grad_opt_definition}
823
+ for (const auto& i : c10::irange(${vector_of_optional_tensor}.size())) {
824
+ if (${any_has_forward_grad_for_current_index}) {
825
+ ${unpacked_arguments}
826
+ ${vector_of_optional_tensor}[i] = ${formula};
827
+ }
828
+ }
829
+ """
830
+ )
831
+
832
+ FW_DERIVATIVE_FORBID_TEMPLATE = CodeTemplate(
833
+ """\
834
+ TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
835
+ """
836
+ )
837
+
838
+ FW_DERIVATIVE_FORBID_LIST_TEMPLATE = CodeTemplate(
839
+ """\
840
+ for (const auto& _t: ${arg}) {
841
+ TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
842
+ }
843
+ """
844
+ )
845
+
846
+
847
+ def gen_variable_type(
848
+ out: str,
849
+ native_yaml_path: str,
850
+ tags_yaml_path: str,
851
+ fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo],
852
+ template_path: str,
853
+ used_keys: set[str],
854
+ ) -> None:
855
+ """VariableType.h and VariableType.cpp body
856
+
857
+ This is the at::Type subclass for differentiable tensors. The
858
+ implementation of each function dispatches to the base tensor type to
859
+ compute the output. The grad_fn is attached to differentiable functions.
860
+ """
861
+ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
862
+ fm.write(
863
+ "VariableType.h",
864
+ lambda: {
865
+ "generated_comment": "@"
866
+ + f"generated from {fm.template_dir_for_comments()}/VariableType.h"
867
+ },
868
+ )
869
+
870
+ # helper that generates a TORCH_LIBRARY_IMPL macro for each
871
+ # dispatch key that appears in derivatives.yaml
872
+ def wrapper_registrations(used_keys: set[str]) -> str:
873
+ library_impl_macro_list: list[str] = []
874
+ for key in sorted(used_keys):
875
+ dispatch_key = key
876
+ if key == "Default":
877
+ dispatch_key = "Autograd"
878
+ library_impl_macro = (
879
+ f"TORCH_LIBRARY_IMPL(aten, {dispatch_key}, m) "
880
+ + "{\n"
881
+ + "${"
882
+ + f"wrapper_registrations_{key}"
883
+ + "}\n}"
884
+ )
885
+ library_impl_macro_list += [library_impl_macro]
886
+ return "\n\n".join(library_impl_macro_list)
887
+
888
+ # Generate a new template from VariableType.cpp which replaces ${wrapper_registrations}
889
+ # with per key TORCH_LIBRARY_IMPL macros for each key that appears in derivatives.yaml
890
+ fm1 = FileManager(
891
+ install_dir=out + "/templates", template_dir=template_path, dry_run=False
892
+ )
893
+ fm1.write(
894
+ "VariableType.cpp",
895
+ lambda: {
896
+ "type_derived_method_definitions": "\n\n".join(
897
+ [
898
+ "${" + f"type_derived_method_definitions_{key}" + "}"
899
+ for key in sorted(used_keys)
900
+ ]
901
+ ),
902
+ "wrapper_registrations": wrapper_registrations(used_keys),
903
+ },
904
+ )
905
+
906
+ # Generate final VariableType_*.cpp files from the generated template
907
+ fm2 = FileManager(install_dir=out, template_dir=out + "/templates", dry_run=False)
908
+
909
+ sharded_keys = set(
910
+ [f"type_derived_method_definitions_{key}" for key in sorted(used_keys)]
911
+ + [f"wrapper_registrations_{key}" for key in sorted(used_keys)]
912
+ )
913
+ # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
914
+ # template regarding sharding of the generated files.
915
+ fm2.write_sharded(
916
+ "VariableType.cpp",
917
+ [fn for fn in fns_with_diff_infos if use_derived(fn)],
918
+ key_fn=lambda fn: cpp.name(fn.func.func),
919
+ base_env={
920
+ "generated_comment": "@"
921
+ + f"generated from {fm.template_dir_for_comments()}/VariableType.cpp",
922
+ },
923
+ env_callable=gen_variable_type_func,
924
+ num_shards=5,
925
+ sharded_keys=sharded_keys,
926
+ )
927
+
928
+
929
+ @with_native_function_and
930
+ def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str:
931
+ return WRAPPER_REGISTRATION.substitute(
932
+ unqual_operator_name_with_overload=f.func.name,
933
+ type_wrapper_name=type_wrapper_name(f, key),
934
+ class_type="VariableType",
935
+ )
936
+
937
+
938
+ def gen_variable_type_func(
939
+ fn: NativeFunctionWithDifferentiabilityInfo,
940
+ ) -> dict[str, list[str]]:
941
+ f = fn.func
942
+ result = {}
943
+ with native_function_manager(f):
944
+ name = cpp.name(f.func)
945
+ formals = gen_formals(f)
946
+
947
+ if (
948
+ fn.info is None
949
+ and str(f.func.name.name) not in RESET_GRAD_ACCUMULATOR
950
+ and get_base_name(f) not in DONT_REQUIRE_DERIVATIVE
951
+ and len(gen_differentiable_outputs(fn)) > 0
952
+ and cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE
953
+ and type_wrapper_name(f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
954
+ and type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT
955
+ ):
956
+ # NOTE: [ Registering AutogradNotImplemented boxed kernel ]
957
+ #
958
+ # When there is no derivatives.yaml entry, we register a generic boxed
959
+ # NotImplemented kernel to set grad_fn to be NotImplemented, so that forward
960
+ # proceeds as usual but an error is properly produced on backward.
961
+ # TODO: it would be nice to not have these special cases
962
+ #
963
+ # There are several cases where still let codegen handle it:
964
+ # 1) ops that need to reset grad accumulator (we let codegen handle this case
965
+ # because) the list is (currently) only accessible in Python.
966
+ # 2) User explicitly specifies DONT_REQUIRE_DERIVATIVE. This basically makes
967
+ # autograd a fallthrough with NDEBUG checks. This can be useful for when all
968
+ # outputs are integral.
969
+ # 3) When there are no differentiable outputs. This is similar to (2).
970
+ # 4) There are certain ops where we skip certain NDEBUG checks. this is similar
971
+ # to (1).
972
+ type_definition = ""
973
+ wrapper_registration = AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION.substitute(
974
+ unqual_operator_name_with_overload=f.func.name
975
+ )
976
+ result["type_derived_method_definitions_Default"] = [type_definition]
977
+ result["wrapper_registrations_Default"] = [wrapper_registration]
978
+ else:
979
+ if not fn.info:
980
+ key = "Default"
981
+ type_definition = METHOD_DEFINITION.substitute(
982
+ return_type=cpp.returns_type(
983
+ f.func.returns, symint=True
984
+ ).cpp_type(),
985
+ type_wrapper_name=type_wrapper_name(f, key),
986
+ type_definition_body=emit_body(fn, key),
987
+ formals=formals,
988
+ )
989
+ wrapper_registration = gen_wrapper_registration(f, key)
990
+ result[f"type_derived_method_definitions_{key}"] = [type_definition]
991
+ result[f"wrapper_registrations_{key}"] = [wrapper_registration]
992
+ else:
993
+ for key in fn.info.keys():
994
+ type_definition = METHOD_DEFINITION.substitute(
995
+ return_type=cpp.returns_type(
996
+ f.func.returns, symint=True
997
+ ).cpp_type(),
998
+ type_wrapper_name=type_wrapper_name(f, key),
999
+ type_definition_body=emit_body(fn, key),
1000
+ formals=formals,
1001
+ )
1002
+ wrapper_registration = gen_wrapper_registration(f, key)
1003
+ result[f"type_derived_method_definitions_{key}"] = [type_definition]
1004
+ result[f"wrapper_registrations_{key}"] = [wrapper_registration]
1005
+ # See Note [Manual Backend kernels]
1006
+ assert (name in MANUAL_BACKEND) == f.manual_kernel_registration
1007
+ # If you want to register a kernel to Autograd, you must make the op abstract.
1008
+ # In other words, this op must have dispatch section in native_functions.yaml.
1009
+ if name in MANUAL_AUTOGRAD_AND_TRACER or (
1010
+ fn.info and any(info.has_derivatives for info in fn.info.values())
1011
+ ):
1012
+ msg = (
1013
+ f"There's a formula for {name}(or its functional variant) in derivatives.yaml. "
1014
+ f"It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA "
1015
+ f"or CompositeExplicitAutograd in native_functions.yaml. Please see "
1016
+ f"https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword "
1017
+ f"for instructions to choose the right dispatch keyword."
1018
+ )
1019
+ assert f.is_abstract, msg
1020
+
1021
+ return result
1022
+
1023
+
1024
+ _foreach_ops_without_differentiability_info = {
1025
+ # No reference backward available due to the lack of `{maximum, minimum}(tensor, scalar)`.
1026
+ ("_foreach_maximum", "Scalar"),
1027
+ ("_foreach_maximum", "ScalarList"),
1028
+ ("_foreach_minimum", "Scalar"),
1029
+ ("_foreach_minimum", "ScalarList"),
1030
+ # No reference backward available as addcdiv/addcmul don't support Tensor as scaling factor.
1031
+ ("_foreach_addcdiv", "Tensor"),
1032
+ ("_foreach_addcmul", "Tensor"),
1033
+ ("_foreach_copy", ""),
1034
+ }
1035
+
1036
+ _foreach_ops_with_different_arity = {
1037
+ # These ops lack `alpha` of scaling factor to applied to the right hand side argument.
1038
+ ("_foreach_add", "Scalar"),
1039
+ ("_foreach_add", "ScalarList"),
1040
+ ("_foreach_sub", "Scalar"),
1041
+ ("_foreach_sub", "ScalarList"),
1042
+ }
1043
+
1044
+
1045
+ @with_native_function_with_differentiability_info_and_key
1046
+ def emit_body(
1047
+ fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
1048
+ ) -> list[str]:
1049
+ assert dispatch_strategy(fn) == "use_derived"
1050
+ f = fn.func
1051
+ info = fn.info[key] if fn.info else None
1052
+ fw_derivatives = fn.fw_derivatives.get(key, []) if fn.fw_derivatives else []
1053
+
1054
+ name = cpp.name(f.func)
1055
+ inplace = f.func.kind() == SchemaKind.inplace
1056
+ is_out_fn = f.func.kind() == SchemaKind.out
1057
+ returns_void = len(f.func.returns) == 0
1058
+ base_name = get_base_name(f)
1059
+ view_info = get_view_info(f)
1060
+
1061
+ is_foreach = name.startswith("_foreach")
1062
+ is_inplace_foreach = is_foreach and inplace
1063
+ if is_inplace_foreach:
1064
+ inplace_foreacharg2refarg: dict[Argument, Argument] = {}
1065
+ refargname2inplace_foreacharg: dict[str, Argument] = {}
1066
+ base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name)
1067
+ if info is None:
1068
+ assert (
1069
+ base_name_and_overload_name
1070
+ in _foreach_ops_without_differentiability_info
1071
+ ), f"{'.'.join(base_name_and_overload_name)} should have a differentiability info"
1072
+ else:
1073
+ assert (
1074
+ len(f.func.arguments.flat_non_out)
1075
+ == len(info.func.func.arguments.flat_non_out)
1076
+ ) or (base_name_and_overload_name in _foreach_ops_with_different_arity), (
1077
+ f"{'.'.join(base_name_and_overload_name)} has {len(f.func.arguments.flat_non_out)} args "
1078
+ f"but the reference has {len(info.func.func.arguments.flat_non_out)}"
1079
+ )
1080
+ for foreach_arg, ref_arg in zip(
1081
+ f.func.arguments.flat_non_out, info.func.func.arguments.flat_non_out
1082
+ ):
1083
+ foreach_arg_type = foreach_arg.type
1084
+ if isinstance(foreach_arg_type, ListType):
1085
+ foreach_arg_type = foreach_arg_type.elem
1086
+ assert foreach_arg_type == ref_arg.type
1087
+ inplace_foreacharg2refarg[foreach_arg] = ref_arg
1088
+ refargname2inplace_foreacharg[ref_arg.name] = foreach_arg
1089
+
1090
+ def gen_differentiable_input(
1091
+ arg: Argument | SelfArgument | TensorOptionsArguments,
1092
+ ) -> DifferentiableInput | None:
1093
+ if isinstance(arg, TensorOptionsArguments):
1094
+ return None
1095
+ a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
1096
+
1097
+ # TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove.
1098
+ # NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are
1099
+ # not handled properly as they are irrelevant for this codegen.
1100
+ cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type()
1101
+
1102
+ if not is_differentiable(a.name, a.type, info):
1103
+ return None
1104
+ return DifferentiableInput(
1105
+ name=a.name,
1106
+ type=a.type,
1107
+ cpp_type=cpp_type,
1108
+ )
1109
+
1110
+ @with_native_function
1111
+ def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]:
1112
+ arguments = list(f.func.arguments.non_out)
1113
+ if is_inplace_foreach and info is not None:
1114
+ for i, arg in enumerate(f.func.arguments.flat_non_out):
1115
+ if arg in inplace_foreacharg2refarg:
1116
+ # note(crcrpar): From what I understand, what matters is only the name.
1117
+ # Thus originally I only replace argument only when the names are different.
1118
+ # TODO(crcrpar): Make it simpler.
1119
+ mapped_arg = inplace_foreacharg2refarg[arg]
1120
+ arguments[i] = Argument(
1121
+ mapped_arg.name,
1122
+ mapped_arg.type,
1123
+ mapped_arg.default,
1124
+ mapped_arg.annotation,
1125
+ )
1126
+ return list(mapMaybe(gen_differentiable_input, arguments))
1127
+
1128
+ def find_args_with_derivatives(
1129
+ differentiable_inputs: list[DifferentiableInput],
1130
+ ) -> list[DifferentiableInput]:
1131
+ """Find arguments that have derivative definitions"""
1132
+ if info is None or not info.has_derivatives:
1133
+ return differentiable_inputs
1134
+ names = {name for d in info.derivatives for name in d.var_names}
1135
+ differentiable = [arg for arg in differentiable_inputs if arg.name in names]
1136
+ if len(differentiable) != len(names):
1137
+ missing = names - {arg.name for arg in differentiable}
1138
+ raise RuntimeError(
1139
+ f"Missing arguments for derivatives: {missing} in {info.name}"
1140
+ )
1141
+ return differentiable
1142
+
1143
+ differentiable_inputs = gen_differentiable_inputs(f)
1144
+ args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
1145
+ differentiable_outputs = gen_differentiable_outputs(fn, key)
1146
+
1147
+ undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or (
1148
+ name in DONT_REQUIRE_DERIVATIVE
1149
+ )
1150
+
1151
+ requires_derivative = (
1152
+ (not undifferentiable)
1153
+ and (len(differentiable_inputs) > 0)
1154
+ and (
1155
+ (len(differentiable_outputs) > 0)
1156
+ # note(crcrpar): In-place foreach functions are a void function.
1157
+ or is_inplace_foreach
1158
+ )
1159
+ )
1160
+
1161
+ if (
1162
+ info is not None
1163
+ and info.has_derivatives
1164
+ and not requires_derivative
1165
+ # out= ops are allowed to have zero returns which cause requires_derivative to be False
1166
+ # we shouldn't error out though (out= ops for autograd just redispatch)
1167
+ and len(f.func.returns) > 0
1168
+ ):
1169
+ raise RuntimeError(
1170
+ f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
1171
+ )
1172
+
1173
+ # note(crcrpar): In-place foreach functions do not support forward AD
1174
+ if requires_derivative and len(fw_derivatives) > 0 and not is_inplace_foreach:
1175
+ assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len(
1176
+ differentiable_outputs
1177
+ ), (
1178
+ "Expected the number of forward derivatives implemented to match the "
1179
+ "number of differentiable outputs. NB: This only applies when at least "
1180
+ "one forward derivative is implemented. Not implementing any forward "
1181
+ "derivatives is also okay, and we would require inputs to the op to "
1182
+ "not have associated tangents in that case."
1183
+ )
1184
+
1185
+ try_jit_decomposition = (
1186
+ requires_derivative
1187
+ and len(fw_derivatives) == 0
1188
+ and (not modifies_arguments(f))
1189
+ and (not returns_void)
1190
+ )
1191
+
1192
+ def emit_save_inputs() -> list[str]:
1193
+ setup: list[str] = []
1194
+ if info is None or not info.has_derivatives:
1195
+ return setup
1196
+
1197
+ has_tensorlist_arg = any(
1198
+ is_tensor_list_type(arg.type) for arg in args_with_derivatives
1199
+ )
1200
+
1201
+ # We don't want to save tensors if we know that they will never be used
1202
+ # when computing the derivative, so we add guards to those statements
1203
+ def guard_for(arg: SavedAttribute) -> str | None:
1204
+ assert info is not None
1205
+
1206
+ # It's hard to determine the edge offset if we have TensorLists
1207
+ # NOTE(crcrpar): in-place foreach functions' arguments include tensorlist
1208
+ # but their derivatives don't use it, so let them bypass this check.
1209
+ if has_tensorlist_arg and (not is_inplace_foreach):
1210
+ return None
1211
+
1212
+ # Empirical evaluation of the cases where we insert those guards in
1213
+ # backward show that they are somewhat useless. E.g. there's no need
1214
+ # to guard on some values captured from forward, because they had to
1215
+ # require_grad if the backward function even gets executed. I don't
1216
+ # have any good ideas for detecting those cases, so I simply disabled the
1217
+ # checks.
1218
+ if "backward" in info.name:
1219
+ return None
1220
+
1221
+ # If there's a single derivative we could compute, we already have
1222
+ # a requires_grad check that is sufficient
1223
+ if len(args_with_derivatives) <= 1:
1224
+ return None
1225
+
1226
+ # We really only care about trimming down the amount of tensors we save
1227
+ if arg.nctype.type != BaseCType(tensorT):
1228
+ return None
1229
+
1230
+ # We want to emit simple guards, so we only allow that if checking one
1231
+ # input is enough to determine whether we need that value
1232
+ used_in = [d for d in info.derivatives if arg in d.saved_inputs]
1233
+ assert len(used_in) > 0
1234
+ if len(used_in) != 1:
1235
+ return None
1236
+ derivative = used_in[0]
1237
+
1238
+ # Case with multioutput formulas
1239
+ # TODO: process all derivative formulas!!!
1240
+ if len(derivative.var_names) != 1:
1241
+ wrap_opt_if_start = derivative.formula.find(
1242
+ f"wrap_opt_if({arg.nctype.name}"
1243
+ )
1244
+ if wrap_opt_if_start == -1:
1245
+ return None
1246
+
1247
+ wrap_opt_if_match = re.match(
1248
+ rf"wrap_opt_if\({arg.nctype.name},(.*?)\)",
1249
+ derivative.formula[wrap_opt_if_start:],
1250
+ )
1251
+ assert wrap_opt_if_match is not None
1252
+
1253
+ # Condition is between 'wrap_opt_if(var_name,' and ')'.
1254
+ condition_slice = slice(len(rf"wrap_opt_if\({arg.nctype.name},"), -1)
1255
+ wrap_opt_if_condition = wrap_opt_if_match.group(0)[
1256
+ condition_slice
1257
+ ].strip()
1258
+ # replace 'grad_input_mask[num]' with 'grad_fn->should_compute_output(num)'
1259
+ wrap_opt_if_condition = re.sub(
1260
+ r"grad_input_mask\[(\d+)\]",
1261
+ r"grad_fn->should_compute_output(\1)",
1262
+ wrap_opt_if_condition,
1263
+ )
1264
+ return f"{wrap_opt_if_condition}"
1265
+
1266
+ # Figure out the offset of the edge that uses this variable
1267
+ derivative_var_name = derivative.var_names[0]
1268
+ for edge_off, a in enumerate(args_with_derivatives):
1269
+ if a.name == derivative_var_name:
1270
+ break
1271
+ else:
1272
+ raise AssertionError
1273
+ return f"grad_fn->should_compute_output({edge_off})"
1274
+
1275
+ if is_inplace_foreach:
1276
+ save_input_stmts = save_variables(info.all_saved_inputs, False, guard_for)
1277
+ if save_input_stmts:
1278
+ setup.append(
1279
+ LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
1280
+ preamble="", statements=save_input_stmts
1281
+ )
1282
+ )
1283
+ else:
1284
+ setup.extend(save_variables(info.all_saved_inputs, False, guard_for))
1285
+ for arg in args_with_derivatives:
1286
+ if is_tensor_list_type(arg.type):
1287
+ setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
1288
+ return setup
1289
+
1290
+ def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]:
1291
+ body: list[str] = []
1292
+ if is_out_fn:
1293
+ # For out functions, ensure that no input or output requires grad
1294
+ body.append(DECLARE_GRAD_FN.substitute(op="Node"))
1295
+ body.append(
1296
+ SETUP_NONE_REQUIRES_GRAD.substitute(
1297
+ base_name=base_name,
1298
+ args_to_check=[arg.name for arg in differentiable_inputs],
1299
+ )
1300
+ )
1301
+ body.append(
1302
+ SETUP_NONE_REQUIRES_GRAD.substitute(
1303
+ base_name=base_name,
1304
+ args_to_check=[arg.name for arg in differentiable_outputs],
1305
+ )
1306
+ )
1307
+ return body
1308
+
1309
+ op = info.op if info is not None and info.has_derivatives else "NotImplemented"
1310
+ setup = []
1311
+ if not is_inplace_foreach:
1312
+ setup.extend(
1313
+ ASSIGN_GRAD_FN.substitute(
1314
+ op=op,
1315
+ op_ctor=""
1316
+ if info is not None and info.has_derivatives
1317
+ else f'"{cpp.name(f.func)}"',
1318
+ args_with_derivatives=[arg.name for arg in args_with_derivatives],
1319
+ ).split("\n")
1320
+ )
1321
+ else:
1322
+ # note(crcrpar): Assuming in-place foreach function's self_arg is always TensorList.
1323
+ list_like_arg = "self"
1324
+ args = [arg.name for arg in args_with_derivatives]
1325
+ for i, arg in enumerate(args):
1326
+ if is_inplace_foreach and info is not None:
1327
+ if arg in refargname2inplace_foreacharg:
1328
+ foreach_arg = refargname2inplace_foreacharg[arg]
1329
+ args[i] = foreach_arg.name + (
1330
+ "[i]" if isinstance(foreach_arg.type, ListType) else ""
1331
+ )
1332
+ else:
1333
+ if arg == list_like_arg:
1334
+ args[i] = arg + "[i]"
1335
+ setup.extend(
1336
+ ASSIGN_VECTOR_OF_GRAD_FN.substitute(
1337
+ op=op,
1338
+ op_ctor=""
1339
+ if info is not None and info.has_derivatives
1340
+ else f'"{cpp.name(f.func)}"',
1341
+ args_with_derivatives=args,
1342
+ irange=f"{list_like_arg}.size()",
1343
+ ).split("\n")
1344
+ )
1345
+ setup.extend(emit_save_inputs())
1346
+
1347
+ body.extend(
1348
+ emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives)
1349
+ )
1350
+ declare_grad_fn_template = (
1351
+ DECLARE_GRAD_FN if not is_inplace_foreach else DECLARE_VECTOR_OF_GRAD_FN
1352
+ )
1353
+ body.append(declare_grad_fn_template.substitute(op=op))
1354
+ body.append(SETUP_DERIVATIVE.substitute(setup=setup))
1355
+ return body
1356
+
1357
+ def emit_check_if_in_complex_autograd_allowlist() -> list[str]:
1358
+ body: list[str] = []
1359
+ if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
1360
+ return body
1361
+ for arg in differentiable_outputs:
1362
+ name = arg.name
1363
+ # TODO: should be `arg.type.is_tensor_like()`?
1364
+ if arg.cpp_type == "at::Tensor" or arg.cpp_type in TENSOR_LIST_LIKE_CTYPES:
1365
+ body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");')
1366
+ return body
1367
+
1368
+ def emit_check_no_requires_grad(
1369
+ tensor_args: list[DifferentiableInput],
1370
+ args_with_derivatives: list[DifferentiableInput],
1371
+ ) -> list[str]:
1372
+ """Checks that arguments without derivatives don't require grad"""
1373
+ body: list[str] = []
1374
+ for arg in tensor_args:
1375
+ if arg in args_with_derivatives:
1376
+ continue
1377
+ arg_name = arg.name
1378
+ if info and arg_name in info.non_differentiable_arg_names:
1379
+ continue
1380
+ if arg_name == "output":
1381
+ # Double-backwards definitions sometimes take in 'input' and
1382
+ # 'output', but only define the derivative for input.
1383
+ continue
1384
+ body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");')
1385
+ return body
1386
+
1387
+ def emit_original_self_definition() -> list[str]:
1388
+ body: list[str] = []
1389
+ if inplace:
1390
+ if is_inplace_foreach:
1391
+ body.append(
1392
+ "std::vector<::std::optional<at::Tensor>> original_selfs(self.size());"
1393
+ )
1394
+ else:
1395
+ body.append("::std::optional<at::Tensor> original_self;")
1396
+
1397
+ all_forward_grad_cond = []
1398
+ for derivative in fw_derivatives:
1399
+ if derivative.required_original_self_value:
1400
+ all_forward_grad_cond.append(
1401
+ get_any_has_forward_grad_name(derivative.var_names)
1402
+ )
1403
+
1404
+ if all_forward_grad_cond:
1405
+ if not is_inplace_foreach:
1406
+ body.append(f'if ({" || ".join(all_forward_grad_cond)}) {{')
1407
+ body.append(" original_self = self.clone();")
1408
+ body.append("}")
1409
+ else:
1410
+ current_all_forward_grad_cond = [
1411
+ f"{cond}[i]" for cond in all_forward_grad_cond
1412
+ ]
1413
+ body.append("for (const auto& i : c10::irange(self.size())) {")
1414
+ body.append(
1415
+ f" if ({' || '.join(current_all_forward_grad_cond)}) {{"
1416
+ )
1417
+ body.append(" original_selfs[i] = self[i].clone();")
1418
+ body.append(" }")
1419
+ body.append("}")
1420
+
1421
+ return body
1422
+
1423
+ def save_variables(
1424
+ saved_variables: Sequence[SavedAttribute],
1425
+ is_output: bool,
1426
+ guard_for: Callable[[SavedAttribute], str | None] = lambda name: None,
1427
+ ) -> Sequence[str]:
1428
+ # assign the saved variables to the generated grad_fn
1429
+ stmts: list[str] = []
1430
+ for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
1431
+ name = (
1432
+ arg.nctype.name.name
1433
+ if isinstance(arg.nctype.name, SpecialArgName)
1434
+ else arg.nctype.name
1435
+ )
1436
+ foreacharg: Argument | None = None
1437
+ is_foreacharg_list_type: bool = False
1438
+ type = arg.nctype.type
1439
+ expr = arg.expr
1440
+ stmts_prepend = None
1441
+ if is_inplace_foreach and info is not None:
1442
+ # todo(crcrpar): See if we can add some check e.g. `assert foreacharg is not None`.
1443
+ # for now the example assert would fail.
1444
+ name_to_query = name.split("_scalar_type")[0]
1445
+ if name_to_query in refargname2inplace_foreacharg:
1446
+ foreacharg = refargname2inplace_foreacharg[name_to_query]
1447
+ is_foreacharg_list_type = isinstance(foreacharg.type, ListType)
1448
+ if foreacharg is not None:
1449
+ name_in_expr = (
1450
+ f"{foreacharg.name}{'[i]' if is_foreacharg_list_type else ''}"
1451
+ )
1452
+ src_name = name
1453
+ if "_scalar_type" in src_name:
1454
+ split_src_name = src_name.split("_scalar_type")
1455
+ assert len(split_src_name) == 2
1456
+ src_name = split_src_name[0]
1457
+ expr = expr.replace(src_name, name_in_expr)
1458
+ if (
1459
+ type == BaseCType(tensorT)
1460
+ or type == OptionalCType(BaseCType(tensorT))
1461
+ or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
1462
+ or (is_output and type == BaseCType(scalarT))
1463
+ ):
1464
+ # note(crcrpar): Here `expr` is generated from scratch, `arg.expr` is ignored.
1465
+ var = name
1466
+ name += "_"
1467
+ if var == "self" and inplace:
1468
+ original_self_var = (
1469
+ "original_self"
1470
+ if not is_inplace_foreach
1471
+ else "original_selfs[i]"
1472
+ )
1473
+ self_var = var if not is_inplace_foreach else var + "[i]"
1474
+ stmts_prepend = f"if (!{original_self_var}.has_value()) {original_self_var} = {self_var}.clone()"
1475
+ var = f"{original_self_var}.value()"
1476
+ assert not is_output
1477
+ if inplace and is_output:
1478
+ assert name == "result_"
1479
+ var = (
1480
+ "self[i]"
1481
+ if is_inplace_foreach or is_foreacharg_list_type
1482
+ else "self"
1483
+ )
1484
+ is_inplace_view = f"{var}.is_view()"
1485
+ expr = f"SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})"
1486
+ else:
1487
+ expr = f"SavedVariable({var}, {str(is_output).lower()})"
1488
+ if foreacharg is not None and "original_selfs" not in expr:
1489
+ expr = expr.replace(src_name, name_in_expr)
1490
+ elif (
1491
+ type == BaseCType(tensorListT)
1492
+ or type == ListCType(OptionalCType(BaseCType(tensorT)))
1493
+ or type == BaseCType(iTensorListRefT)
1494
+ or type == VectorCType(BaseCType(tensorT))
1495
+ ):
1496
+ # See Note [nuanced return type of out-of-place foreach functions]
1497
+ if type == VectorCType(BaseCType(tensorT)):
1498
+ assert is_foreach and is_output
1499
+ expr = f"make_saved_variable_list({name}, {str(is_foreach and is_output).lower()})"
1500
+ name += "_"
1501
+ elif type == BaseCType(intArrayRefT):
1502
+ expr = expr + ".vec()"
1503
+ elif type == BaseCType(symIntArrayRefT):
1504
+ expr = expr + ".vec()"
1505
+ elif type == BaseCType(stringT):
1506
+ expr = f"std::string({expr})"
1507
+ elif type == OptionalCType(BaseCType(stringT)):
1508
+ expr = f"{expr}.has_value() ? ::std::optional<std::string>(std::string({expr}.value())) : ::std::nullopt"
1509
+ elif type == ArrayRefCType(
1510
+ elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
1511
+ ):
1512
+ expr = expr + ".vec()"
1513
+
1514
+ guard = guard_for(arg)
1515
+ if guard is None:
1516
+ if stmts_prepend:
1517
+ stmts.append(f"{stmts_prepend};")
1518
+ stmts.append(f"grad_fn->{name} = {expr};")
1519
+ else:
1520
+ stmts.append(f"if ({guard}) {{")
1521
+ if stmts_prepend:
1522
+ stmts.append(f" {stmts_prepend};")
1523
+ stmts.append(f" grad_fn->{name} = {expr};")
1524
+ stmts.append("}")
1525
+ return stmts
1526
+
1527
+ # Generates a Dispatcher::redispatch() call into the dispatcher. We do this mainly for performance reasons:
1528
+ # - Pre-compute the full DispatchKeySet. This saves the dispatcher from having to read from TLS.
1529
+ # - redispatch() avoids a redundant call to RecordFunction, which was already called right before
1530
+ # we entered this autograd kernel.
1531
+ def emit_dispatch_call(
1532
+ f: NativeFunction, input_base: str, unpacked_args: Sequence[str]
1533
+ ) -> str:
1534
+ """Dispatch call via function in a namespace or method on Tensor."""
1535
+ dispatcher_sig = DispatcherSignature.from_schema(f.func)
1536
+ dispatcher_exprs = dispatcher_sig.exprs()
1537
+
1538
+ # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
1539
+ # Ops also always have a function variant of the redispatch API.
1540
+ # See Note [Plumbing Keys Through The Dispatcher] for details.
1541
+ dispatch_key_set = "ks & c10::after_autograd_keyset"
1542
+ call = CALL_REDISPATCH.substitute(
1543
+ api_name=cpp.name(
1544
+ f.func,
1545
+ faithful_name_for_out_overloads=True,
1546
+ symint_overload=f.func.has_symint(),
1547
+ ),
1548
+ unpacked_args=[dispatch_key_set] + list(unpacked_args),
1549
+ )
1550
+ return call
1551
+
1552
+ def wrap_output(
1553
+ f: NativeFunction, unpacked_bindings: list[Binding], var: str
1554
+ ) -> str:
1555
+ call = ""
1556
+ rhs_value: str | None = None
1557
+ if not any(r.type.is_tensor_like() for r in f.func.returns):
1558
+ rhs_value = var
1559
+ else:
1560
+ rhs_value = f"std::move({var})"
1561
+ assert rhs_value is not None
1562
+ call += ASSIGN_RETURN_VALUE.substitute(
1563
+ return_values=tie_return_values(f), rhs_value=rhs_value
1564
+ )
1565
+ return call
1566
+
1567
+ def check_tensorimpl_and_storage(
1568
+ call: str, unpacked_bindings: list[Binding]
1569
+ ) -> str:
1570
+ # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
1571
+ stmts_before_call: list[str] = []
1572
+ stmts_after_call: list[str] = []
1573
+
1574
+ if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
1575
+ return call
1576
+
1577
+ # Check properties of inputs (enforce (1))
1578
+ for unpacked_binding in unpacked_bindings:
1579
+ arg = unpacked_binding.name
1580
+ noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
1581
+ if noref_cpp_type == BaseCType(tensorListT) or noref_cpp_type == BaseCType(
1582
+ iTensorListRefT
1583
+ ):
1584
+ stmts_before_call += [
1585
+ SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
1586
+ SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
1587
+ ]
1588
+ stmts_after_call += [
1589
+ ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
1590
+ ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
1591
+ ]
1592
+ elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
1593
+ stmts_before_call += [
1594
+ SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
1595
+ SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg),
1596
+ ]
1597
+ stmts_after_call += [
1598
+ ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
1599
+ tensorlist_name=arg
1600
+ ),
1601
+ ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
1602
+ tensorlist_name=arg
1603
+ ),
1604
+ ]
1605
+ elif noref_cpp_type == BaseCType(tensorT):
1606
+ stmts_before_call += [
1607
+ SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
1608
+ SAVE_TENSOR_IMPL.substitute(tensor_name=arg),
1609
+ ]
1610
+ stmts_after_call += [
1611
+ ENFORCE_SAME_TENSOR_STORAGE.substitute(
1612
+ tensor_name=arg, out_tensor_name=arg
1613
+ ),
1614
+ ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg),
1615
+ ]
1616
+
1617
+ assert (stmts_before_call and stmts_after_call) or (
1618
+ not stmts_before_call and not stmts_after_call
1619
+ )
1620
+
1621
+ # Check properties of outputs (enforce (2), (3))
1622
+ if f.func.kind() not in (SchemaKind.inplace, SchemaKind.out):
1623
+ base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)?
1624
+ aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None)
1625
+ if aliased_arg_name is not None:
1626
+ aliased_arg_name = unpacked_name(aliased_arg_name)
1627
+ for i, (ret, ret_name) in enumerate(
1628
+ zip(f.func.returns, cpp.return_names(f))
1629
+ ):
1630
+ noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref()
1631
+ if noref_cpp_type == BaseCType(tensorT):
1632
+ if aliased_arg_name is not None:
1633
+ assert (
1634
+ i == 0
1635
+ ), "Expect non-CompositeImplicitAutograd view function {base} to return single output"
1636
+ stmts_after_call += [
1637
+ ENFORCE_SAME_TENSOR_STORAGE.substitute(
1638
+ tensor_name=aliased_arg_name, out_tensor_name=ret_name
1639
+ )
1640
+ ]
1641
+ else:
1642
+ if (
1643
+ type_wrapper_name(f)
1644
+ not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
1645
+ ):
1646
+ stmts_after_call += [
1647
+ ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.substitute(
1648
+ tensor_name=ret_name, fn_name=type_wrapper_name(f)
1649
+ )
1650
+ ]
1651
+
1652
+ if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
1653
+ stmts_after_call += [
1654
+ ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute(
1655
+ tensor_name=ret_name, fn_name=type_wrapper_name(f)
1656
+ )
1657
+ ]
1658
+
1659
+ # Currently we don't have any functions that return the following types, but
1660
+ # we should update the checks once we do
1661
+ elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
1662
+ raise AssertionError(
1663
+ f"Please add use_count checks for {noref_cpp_type}"
1664
+ )
1665
+ elif noref_cpp_type == BaseCType(tensorListT):
1666
+ raise AssertionError(
1667
+ f"Please add use_count checks for {noref_cpp_type}"
1668
+ )
1669
+
1670
+ if stmts_before_call and stmts_after_call:
1671
+ call = (
1672
+ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call)
1673
+ + call
1674
+ + RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
1675
+ )
1676
+ return call
1677
+
1678
+ def emit_call(
1679
+ f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool
1680
+ ) -> str:
1681
+ # We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
1682
+ # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
1683
+ # the baseType operations still dispatch to non-Variable type, even if the arguments passed
1684
+ # in are now Variables.
1685
+ # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details.
1686
+ unpacked_args = [b.name for b in unpacked_bindings]
1687
+ base_type_call = emit_dispatch_call(f, "self_", unpacked_args)
1688
+
1689
+ if get_view_info(f) is not None or modifies_arguments(f):
1690
+ guard = "at::AutoDispatchBelowAutograd guard;"
1691
+ else:
1692
+ guard = "at::AutoDispatchBelowADInplaceOrView guard;"
1693
+
1694
+ any_has_forward_grad = (
1695
+ get_any_has_fw_grad_cond(derivative=None)
1696
+ if requires_derivative
1697
+ else "false"
1698
+ )
1699
+ return_types = ", ".join(
1700
+ [cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns]
1701
+ )
1702
+ if len(f.func.returns) > 1:
1703
+ return_types = f"std::tuple<{return_types}>"
1704
+
1705
+ arg_names = [
1706
+ a.name
1707
+ for a in cpp.arguments(
1708
+ f.func.arguments,
1709
+ faithful=True,
1710
+ symint=True,
1711
+ method=False,
1712
+ cpp_no_default_args=set(),
1713
+ )
1714
+ ]
1715
+
1716
+ if not modifies_arguments(f) and not returns_void:
1717
+ if try_jit_decomposition:
1718
+ call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP.substitute(
1719
+ base_type_call=base_type_call,
1720
+ tmp_var=TMP_VAR,
1721
+ guard=guard,
1722
+ any_has_forward_grad=any_has_forward_grad,
1723
+ op_name=cpp.name(f.func),
1724
+ op_overload=f.func.name.overload_name,
1725
+ return_types=return_types,
1726
+ arg_names=arg_names,
1727
+ )
1728
+ else:
1729
+ call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
1730
+ base_type_call=base_type_call,
1731
+ tmp_var=TMP_VAR,
1732
+ guard=guard,
1733
+ )
1734
+
1735
+ call += wrap_output(f, unpacked_bindings, TMP_VAR)
1736
+ else:
1737
+ assert not try_jit_decomposition
1738
+ call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
1739
+ base_type_call=base_type_call, guard=guard
1740
+ )
1741
+ call = check_tensorimpl_and_storage(call, unpacked_bindings)
1742
+ return call
1743
+
1744
+ def emit_history() -> str:
1745
+ fn = "rebase" if modifies_arguments(f) and view_info is None else "set"
1746
+ output_names = [r.name for r in differentiable_outputs]
1747
+ # TODO: flatten allocates a std::vector, which could be expensive
1748
+ outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(
1749
+ outs=output_names if not is_inplace_foreach else "self"
1750
+ )
1751
+ if not is_inplace_foreach:
1752
+ return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
1753
+ else:
1754
+ return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
1755
+ preamble=(
1756
+ f"auto differentiable_outputs = {outs};\n"
1757
+ f"TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());"
1758
+ ),
1759
+ statements=f"{fn}_history(differentiable_outputs[i], grad_fns[i]);",
1760
+ )
1761
+
1762
+ def emit_save_outputs() -> str:
1763
+ if is_out_fn:
1764
+ # out functions don't currently support differentiation
1765
+ return ""
1766
+ if info is not None and info.has_derivatives:
1767
+ stmts = save_variables(info.all_saved_outputs, True)
1768
+ if len(stmts) == 0:
1769
+ return ""
1770
+ if not is_inplace_foreach:
1771
+ return CONDITIONAL.substitute(cond="grad_fn", statements=stmts)
1772
+ else:
1773
+ return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
1774
+ preamble="", statements=stmts
1775
+ )
1776
+ return ""
1777
+
1778
+ def emit_any_requires_grad() -> list[str]:
1779
+ extra_condition = ""
1780
+ if info and info.output_differentiability_conditions:
1781
+ assert len(info.output_differentiability_conditions) == 1
1782
+ extra_condition = f"_any_requires_grad &= ({info.output_differentiability_conditions[0]});"
1783
+ names_of_args_with_derivatives = [arg.name for arg in args_with_derivatives]
1784
+ if is_inplace_foreach and info is not None:
1785
+ for i, arg in enumerate(names_of_args_with_derivatives):
1786
+ for f_arg, r_arg in inplace_foreacharg2refarg.items():
1787
+ if arg == r_arg.name:
1788
+ names_of_args_with_derivatives[i] = f_arg.name
1789
+ return [
1790
+ SETUP_ANY_REQUIRES_GRAD.substitute(
1791
+ args_with_derivatives=names_of_args_with_derivatives,
1792
+ extra_differentiability_conditions=extra_condition,
1793
+ )
1794
+ ]
1795
+
1796
+ def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str:
1797
+ if len(var_names) == 1:
1798
+ return f"_any_has_forward_grad_{var_names[0]}"
1799
+ else:
1800
+ return f'_any_has_forward_grad_{"_".join(var_names)}'
1801
+
1802
+ def emit_any_has_forward_grad() -> list[str]:
1803
+ content: list[str] = []
1804
+ if not is_foreach:
1805
+ for derivative in fw_derivatives:
1806
+ requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
1807
+ if info and info.output_differentiability_conditions:
1808
+ assert len(info.output_differentiability_conditions) == 1
1809
+ requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}"
1810
+ content.append(
1811
+ f"[[maybe_unused]] auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};"
1812
+ )
1813
+ else:
1814
+ for derivative in fw_derivatives:
1815
+ bool_vector_name = get_any_has_forward_grad_name(derivative.var_names)
1816
+ cur_derivative_conditions = []
1817
+ for inp in differentiable_inputs:
1818
+ if derivative.required_inputs_fw_grad is None:
1819
+ continue
1820
+ if inp.name not in derivative.required_inputs_fw_grad:
1821
+ continue
1822
+ inp_name = (
1823
+ inp.name
1824
+ if not inplace
1825
+ else refargname2inplace_foreacharg[inp.name].name
1826
+ )
1827
+ inp_type = (
1828
+ inp.type
1829
+ if not inplace
1830
+ else refargname2inplace_foreacharg[inp.name].type
1831
+ )
1832
+ is_list_type = is_tensor_list_type(inp_type)
1833
+ if is_list_type:
1834
+ if inp_name != "self":
1835
+ content.append(
1836
+ FW_DERIVATIVE_SIZE_CHECK_TEMPLATE.substitute(
1837
+ inp_name=inp_name
1838
+ )
1839
+ )
1840
+ cur_derivative_conditions.append(
1841
+ FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
1842
+ req_inp=inp_name + "[i]"
1843
+ )
1844
+ )
1845
+ else:
1846
+ cur_derivative_conditions.append(
1847
+ FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name)
1848
+ )
1849
+
1850
+ content.append(f"std::vector<bool> {bool_vector_name}(self.size());")
1851
+ content.append("for (const auto& i : c10::irange(self.size())) {")
1852
+ content.append(
1853
+ f" {bool_vector_name}[i] = {' || '.join(cur_derivative_conditions)};"
1854
+ )
1855
+ content.append("}")
1856
+ return content
1857
+
1858
+ def emit_check_inplace() -> list[str]:
1859
+ if not inplace:
1860
+ return []
1861
+ return [
1862
+ f"check_inplace({arg.name}, _any_requires_grad);"
1863
+ for arg in differentiable_outputs
1864
+ ]
1865
+
1866
+ def emit_fw_derivatives() -> list[str]:
1867
+ content: list[str] = []
1868
+ fw_grad_setters: list[str] = []
1869
+ for derivative in fw_derivatives:
1870
+ res = derivative.var_names
1871
+ if f.func.name.name.inplace:
1872
+ assert (
1873
+ len(res) == 1
1874
+ ), "Expected number of outputs to be 1 if function is inplace"
1875
+ # TODO update this when inplace namings are unified
1876
+ res = ("self",)
1877
+
1878
+ assert derivative.required_inputs_fw_grad is not None
1879
+
1880
+ unpacked_arguments = ""
1881
+ for inp in differentiable_inputs:
1882
+ inp_name = inp.name
1883
+ is_input_tensorlist = is_foreach and is_tensor_list_type(
1884
+ inp.type
1885
+ if not inplace
1886
+ else refargname2inplace_foreacharg[inp.name].type
1887
+ )
1888
+ input_suffix = "[i]" if is_input_tensorlist else ""
1889
+ if is_inplace_foreach:
1890
+ if inp.name in refargname2inplace_foreacharg:
1891
+ inp_name = refargname2inplace_foreacharg[inp.name].name
1892
+ zeros_fn = (
1893
+ "zeros_symint"
1894
+ if inplace and inp.name == "self"
1895
+ else "_efficientzerotensor_symint"
1896
+ )
1897
+ if inp.name in derivative.required_inputs_fw_grad:
1898
+ unpacked_arguments += (
1899
+ FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
1900
+ inp_name=inp.name,
1901
+ inp=inp_name + input_suffix,
1902
+ zeros_fn=zeros_fn,
1903
+ )
1904
+ )
1905
+ if inp.name in (derivative.required_inputs_primal or []):
1906
+ unpacked_arguments += (
1907
+ FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
1908
+ inp_name=inp.name,
1909
+ inp=inp_name + input_suffix,
1910
+ )
1911
+ )
1912
+ if derivative.required_original_self_value:
1913
+ input_suffix = "s[i]" if is_inplace_foreach else ""
1914
+ unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
1915
+ inp_name="original_self",
1916
+ inp="original_self" + input_suffix,
1917
+ zeros_fn=zeros_fn,
1918
+ )
1919
+ unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
1920
+ inp_name="original_self",
1921
+ inp="original_self" + input_suffix,
1922
+ )
1923
+ elif inplace and derivative.is_reusing_outplace_formula:
1924
+ # The gradient wasn't already cloned, do it if grad mode is enabled
1925
+ unpacked_arguments += (
1926
+ "self_t = GradMode::is_enabled() ? self_t.clone() : self_t;"
1927
+ )
1928
+
1929
+ if inplace:
1930
+ is_inplace_str = "true"
1931
+ else:
1932
+ is_inplace_str = "false"
1933
+
1934
+ requires_fw_grad = get_any_has_forward_grad_name(derivative.var_names)
1935
+
1936
+ if all(
1937
+ (isinstance(var_type, BaseType) and var_type.is_tensor_like())
1938
+ for var_type in derivative.var_types
1939
+ ):
1940
+ # Is there a way to get from BaseType to BaseCType
1941
+ if len(derivative.var_types) == 1:
1942
+ opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
1943
+ if not is_foreach:
1944
+ fw_grad_setters.append(
1945
+ FW_DERIVATIVE_SETTER_TENSOR.substitute(
1946
+ out_arg=res[0], is_inplace=is_inplace_str
1947
+ )
1948
+ )
1949
+ else:
1950
+ assert res[0] == ("result" if not inplace else "self")
1951
+ fw_grad_setters.append(
1952
+ FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
1953
+ out_arg=res[0], is_inplace=is_inplace_str
1954
+ )
1955
+ )
1956
+ requires_fw_grad += f" && ({derivative.var_names[0]}.defined())"
1957
+ else:
1958
+ tuple_type = TupleCType(
1959
+ [BaseCType(tensorT)] * len(derivative.var_types)
1960
+ )
1961
+ opt_res_grad_type = OptionalCType(tuple_type).cpp_type()
1962
+ for idx, single_res in enumerate(res):
1963
+ fw_grad_setters.append(
1964
+ FW_DERIVATIVE_SETTER_MULTI_OUTPUT.substitute(
1965
+ idx=idx, all_res="_".join(res), out_arg=single_res
1966
+ )
1967
+ )
1968
+ elif (
1969
+ isinstance(derivative.var_types[0], ListType)
1970
+ and derivative.var_types[0].is_tensor_like()
1971
+ ):
1972
+ assert (
1973
+ len(derivative.var_types) == 1
1974
+ ), "Expected number of outputs to be 1 if function returns ListType"
1975
+ if not is_foreach:
1976
+ opt_res_grad_type = OptionalCType(
1977
+ VectorCType(BaseCType(tensorT))
1978
+ ).cpp_type()
1979
+ fw_grad_setters.append(
1980
+ FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
1981
+ out_arg=res[0], is_inplace=is_inplace_str
1982
+ )
1983
+ )
1984
+ else:
1985
+ # TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow?
1986
+ # Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml`
1987
+ # can reach here.
1988
+ opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
1989
+ fw_grad_setters.append(
1990
+ FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
1991
+ out_arg=res[0], is_inplace=is_inplace_str
1992
+ )
1993
+ )
1994
+ else:
1995
+ raise RuntimeError("Unsupported output type for forward derivative")
1996
+
1997
+ if not is_foreach:
1998
+ fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = ::std::nullopt;"
1999
+ # View ops create fw_grad that already is a view of the base's fw_grad so just use that
2000
+ content.append(
2001
+ FW_DERIVATIVE_TEMPLATE.substitute(
2002
+ fw_grad_opt_definition=fw_grad_opt_definition,
2003
+ requires_fw_grad=requires_fw_grad,
2004
+ formula=derivative.formula,
2005
+ out_arg="_".join(res),
2006
+ unpacked_arguments=unpacked_arguments,
2007
+ )
2008
+ )
2009
+ else:
2010
+ # note(crcrpar): Assuming `self` is TensorList.
2011
+ fw_grad_opt_definition = (
2012
+ f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts"
2013
+ "(self.size(), ::std::nullopt);"
2014
+ )
2015
+ foreach_forward_grad_formula = derivative.formula
2016
+ _foreach_arg: Argument | DifferentiableInput
2017
+ if inplace:
2018
+ for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
2019
+ # note(crcrpar): Massage only Scalar and ArrayRef<Scalar> here.
2020
+ if not (
2021
+ is_tensor_type(_foreach_arg.type)
2022
+ or is_tensor_list_type(_foreach_arg.type)
2023
+ ):
2024
+ pattern = _foreach_arg.name
2025
+ if isinstance(_foreach_arg.type, ListType):
2026
+ pattern += "[i]"
2027
+ foreach_forward_grad_formula = (
2028
+ foreach_forward_grad_formula.replace(
2029
+ _ref_arg.name, pattern
2030
+ )
2031
+ )
2032
+ else:
2033
+ if (
2034
+ "result" in foreach_forward_grad_formula
2035
+ and "result[i]" not in foreach_forward_grad_formula
2036
+ ):
2037
+ foreach_forward_grad_formula = (
2038
+ foreach_forward_grad_formula.replace("result", "result[i]")
2039
+ )
2040
+
2041
+ content.append(
2042
+ FW_DERIVATIVE_FOREACH_TEMPLATE.substitute(
2043
+ fw_grad_opt_definition=fw_grad_opt_definition,
2044
+ vector_of_optional_tensor=f"{'_'.join(res)}_new_fw_grad_opts",
2045
+ any_has_forward_grad_for_current_index=" || ".join(
2046
+ get_any_has_forward_grad_name(derivative.var_names) + "[i]"
2047
+ for derivative in fw_derivatives
2048
+ ),
2049
+ formula=foreach_forward_grad_formula,
2050
+ unpacked_arguments=unpacked_arguments,
2051
+ )
2052
+ )
2053
+
2054
+ # Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367
2055
+ content.append("\n".join(fw_grad_setters))
2056
+ return content
2057
+
2058
+ def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str:
2059
+ #
2060
+ # Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
2061
+ #
2062
+ if derivative is None:
2063
+ # (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs
2064
+ # - Used in the out_fn case when we want to forbid fw derivatives
2065
+ # - Used in the case where the fw_derivative is not defined, but we want
2066
+ # To check if there is a decomposition registered for jvp
2067
+ to_check: list[str] = []
2068
+ for inp in list(
2069
+ mapMaybe(
2070
+ gen_differentiable_input,
2071
+ f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator]
2072
+ )
2073
+ ):
2074
+ if is_tensor_type(inp.type):
2075
+ to_check.append(
2076
+ FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
2077
+ )
2078
+ elif is_tensor_list_type(inp.type):
2079
+ to_check.append(
2080
+ FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute(
2081
+ req_inp=inp.name
2082
+ )
2083
+ )
2084
+ else:
2085
+ raise RuntimeError(
2086
+ f'Unsupported input type for "{name}" when forbidding forward AD usage.'
2087
+ )
2088
+ return f'({" || ".join(to_check)})'
2089
+ else:
2090
+ # (2) If derivative is provided, use that information to determine which inputs
2091
+ # to check fw_grad for
2092
+ assert derivative.required_inputs_fw_grad is not None
2093
+
2094
+ if len(derivative.required_inputs_fw_grad) == 0:
2095
+ # Handle functions like stack
2096
+ # For these, we don't unpack anything and always call the user function
2097
+ if not (
2098
+ len(differentiable_inputs) == 1
2099
+ and is_tensor_list_type(differentiable_inputs[0].type)
2100
+ ):
2101
+ raise RuntimeError(
2102
+ f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
2103
+ "forward AD formula does not use any input tangent) even though a forward gradient "
2104
+ "formula has been defined for it. This case should only happen for function that "
2105
+ "take a single TensorList as input. All other cases are not supported right now."
2106
+ )
2107
+ any_has_fw_grad = "true"
2108
+ else:
2109
+ any_has_fw_grad = " || ".join(
2110
+ [
2111
+ (
2112
+ FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE
2113
+ if is_tensor_list_type(inp.type)
2114
+ else FW_DERIVATIVE_CHECK_TEMPLATE
2115
+ ).substitute(req_inp=inp.name)
2116
+ for inp in differentiable_inputs
2117
+ if inp.name in derivative.required_inputs_fw_grad
2118
+ ]
2119
+ )
2120
+ any_has_fw_grad = f"({any_has_fw_grad})"
2121
+
2122
+ return any_has_fw_grad
2123
+
2124
+ def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
2125
+ if is_out_fn:
2126
+ msg = "because it is an out= function"
2127
+ else:
2128
+ msg = (
2129
+ "because it has not been implemented yet.\\nPlease file an issue "
2130
+ "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
2131
+ "so that we can prioritize its implementation."
2132
+ )
2133
+ cond = get_any_has_fw_grad_cond(derivative=None)
2134
+ return (
2135
+ FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg)
2136
+ if cond != ""
2137
+ else ""
2138
+ )
2139
+
2140
+ body: list[str] = []
2141
+ unpack_args_stats, unpacked_bindings = unpack_args(f)
2142
+
2143
+ body.extend(unpack_args_stats)
2144
+ if requires_derivative:
2145
+ body.extend(emit_any_requires_grad())
2146
+ body.extend(emit_any_has_forward_grad())
2147
+ body.extend(emit_check_inplace())
2148
+ body.extend(emit_original_self_definition())
2149
+ body.extend(setup_derivative(differentiable_inputs))
2150
+
2151
+ body.append(emit_call(f, unpacked_bindings, try_jit_decomposition))
2152
+ if requires_derivative:
2153
+ # set_flags has to appear after version_counter, because rebase_history
2154
+ # requires that the counter is incremented before it is called
2155
+ body.append(emit_history())
2156
+ body.extend(emit_check_if_in_complex_autograd_allowlist())
2157
+
2158
+ if is_out_fn:
2159
+ body.append(emit_forbid_fw_derivatives(is_out_fn=True))
2160
+ else:
2161
+ if requires_derivative and not try_jit_decomposition:
2162
+ if len(fw_derivatives) > 0:
2163
+ body.extend(emit_fw_derivatives())
2164
+ else:
2165
+ body.append(emit_forbid_fw_derivatives())
2166
+
2167
+ if requires_derivative:
2168
+ # Save only after the forward AD has been set up
2169
+ body.append(emit_save_outputs())
2170
+
2171
+ if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR:
2172
+ # `inplace` implies that there is exactly one output named `self`,
2173
+ # so we can keep the generated code easy. If you need to
2174
+ # `reset_grad_accumulator` in an operator that's not `inplace`, you can
2175
+ # remove this assert but the code generation will get more elaborate
2176
+ assert inplace
2177
+ body.append("reset_grad_accumulator(self);")
2178
+ if not returns_void:
2179
+ body.append(f"return {get_return_value(f)};")
2180
+ return body
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/gen_view_funcs.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generates ViewFuncs.h/cpp
2
+ #
3
+ # NOTE: If any changes are being made to the ViewFunc codegen please also check
4
+ # if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
5
+ # The fallback is expected to mimic this codegen, so we should keep the two in sync.
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ import torchgen.api.dispatcher as dispatcher
12
+ from torchgen.api.translate import translate
13
+ from torchgen.api.types import (
14
+ BaseCType,
15
+ Binding,
16
+ NamedCType,
17
+ SymIntT,
18
+ tensorT,
19
+ VectorCType,
20
+ )
21
+ from torchgen.code_template import CodeTemplate
22
+ from torchgen.model import Argument, NativeFunction, OptionalType
23
+ from torchgen.utils import FileManager
24
+
25
+ from .gen_inplace_or_view_type import (
26
+ CALL_DISPATCH,
27
+ extract_bindings,
28
+ get_view_info,
29
+ modifies_arguments,
30
+ use_derived,
31
+ )
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
36
+
37
+
38
+ FUNCTION_DECLARATION = CodeTemplate(
39
+ """\
40
+ #define ${uppercase_op}_AVAILABLE
41
+ struct ${op} : public ${superclass} {
42
+ ${op}(${constructor_args}) ${initializer_list}
43
+ {};
44
+ virtual ~${op}() override {};
45
+ virtual std::vector<c10::SymInt> get_symints() const override;
46
+ virtual size_t num_symints() const override;
47
+ virtual std::vector<at::Tensor> get_tensors() const override;
48
+ virtual size_t num_tensors() const override;
49
+ virtual at::Tensor operator()(const at::Tensor&) const override;
50
+ virtual std::unique_ptr<ViewFunc> clone_and_set(
51
+ std::optional<std::vector<c10::SymInt>> = ::std::nullopt,
52
+ std::optional<std::vector<at::Tensor>> = ::std::nullopt) const override;
53
+
54
+ protected:
55
+ virtual void set_symints(std::vector<c10::SymInt>) override;
56
+ virtual void set_tensors(std::vector<at::Tensor>) override;
57
+
58
+ private:
59
+ ${state}
60
+ };
61
+
62
+ """
63
+ )
64
+
65
+ FUNCTION_DEFINITION = CodeTemplate(
66
+ """\
67
+ std::vector<c10::SymInt> ${op}::get_symints() const {
68
+ ${get_symints}
69
+ }
70
+
71
+ size_t ${op}::num_symints() const {
72
+ return static_cast<size_t>(${num_symints});
73
+ }
74
+
75
+ void ${op}::set_symints(std::vector<c10::SymInt> ${symints_vec}) {
76
+ TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints());
77
+ ${set_symints}
78
+ }
79
+
80
+ std::vector<at::Tensor> ${op}::get_tensors() const {
81
+ ${get_tensors}
82
+ }
83
+
84
+ size_t ${op}::num_tensors() const {
85
+ return static_cast<size_t>(${num_tensors});
86
+ }
87
+
88
+ void ${op}::set_tensors(std::vector<at::Tensor> ${tensors_vec}) {
89
+ TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors());
90
+ ${set_tensors}
91
+ }
92
+
93
+ at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const {
94
+ return ${op_call};
95
+ }
96
+
97
+ std::unique_ptr<ViewFunc> ${op}::clone_and_set(
98
+ std::optional<std::vector<c10::SymInt>> ${symints_vec},
99
+ std::optional<std::vector<at::Tensor>> ${tensors_vec}) const {
100
+ auto output = std::make_unique<${op}>(${clone_args});
101
+ if (${symints_vec}.has_value()) {
102
+ output->set_symints(std::move(*(${symints_vec})));
103
+ }
104
+ if (${tensors_vec}.has_value()) {
105
+ output->set_tensors(std::move(*(${tensors_vec})));
106
+ }
107
+ return output;
108
+ }
109
+
110
+ """
111
+ )
112
+
113
+
114
+ # e.g. as_strided -> AsStridedViewFunc for camel case or
115
+ # as_strided_view_func otherwise
116
+ def view_func_name(
117
+ f: NativeFunction, include_namespace: bool = False, camel_case: bool = True
118
+ ) -> str:
119
+ name = f.func.name.unambiguous_name()
120
+ view_func_name = f"{name.replace('.', '_')}_view_func"
121
+ if camel_case:
122
+ is_private = view_func_name.startswith("_")
123
+ view_func_name = "".join(
124
+ [p.title() for p in view_func_name.replace(".", "_").split("_")]
125
+ )
126
+ if is_private:
127
+ # put the leading underscore back in
128
+ view_func_name = f"_{view_func_name}"
129
+ namespace = "torch::autograd::generated::" if include_namespace else ""
130
+ return f"{namespace}{view_func_name}"
131
+
132
+
133
+ def is_symint_or_tensor(arg: Argument) -> bool:
134
+ return arg.type.is_tensor_like() or arg.type.is_symint_like()
135
+
136
+
137
+ def remove_const_ref(binding: Binding) -> Binding:
138
+ return Binding(
139
+ name=binding.name,
140
+ nctype=binding.nctype.remove_const_ref(),
141
+ argument=binding.argument,
142
+ default=binding.default,
143
+ )
144
+
145
+
146
+ def returns_multi_tensor(fn: NativeFunction) -> bool:
147
+ returns = fn.func.returns
148
+ assert len(returns) == 1
149
+ returns_list_like = returns[0].type.is_list_like() is not None
150
+ returns_tensor_like = returns[0].type.is_tensor_like()
151
+ return returns_list_like and returns_tensor_like
152
+
153
+
154
+ # Generates strings with logic for getting / setting state of a particular type.
155
+ #
156
+ # Args:
157
+ # bindings (list): List of state bindings of interest (may be empty)
158
+ # state_vec_type (NamedCType): Type of vector to either return or copy from
159
+ #
160
+ # Returns:
161
+ # tuple: (list of getter logic strings, list of setter logic strings, string
162
+ # with num items expression)
163
+ def generate_state_getter_setter(
164
+ bindings: list[Binding],
165
+ state_vec_type: NamedCType,
166
+ ) -> tuple[list[str], list[str], str]:
167
+ getter_logic = []
168
+ setter_logic = []
169
+
170
+ state_vec = state_vec_type.name
171
+ getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};")
172
+ if len(bindings) > 0:
173
+ setter_logic.append("auto i = 0;")
174
+
175
+ num_exprs = []
176
+ for i, b in enumerate(bindings):
177
+ assert isinstance(b.argument, Argument)
178
+ if b.argument.type.is_list_like():
179
+ # Handle list-likes.
180
+ num_expr = f"{b.name}.size()"
181
+ num_exprs.append(num_expr)
182
+ getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());"
183
+ setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());"
184
+ elif isinstance(b.argument.type, OptionalType):
185
+ # Handle optionals.
186
+ num_expr = f"({b.name}.has_value() ? 1 : 0)"
187
+ num_exprs.append(num_expr)
188
+ conditional = f"if({b.name}.has_value())"
189
+ getter = (
190
+ f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));"
191
+ )
192
+ setter = f"{conditional} {b.name} = {state_vec}[i];"
193
+ else:
194
+ num_expr = "1"
195
+ num_exprs.append(num_expr)
196
+ getter = f"{state_vec}.push_back({b.name});"
197
+ setter = f"{b.name} = {state_vec}[i];"
198
+
199
+ getter_logic.append(getter)
200
+ setter_logic.append(setter)
201
+ if i < len(bindings) - 1:
202
+ setter_logic.append(f"i += {num_expr};")
203
+
204
+ # Reserve / assert based on the total number of items expression.
205
+ num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs)
206
+ if len(bindings) > 0:
207
+ getter_logic.insert(1, f"{state_vec}.reserve({num_items});")
208
+
209
+ getter_logic.append(f"return {state_vec};")
210
+
211
+ return getter_logic, setter_logic, num_items
212
+
213
+
214
+ def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
215
+ bindings = extract_bindings(fn)
216
+ non_self_bindings = [b for b in bindings if b.name != "self"]
217
+
218
+ non_self_args = fn.func.arguments.flat_all[1:]
219
+ non_self_value_bindings = [
220
+ dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
221
+ ]
222
+
223
+ # Generate constructor / clone args for the generated struct.
224
+ constructor_args = [b.defn() for b in non_self_bindings]
225
+ clone_args = [b.name for b in non_self_bindings]
226
+
227
+ # Generate state variable declarations for the generated struct.
228
+ state_variables = [
229
+ f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings
230
+ ]
231
+
232
+ # Generate initializer list expressions for the generated struct.
233
+ # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as
234
+ # vector<SymInt>s.
235
+ init_exprs = translate(
236
+ non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True
237
+ )
238
+ initializers = []
239
+ for b, init_expr in zip(non_self_bindings, init_exprs):
240
+ name = b.nctype.name
241
+ assert isinstance(name, str)
242
+ initializers.append(f"{name}({init_expr.expr})")
243
+
244
+ # Generate call to underlying view op
245
+ call_input_name = "input_base"
246
+ op_call_args = [call_input_name, *(b.name for b in non_self_bindings)]
247
+ op_call = CALL_DISPATCH.substitute(
248
+ unambiguous_name=fn.func.name.unambiguous_name(),
249
+ unpacked_args=op_call_args,
250
+ )
251
+
252
+ # Multi-output views additionally require a view_idx for disambiguation.
253
+ if returns_multi_tensor(fn):
254
+ view_idx_name = "view_idx"
255
+ view_idx_typename = "int64_t"
256
+ view_idx_decl = f"{view_idx_typename} {view_idx_name}"
257
+ constructor_args.append(view_idx_decl)
258
+ clone_args.append(view_idx_name)
259
+ state_variables.append(f"{view_idx_decl};")
260
+ initializers.append(f"{view_idx_name}({view_idx_name})")
261
+ op_call += f"[{view_idx_name}]"
262
+
263
+ # Generate initializer list for the generated struct.
264
+ initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else ""
265
+
266
+ # Generate getter / setter logic for any symints.
267
+ symint_bindings = [
268
+ b
269
+ for b in non_self_bindings
270
+ if isinstance(b.argument, Argument) and b.argument.type.is_symint_like()
271
+ ]
272
+ symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT)))
273
+ get_symints, set_symints, num_symints = generate_state_getter_setter(
274
+ symint_bindings, symints_vec_type
275
+ )
276
+
277
+ # Generate getter / setter logic for any tensors.
278
+ tensor_bindings = [
279
+ b
280
+ for b in non_self_bindings
281
+ if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like()
282
+ ]
283
+ tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT)))
284
+ get_tensors, set_tensors, num_tensors = generate_state_getter_setter(
285
+ tensor_bindings, tensors_vec_type
286
+ )
287
+
288
+ return template.substitute(
289
+ op=view_func_name(fn),
290
+ uppercase_op=view_func_name(fn, camel_case=False).upper(),
291
+ superclass="torch::autograd::ViewFunc",
292
+ initializer_list=initializer_list,
293
+ state=state_variables,
294
+ constructor_args=constructor_args,
295
+ clone_args=clone_args,
296
+ symints_vec=symints_vec_type.name,
297
+ get_symints=get_symints,
298
+ set_symints=set_symints,
299
+ num_symints=num_symints,
300
+ tensors_vec=tensors_vec_type.name,
301
+ get_tensors=get_tensors,
302
+ set_tensors=set_tensors,
303
+ num_tensors=num_tensors,
304
+ call_input_name=call_input_name,
305
+ op_call=op_call,
306
+ )
307
+
308
+
309
+ def gen_view_funcs(
310
+ out: str,
311
+ fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
312
+ template_path: str,
313
+ ) -> None:
314
+ # don't need the info parts, just the function
315
+ fns = [fn.func for fn in fns_with_infos if use_derived(fn)]
316
+ # only want out-of-place views
317
+ view_fns = [
318
+ fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn)
319
+ ]
320
+
321
+ declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns]
322
+ definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns]
323
+ ops_headers = [f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in view_fns]
324
+
325
+ file_basename = "ViewFuncs"
326
+ fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
327
+ for suffix in [".h", ".cpp"]:
328
+ fname = file_basename + suffix
329
+ fm.write_with_template(
330
+ fname,
331
+ fname,
332
+ lambda: {
333
+ "generated_comment": "@"
334
+ + f"generated from {fm.template_dir_for_comments()}/"
335
+ + fname,
336
+ "view_func_declarations": declarations,
337
+ "view_func_definitions": definitions,
338
+ "ops_headers": ops_headers,
339
+ },
340
+ )
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/load_derivatives.py ADDED
@@ -0,0 +1,1014 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Parses derivatives.yaml into autograd functions
2
+ #
3
+ # Each autograd function is represented by `DifferentiabilityInfo` containing
4
+ # a list of `Derivative`. See `torchgen.api.autograd` for the data models.
5
+
6
+ from __future__ import annotations
7
+
8
+ import re
9
+ from collections import defaultdict
10
+ from typing import Any, Counter, Dict, Sequence, Set, Tuple
11
+
12
+ import yaml
13
+
14
+ from torchgen.api import cpp
15
+ from torchgen.api.autograd import (
16
+ Derivative,
17
+ DifferentiabilityInfo,
18
+ ForwardDerivative,
19
+ SavedAttribute,
20
+ )
21
+ from torchgen.api.types import (
22
+ BaseCType,
23
+ Binding,
24
+ boolT,
25
+ CppSignatureGroup,
26
+ layoutT,
27
+ longT,
28
+ NamedCType,
29
+ OptionalCType,
30
+ scalarTypeT,
31
+ SpecialArgName,
32
+ stringT,
33
+ symIntArrayRefT,
34
+ SymIntT,
35
+ tensorGeometryT,
36
+ tensorOptionsT,
37
+ typeAndSizeT,
38
+ VectorCType,
39
+ )
40
+ from torchgen.context import with_native_function
41
+ from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
42
+ from torchgen.model import (
43
+ AUTOGRAD_KEYS,
44
+ FunctionSchema,
45
+ NativeFunction,
46
+ NativeFunctionsViewGroup,
47
+ OperatorName,
48
+ SchemaKind,
49
+ Type,
50
+ Variant,
51
+ )
52
+ from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
53
+ from torchgen.yaml_utils import YamlLoader
54
+
55
+
56
+ DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
57
+
58
+ _GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
59
+
60
+ _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
61
+
62
+
63
+ # This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op.
64
+ # Since every {view} and {view}_copy op shares the same derivative formula,
65
+ # we generate them here instead of duplicating them in the yaml.
66
+ # See Note [Codegen'd {view}_copy Operators]
67
+ def add_view_copy_derivatives(
68
+ infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
69
+ view_groups: list[NativeFunctionsViewGroup],
70
+ ) -> None:
71
+ # Get the map from each view op's name to its corresponding view group
72
+ view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = {
73
+ g.view.func.name: g for g in view_groups
74
+ }
75
+
76
+ view_infos = {}
77
+
78
+ for info_dispatch_dict in infos.values():
79
+ # maybe_view_group only needs to be calculated once per info_dispatch_dict
80
+ maybe_view_group = None
81
+ view_copy_differentiability_infos = {}
82
+ for dispatch_key, info in info_dispatch_dict.items():
83
+ maybe_view_group = view_name_to_group.get(info.func.func.name, None)
84
+ if maybe_view_group is not None and maybe_view_group.view_copy is not None:
85
+ view_copy_info = info.create_view_copy_from_view_derivative(
86
+ maybe_view_group
87
+ )
88
+ if view_copy_info is not None:
89
+ fn_schema = view_copy_info.func.func
90
+ view_copy_differentiability_infos[dispatch_key] = view_copy_info
91
+ else:
92
+ break
93
+ # prefer manually-defined derivatives if any
94
+ if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
95
+ assert fn_schema is not None
96
+ view_infos[fn_schema] = view_copy_differentiability_infos
97
+
98
+ infos.update(view_infos)
99
+
100
+
101
+ def load_derivatives(
102
+ derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str
103
+ ) -> DerivativeRet:
104
+ # Do some caching as this is a deterministic function
105
+ global _GLOBAL_LOAD_DERIVATIVE_CACHE
106
+ key = (derivatives_yaml_path, native_yaml_path)
107
+ if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
108
+ with open(derivatives_yaml_path) as f:
109
+ definitions = yaml.load(f, Loader=YamlLoader)
110
+
111
+ funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
112
+ # From the parsed native functions, separate out the (generated) view_copy functions,
113
+ # so we can generate derivatives for them separately.
114
+ native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
115
+ native_functions = concatMap(
116
+ lambda g: [g]
117
+ if isinstance(g, NativeFunction)
118
+ else list(g.functions(include_copy=True)),
119
+ native_functions_with_view_groups,
120
+ )
121
+ view_groups = [
122
+ g
123
+ for g in native_functions_with_view_groups
124
+ if isinstance(g, NativeFunctionsViewGroup)
125
+ ]
126
+
127
+ # What's the difference between function schema v.s. signature?
128
+ # function schema is the complete declaration including mutability annotation / default value and etc.
129
+ # signature is the canonical schema for a group of functions (in-place/out/functional variants)
130
+ # that are semantically related.
131
+ functions_by_signature: dict[
132
+ FunctionSchema, list[NativeFunction]
133
+ ] = defaultdict(list)
134
+ functions_by_schema: dict[str, NativeFunction] = {}
135
+ for function in native_functions:
136
+ functions_by_signature[function.func.signature()].append(function)
137
+ assert str(function.func) not in functions_by_schema
138
+ functions_by_schema[str(function.func)] = function
139
+
140
+ # Keep track of how many of which ops we've seen so we can
141
+ # disambiguate them with a numeric suffix.
142
+ op_counter = Counter[str]()
143
+
144
+ # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
145
+ # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
146
+ # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
147
+ infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {}
148
+ used_dispatch_keys: set[str] = set()
149
+ for defn_dict in definitions:
150
+ # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
151
+ if "dispatch" not in defn_dict:
152
+ specification = defn_dict.pop("name")
153
+ output_differentiability = defn_dict.pop(
154
+ "output_differentiability", None
155
+ )
156
+ defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}}
157
+ if output_differentiability:
158
+ defn_dict["output_differentiability"] = output_differentiability
159
+ name, per_dispatch_diffinfos = create_differentiability_info(
160
+ defn_dict,
161
+ functions_by_signature,
162
+ functions_by_schema,
163
+ op_counter,
164
+ used_dispatch_keys,
165
+ )
166
+ infos[name] = per_dispatch_diffinfos
167
+
168
+ add_view_copy_derivatives(infos, view_groups)
169
+
170
+ # cache both loaded infos as well a a set of all the dispatch_keys/aliases
171
+ # that appear in derivatives.yaml. used_dispatch_keys is useful for generating
172
+ # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used
173
+ _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys
174
+
175
+ return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
176
+
177
+
178
+ # TODO: Why is this going through CppSignatureGroup, that doesn't make sense...
179
+ @with_native_function
180
+ def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
181
+ sigs = CppSignatureGroup.from_native_function(f, method=False)
182
+ if sigs.symint_signature is not None:
183
+ return sigs.symint_signature.arguments()
184
+ else:
185
+ return sigs.signature.arguments()
186
+
187
+
188
+ def create_derivative(
189
+ f: NativeFunction,
190
+ formula: str,
191
+ var_names: tuple[str, ...],
192
+ available_named_gradients: Sequence[str],
193
+ ) -> Derivative:
194
+ original_formula = formula
195
+ arguments: list[NamedCType] = [
196
+ a.nctype.remove_const_ref() for a in cpp_arguments(f)
197
+ ]
198
+
199
+ return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
200
+ return_types = tuple(
201
+ cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns
202
+ )
203
+
204
+ named_returns = [
205
+ NamedCType(name, type) for name, type in zip(return_names, return_types)
206
+ ]
207
+
208
+ formula, saved_inputs = saved_variables(formula, arguments, var_names)
209
+ formula, saved_outputs = saved_variables(formula, named_returns, var_names)
210
+
211
+ used_named_gradients = {
212
+ name
213
+ for name in available_named_gradients
214
+ if re.search(IDENT_REGEX.format(name), formula)
215
+ }
216
+
217
+ # Check that the referenced derivatives in the formula are in bounds
218
+ for i in used_gradient_indices(formula):
219
+ if i >= len(f.func.returns):
220
+ raise RuntimeError(
221
+ f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} "
222
+ f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs."
223
+ )
224
+
225
+ return Derivative(
226
+ formula=formula,
227
+ original_formula=original_formula,
228
+ var_names=var_names,
229
+ saved_inputs=saved_inputs,
230
+ saved_outputs=saved_outputs,
231
+ named_gradients=used_named_gradients,
232
+ )
233
+
234
+
235
+ def create_forward_derivative(
236
+ f: NativeFunction, formula: str, names: tuple[str, ...]
237
+ ) -> ForwardDerivative:
238
+ var_names = names
239
+ var_types: tuple[Type, ...] | None = None
240
+ for r in f.func.returns:
241
+ if r.name in var_names:
242
+ if var_types is None:
243
+ var_types = ()
244
+ var_types = var_types + (r.type,)
245
+
246
+ # Handle default return names
247
+ if var_types is None:
248
+ if var_names == ("result",):
249
+ assert len(f.func.returns) == 1
250
+ var_types = (f.func.returns[0].type,)
251
+ else:
252
+ for var_name in var_names:
253
+ res = re.findall(r"^result(\d+)$", var_name)
254
+ if len(res) == 1:
255
+ if var_types is None:
256
+ var_types = ()
257
+ arg_idx = int(res[0])
258
+ var_types = var_types + (f.func.returns[arg_idx].type,)
259
+
260
+ assert var_types is not None, "No matching output for forward derivative definition"
261
+ return ForwardDerivative(
262
+ formula=formula,
263
+ var_names=var_names,
264
+ var_types=var_types,
265
+ required_inputs_fw_grad=None,
266
+ required_inputs_primal=None,
267
+ required_original_self_value=False,
268
+ is_reusing_outplace_formula=False,
269
+ )
270
+
271
+
272
+ def postprocess_forward_derivatives(
273
+ f: NativeFunction,
274
+ defn_name: str,
275
+ all_arg_names: list[str],
276
+ derivatives: list[Derivative],
277
+ forward_derivatives: list[ForwardDerivative],
278
+ args_with_derivatives: Sequence[Binding],
279
+ ) -> list[ForwardDerivative]:
280
+ def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]:
281
+ is_foreach = f.func.name.name.base.startswith("_foreach_")
282
+ required_inputs = set()
283
+ for arg in args_with_derivatives:
284
+ if (
285
+ arg.type in ("at::TensorList", "const at::ITensorListRef &")
286
+ and not is_foreach
287
+ ):
288
+ # The functions taking TensorList handle everything internally
289
+ continue
290
+ arg_name = arg.name
291
+
292
+ found = re.search(IDENT_REGEX.format(arg_name), formula)
293
+ if found:
294
+ raise RuntimeError(
295
+ f"The forward formula for {defn_name} is using the base name of the {arg_name} "
296
+ f"argument which is ambiguous. You should use {arg_name}_p to access the primal "
297
+ f"value and {arg_name}_t to access the tangent."
298
+ )
299
+
300
+ found = re.search(IDENT_REGEX.format(arg_name + postfix), formula)
301
+ if found:
302
+ required_inputs.add(arg_name)
303
+
304
+ return tuple(required_inputs)
305
+
306
+ updated_derivatives: list[ForwardDerivative] = []
307
+
308
+ for defn in forward_derivatives:
309
+ formula = defn.formula
310
+ required_inputs_tangent = find_required_inputs(formula, "_t")
311
+ if formula == "auto_element_wise":
312
+ assert (
313
+ f.func.kind() != SchemaKind.inplace
314
+ ), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant"
315
+ if (
316
+ (not len(args_with_derivatives) == 1)
317
+ or len(forward_derivatives) > 1
318
+ or len(forward_derivatives[0].var_names) > 1
319
+ ):
320
+ raise RuntimeError(
321
+ f"Derivative definition of {defn_name} in derivatives.yaml defines the "
322
+ "forward definition of gradient as element_wise but this only "
323
+ "works for functions with a single differentiable input and a "
324
+ "single differentiable output."
325
+ )
326
+ if not len(derivatives) == 1:
327
+ raise RuntimeError(
328
+ f"Derivative definition of {defn_name} in derivatives.yaml defines the "
329
+ "forward definition of gradient as element_wise but it does not "
330
+ "defines the gradient formula for its argument which is required."
331
+ )
332
+ # This transformation is based on the observation that for element-wise functions, the Jacobian
333
+ # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions)
334
+ # For the complex case, we use hermitian transpose and get (v.conj() J).conj()
335
+ # So here we are going to re-use the backward formula and replace two things:
336
+ # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input.
337
+ # 2) all usage of an original input "foo" with its primal value "foo_p".
338
+ # 3) conjugate the final result
339
+ # For example, for abs, the backward formula is:
340
+ # grad * self.sgn()
341
+ # And this function generates a forward formula that is:
342
+ # (self_t.conj() * self_p.sgn()).conj()
343
+
344
+ backward_formula = derivatives[0].original_formula
345
+ input_name = args_with_derivatives[0].name
346
+
347
+ # Do replacement 1) of the grad
348
+ def repl(m: Any) -> str:
349
+ return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}"
350
+
351
+ fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula)
352
+
353
+ # Do replacement 2) of the input variables
354
+ for arg in args_with_derivatives:
355
+ arg_name = arg.name
356
+
357
+ def repl(m: Any) -> str:
358
+ return f"{m.group(1)}{arg_name}_p{m.group(2)}"
359
+
360
+ fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula)
361
+
362
+ # Do the final conjugate 3)
363
+ fw_formula = f"({fw_formula}).conj()"
364
+
365
+ # Since there is a single differentiable inputs and we necessarily need its tangent we can
366
+ # simply require all differentiable input's tangent.
367
+ required_inputs_tangent = tuple(all_arg_names)
368
+ formula = fw_formula
369
+ elif formula == "auto_linear":
370
+ if (
371
+ len(forward_derivatives) > 1
372
+ or len(forward_derivatives[0].var_names) > 1
373
+ ):
374
+ raise RuntimeError(
375
+ f"Derivative definition of {defn_name} in derivatives.yaml defines the "
376
+ "forward definition of gradient as linear but this only works "
377
+ "for functions with a single differentiable output."
378
+ )
379
+ # This transformation is based on the observation that linear functions can be written as:
380
+ # y = f(x) = A * x
381
+ # For some matrix A and the Jacobian of the function f is also A.
382
+ # So doing J * v = A * v = f(v).
383
+ # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x.
384
+ # We do this by calling the forward again by replacing any occurrence of the differentiable
385
+ # input "foo" by it's tangent "foo_t".
386
+ # Note that multiple inputs are not a problem as long as the function is truly linear wrt to
387
+ # the vector where all the differentiable inputs are stacked.
388
+
389
+ diff_arg_names = [arg.name for arg in args_with_derivatives]
390
+ assert len(diff_arg_names) > 0
391
+
392
+ # Do replacement of input variables
393
+ new_args = []
394
+ for arg_name in all_arg_names:
395
+ if arg_name in diff_arg_names:
396
+ arg_name = arg_name + "_t"
397
+ new_args.append(arg_name)
398
+
399
+ # TODO we are trolling
400
+ if f.func.has_symint():
401
+ defn_name += "_symint"
402
+
403
+ # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.
404
+ if Variant.function in f.variants:
405
+ fw_formula = f"at::{defn_name}({', '.join(new_args)})"
406
+ else:
407
+ assert Variant.method in f.variants
408
+ fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})"
409
+
410
+ # All of the input tangents are always used so all of them are required here.
411
+ required_inputs_tangent = tuple(diff_arg_names)
412
+ formula = fw_formula
413
+
414
+ # At this point, the formula is final and is not modified anymore.
415
+
416
+ # During forward formula, we use the primal instead of the input Tensors.
417
+ # This call inspects the formula to find for which input's primal are used.
418
+ required_inputs_primal = find_required_inputs(formula, "_p")
419
+
420
+ updated_derivatives.append(
421
+ ForwardDerivative(
422
+ formula=formula,
423
+ var_names=defn.var_names,
424
+ var_types=defn.var_types,
425
+ required_inputs_fw_grad=required_inputs_tangent,
426
+ required_inputs_primal=required_inputs_primal,
427
+ required_original_self_value=False,
428
+ is_reusing_outplace_formula=False,
429
+ )
430
+ )
431
+
432
+ return updated_derivatives
433
+
434
+
435
+ def is_forward_derivative_definition(
436
+ all_arg_names: list[str], names: tuple[str, ...]
437
+ ) -> bool:
438
+ for name in names:
439
+ return name not in all_arg_names
440
+ raise RuntimeError("Expected `names` to be non-empty")
441
+
442
+
443
+ def create_differentiability_info(
444
+ defn_dict: dict[Any, Any],
445
+ functions_by_signature: dict[FunctionSchema, list[NativeFunction]],
446
+ functions_by_schema: dict[str, NativeFunction],
447
+ op_counter: Counter[str],
448
+ used_dispatch_keys: set[str],
449
+ ) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]:
450
+ """Processes a single entry `defn` in derivatives.yaml"""
451
+
452
+ def canonical_function(
453
+ functions: Sequence[NativeFunction], name: str
454
+ ) -> NativeFunction:
455
+ for f in functions:
456
+ if (
457
+ not f.func.is_functional_fn()
458
+ and not f.func.is_out_fn()
459
+ and name == str(f.func.name.name)
460
+ ):
461
+ return f
462
+ # some functions only have in-place variants
463
+ assert name + "_" == cpp.name(functions[0].func)
464
+ return functions[0]
465
+
466
+ def split_names(raw_names: str) -> tuple[str, ...]:
467
+ """Given "foo, bar", return ["foo", "bar"]."""
468
+ return tuple(x.strip() for x in raw_names.split(","))
469
+
470
+ def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None:
471
+ """
472
+ Check for some subtle mistakes one might make when writing derivatives.
473
+ These mistakes will compile, but will be latent until a function is
474
+ used with double backwards.
475
+ """
476
+
477
+ uses_grad = False # true if any derivative uses "grad"
478
+ num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]"
479
+ uses_named_grads = False # true if any derivative uses "grad_{name}"
480
+ used_grads_indices: list[int] = [] # which indices of grads are used
481
+ for d in derivatives:
482
+ formula = d.formula
483
+ uses_grad = uses_grad or bool(
484
+ re.findall(IDENT_REGEX.format("grad"), formula)
485
+ )
486
+ num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula))
487
+ uses_named_grads = uses_named_grads or bool(d.named_gradients)
488
+ used_grads_indices.extend(used_gradient_indices(formula))
489
+ # This is a basic sanity check: the number of places we see
490
+ # "grads" should be no fewer than the number of indices we see
491
+ # inside "grads". They may not be equal because we may use
492
+ # "grads" without an index.
493
+ assert num_grads_uses >= len(used_grads_indices)
494
+ # Thus if the number is equal, every use of grads is also
495
+ # indexed.
496
+ only_used_grads_indices = num_grads_uses == len(used_grads_indices)
497
+
498
+ if uses_grad and num_grads_uses > 0:
499
+ raise RuntimeError(
500
+ f"Derivative definition of {defn_name} in derivatives.yaml illegally "
501
+ "mixes use of 'grad' and 'grads'. Consider replacing "
502
+ "occurrences of 'grad' with 'grads[0]'"
503
+ )
504
+
505
+ if only_used_grads_indices and set(used_grads_indices) == {0}:
506
+ raise RuntimeError(
507
+ f"Derivative definition of {defn_name} in derivatives.yaml solely "
508
+ "refers to 'grads[0]'. If the first output is indeed the "
509
+ "only differentiable output, replace 'grads[0]' with 'grad'; "
510
+ "otherwise, there is a likely error in your derivatives "
511
+ "declaration."
512
+ )
513
+
514
+ if uses_named_grads and (uses_grad or num_grads_uses > 0):
515
+ raise RuntimeError(
516
+ f"Derivative definition of {defn_name} in derivatives.yaml illegally "
517
+ 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
518
+ "only one method for identifying gradients."
519
+ )
520
+
521
+ @with_native_function
522
+ def set_up_derivatives(
523
+ f: NativeFunction,
524
+ ) -> tuple[
525
+ Sequence[Derivative],
526
+ Sequence[ForwardDerivative],
527
+ Sequence[Binding],
528
+ Sequence[str],
529
+ Sequence[str],
530
+ ]:
531
+ # Set up the derivative information
532
+ derivatives: list[Derivative] = []
533
+ forward_derivatives: list[ForwardDerivative] = []
534
+ non_differentiable_arg_names: list[str] = []
535
+ args_with_derivatives_set: set[str] = set()
536
+
537
+ all_arg_names = [a.name for a in cpp_arguments(f)]
538
+ all_ret_names = [
539
+ r.name for r in f.func.returns
540
+ ] # only used for the assert below
541
+ # output_differentiability is captured from the enclosed
542
+ # scope. Don't modify it.
543
+ #
544
+ # If it is not present, then no output is explicitly
545
+ # undifferentiable.
546
+ #
547
+ # It may be present and shorter than the length of return
548
+ # values. If that's the case, any return value that does not
549
+ # have a corresponding entry is considered not differentiable.
550
+ differentiability = output_differentiability or [True] * len(f.func.returns)
551
+ # A return is available as a named gradient ...
552
+ available_named_gradients = [
553
+ f"grad_{ret.name}"
554
+ for ret, differentiable in zip(f.func.returns, differentiability)
555
+ # if it has not been explicitly made undifferentiable
556
+ if differentiable
557
+ # and if it has a name
558
+ and ret.name is not None
559
+ # and if its type is differentiable
560
+ and ret.type.is_tensor_like()
561
+ ]
562
+
563
+ for raw_names in sorted(defn.keys()):
564
+ formula = defn[raw_names]
565
+ names = split_names(raw_names)
566
+
567
+ for name in names:
568
+ assert not (name in all_arg_names and name in all_ret_names), (
569
+ f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
570
+ f"expected '{name}' to not be both an input arg and named return. "
571
+ )
572
+
573
+ if is_forward_derivative_definition(all_arg_names, names):
574
+ forward_derivatives.append(create_forward_derivative(f, formula, names))
575
+ else:
576
+ if formula.lower().strip() == "non_differentiable":
577
+ non_differentiable_arg_names += names
578
+ else:
579
+ derivative = create_derivative(
580
+ f, formula, names, available_named_gradients
581
+ )
582
+ derivatives.append(derivative)
583
+ args_with_derivatives_set |= set(names)
584
+
585
+ overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names)
586
+ if overlap:
587
+ raise RuntimeError(
588
+ f"derivatives definition for {defn} have overlapped non_differentiable "
589
+ f"and differentiable variables: {overlap}"
590
+ )
591
+
592
+ # Next, let us determine the list of inputs in order.
593
+ # TODO: do we need eagerly calculate and save it here? Can it be derived
594
+ # from NativeFunction and `derivatives` on callsites instead?
595
+ args_with_derivatives = [
596
+ a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
597
+ ]
598
+
599
+ # Postprocess forward derivatives definitions now that we know the differentiable arguments
600
+ forward_derivatives = postprocess_forward_derivatives(
601
+ f,
602
+ defn_name,
603
+ all_arg_names,
604
+ derivatives,
605
+ forward_derivatives,
606
+ args_with_derivatives,
607
+ )
608
+
609
+ # Test to see if the use of 'grads' makes sense.
610
+ check_grad_usage(defn_name, derivatives)
611
+
612
+ return (
613
+ derivatives,
614
+ forward_derivatives,
615
+ args_with_derivatives,
616
+ non_differentiable_arg_names,
617
+ available_named_gradients,
618
+ )
619
+
620
+ # NB: Removes 'name' from defn dictionary
621
+ specification = defn_dict.pop("name")
622
+ defn_name, _ = split_name_params(specification)
623
+ # NB: Removes 'output_differentiability' from defn dictionary
624
+ # `None` means all differentiable.
625
+ output_differentiability = defn_dict.pop("output_differentiability", None)
626
+ output_differentiability_conditions = None
627
+ if output_differentiability and any(
628
+ isinstance(diff, str) for diff in output_differentiability
629
+ ):
630
+ if len(output_differentiability) != 1:
631
+ raise RuntimeError(
632
+ f"Not supported: for {specification},"
633
+ f"output_differentiability must either be "
634
+ f"List[bool] or a List[str] where each str is a "
635
+ f"condition. In the case where it is a condition, "
636
+ f"we only support single-output functions. "
637
+ f"Please file us an issue. "
638
+ )
639
+ output_differentiability_conditions = output_differentiability
640
+ output_differentiability = [True]
641
+
642
+ schema_function = functions_by_schema.get(specification)
643
+ if not schema_function:
644
+ avail = "\n".join(
645
+ k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name
646
+ )
647
+ raise RuntimeError(
648
+ f"could not find ATen function for schema: {specification} "
649
+ f". Available signatures:\n{avail}"
650
+ )
651
+
652
+ # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
653
+ # to map in-place schemas to the out-of-place variants.
654
+ # TODO: maybe the logic to handle the legacy schema is no longer necessary?
655
+ signature = schema_function.func.signature()
656
+ functions = functions_by_signature[signature]
657
+ if len(functions) == 0:
658
+ avail = "\n".join(
659
+ str(k)
660
+ for k, v in functions_by_signature.items()
661
+ if cpp.name(k) == defn_name
662
+ )
663
+ raise RuntimeError(
664
+ f"could not find ATen function for legacy signature: {signature} "
665
+ f"corresponding to schema {specification}. Please report a bug to PyTorch. "
666
+ f"Available signatures:\n{avail}"
667
+ )
668
+
669
+ canonical = canonical_function(functions, defn_name)
670
+ if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
671
+ raise RuntimeError(
672
+ f"Schema for {defn_name} has an argument named grad_input_mask, "
673
+ "but this name would be shadowed by our codegen. "
674
+ "Please use a different name in native_functions.yaml."
675
+ )
676
+
677
+ if "result" in (a.name for a in cpp_arguments(canonical)):
678
+ raise RuntimeError(
679
+ f"Schema for {defn_name} has an argument named result, "
680
+ "but this is only allowed for outputs."
681
+ "Please use a different name in native_functions.yaml."
682
+ )
683
+
684
+ diffinfo_dict = {}
685
+ for key, defn in defn_dict["dispatch"].items():
686
+ if key != "Default" and key not in _VALID_AUTOGRAD_KEYS:
687
+ raise RuntimeError(
688
+ f"Invalid dispatch key {key} in derivatives.yaml for {specification},"
689
+ f" expected key to be one of {_VALID_AUTOGRAD_KEYS}"
690
+ )
691
+ if key not in used_dispatch_keys:
692
+ used_dispatch_keys.add(key)
693
+
694
+ (
695
+ derivatives,
696
+ forward_derivatives,
697
+ args_with_derivatives,
698
+ non_differentiable_arg_names,
699
+ available_named_gradients,
700
+ ) = set_up_derivatives(canonical)
701
+
702
+ used_named_gradients: set[str] = set()
703
+ for d in derivatives:
704
+ used_named_gradients |= d.named_gradients
705
+
706
+ # only assign an op name if we are actually going to calculate a derivative
707
+ op = None
708
+ if args_with_derivatives:
709
+ op_prefix = _create_op_prefix(defn_name)
710
+ if key != "Default":
711
+ op_prefix = op_prefix + key
712
+ op = f"{op_prefix}{op_counter[op_prefix]}"
713
+ op_counter[op_prefix] += 1
714
+
715
+ diffinfo_dict[key] = DifferentiabilityInfo(
716
+ name=defn_name,
717
+ func=canonical,
718
+ op=op,
719
+ derivatives=derivatives,
720
+ forward_derivatives=forward_derivatives,
721
+ all_saved_inputs=dedup_vars(
722
+ [v for d in derivatives for v in d.saved_inputs]
723
+ ),
724
+ all_saved_outputs=dedup_vars(
725
+ [v for d in derivatives for v in d.saved_outputs]
726
+ ),
727
+ available_named_gradients=available_named_gradients,
728
+ used_named_gradients=used_named_gradients,
729
+ args_with_derivatives=args_with_derivatives,
730
+ non_differentiable_arg_names=non_differentiable_arg_names,
731
+ output_differentiability=output_differentiability,
732
+ output_differentiability_conditions=output_differentiability_conditions,
733
+ )
734
+
735
+ return canonical.func, diffinfo_dict
736
+
737
+
738
+ GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
739
+
740
+
741
+ def used_gradient_indices(formula: str) -> list[int]:
742
+ """Determine a list of gradient indices (the i in grads[i]) that
743
+ are used by the formula.
744
+
745
+ >>> used_gradient_indices("foo(grads[0], grads[1])")
746
+ [0, 1]
747
+ """
748
+ return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)]
749
+
750
+
751
+ def saved_variables(
752
+ formula: str,
753
+ nctypes: list[NamedCType],
754
+ var_names: tuple[str, ...],
755
+ ) -> tuple[str, tuple[SavedAttribute, ...]]:
756
+ def stride_expr(name: str) -> str:
757
+ assert var_names == (name,), (
758
+ 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
759
+ 'that ".strides()" is being called on.'
760
+ )
761
+ return f'strides_or_error({name}, "{name}")'
762
+
763
+ REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [
764
+ # replace self.sym_sizes() with self_sym_sizes
765
+ (
766
+ r"{}.sym_sizes\(\)",
767
+ {
768
+ "suffix": "_sym_sizes",
769
+ "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
770
+ },
771
+ ),
772
+ # replace self->sym_sizes() with self_sym_sizes_opt
773
+ (
774
+ r"{}->sym_sizes\(\)",
775
+ {
776
+ "suffix": "_sym_sizes_opt",
777
+ "nctype": lambda name: NamedCType(
778
+ name, OptionalCType(BaseCType(symIntArrayRefT))
779
+ ),
780
+ "expr": lambda name: f"{name}.has_value() ? std::optional<c10::SymIntArrayRef>({name}->sym_sizes()) : std::nullopt",
781
+ },
782
+ ),
783
+ # replace self.sym_blocksize() with self_sym_blocksize_opt
784
+ (
785
+ r"{}.sym_blocksize\(\)",
786
+ {
787
+ "suffix": "_self_sym_blocksize_opt",
788
+ "nctype": lambda name: NamedCType(
789
+ name, OptionalCType(BaseCType(symIntArrayRefT))
790
+ ),
791
+ "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})",
792
+ },
793
+ ),
794
+ # replace self.options() with self_options
795
+ (
796
+ r"{}.options\(\)",
797
+ {
798
+ "suffix": "_options",
799
+ "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
800
+ },
801
+ ),
802
+ # replace zeros_like(self) with self_info
803
+ (
804
+ r"zeros_like\({}\)",
805
+ {
806
+ "suffix": "_info",
807
+ "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
808
+ "expr": lambda name: name, # at save-time
809
+ "res": lambda name: name + "_info.zeros()", # at eval-time
810
+ },
811
+ ),
812
+ # replace self.sym_size(2) with self_sym_size_2
813
+ (
814
+ r"{}.sym_size\((-?\w+)\)",
815
+ {
816
+ "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}",
817
+ "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
818
+ },
819
+ ),
820
+ # replace self.numel() with self_numel
821
+ (
822
+ r"{}.numel\(\)",
823
+ {
824
+ "suffix": "_numel",
825
+ "nctype": lambda name: NamedCType(name, BaseCType(longT)),
826
+ },
827
+ ),
828
+ # replace self.sym_numel() with self_sym_numel
829
+ (
830
+ r"{}.sym_numel\(\)",
831
+ {
832
+ "suffix": "_sym_numel",
833
+ "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
834
+ },
835
+ ),
836
+ # replace to_args_sizes(self) with self_args_sizes
837
+ (
838
+ r"to_args_sizes\({}\)",
839
+ {
840
+ "suffix": "_args_sizes",
841
+ "nctype": lambda name: NamedCType(
842
+ name, VectorCType(VectorCType(BaseCType(longT)))
843
+ ),
844
+ },
845
+ ),
846
+ # replace to_args_sizes_symint(self) with self_args_sizes
847
+ (
848
+ r"to_args_sizes_symint\({}\)",
849
+ {
850
+ "suffix": "_args_sizes_symint",
851
+ "nctype": lambda name: NamedCType(
852
+ name, VectorCType(VectorCType(BaseCType(SymIntT)))
853
+ ),
854
+ },
855
+ ),
856
+ # replace to_args_scalartypes(self) with self_args_scalartypes
857
+ (
858
+ r"to_args_scalartypes\({}\)",
859
+ {
860
+ "suffix": "_args_scalartypes",
861
+ "nctype": lambda name: NamedCType(
862
+ name, VectorCType(BaseCType(scalarTypeT))
863
+ ),
864
+ },
865
+ ),
866
+ # replace TensorGeometry(self) with self_geometry
867
+ (
868
+ r"TensorGeometry\({}\)",
869
+ {
870
+ "suffix": "_geometry",
871
+ "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
872
+ },
873
+ ),
874
+ (
875
+ r"{}.scalar_type\(\)",
876
+ {
877
+ "suffix": "_scalar_type",
878
+ "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)),
879
+ },
880
+ ),
881
+ # replace self.dim() with self_dim
882
+ (
883
+ r"{}.dim\(\)",
884
+ {
885
+ "suffix": "_dim",
886
+ "nctype": lambda name: NamedCType(name, BaseCType(longT)),
887
+ },
888
+ ),
889
+ # replace self.sym_strides() with self_sym_strides
890
+ (
891
+ r"{}.sym_strides\(\)",
892
+ {
893
+ "suffix": "_sym_strides",
894
+ "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
895
+ "expr": stride_expr,
896
+ },
897
+ ),
898
+ # replace self.layout() with self_layout
899
+ (
900
+ r"{}.layout\(\)",
901
+ {
902
+ "suffix": "_layout",
903
+ "nctype": lambda name: NamedCType(name, BaseCType(layoutT)),
904
+ },
905
+ ),
906
+ # replace self.is_conj() with self_conjugate
907
+ (
908
+ r"{}.is_conj\(\)",
909
+ {
910
+ "suffix": "_conjugate",
911
+ "nctype": lambda name: NamedCType(name, BaseCType(boolT)),
912
+ },
913
+ ),
914
+ ]
915
+
916
+ # find which arguments need to be saved
917
+ saved: list[SavedAttribute] = []
918
+
919
+ if ".sizes()" in formula or "->sizes()" in formula:
920
+ raise RuntimeError(
921
+ ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version,"
922
+ + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}"
923
+ )
924
+ if re.search(r"\.size\([-]?\d+\)", formula) or re.search(
925
+ r"->size\([-]?\d+\)", formula
926
+ ):
927
+ raise RuntimeError(
928
+ ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version,"
929
+ + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}"
930
+ )
931
+ if ".strides()" in formula or "->strides()" in formula:
932
+ raise RuntimeError(
933
+ ".strides() is not supported in derivative formulas. Instead, please use the SymInt version,"
934
+ + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
935
+ )
936
+ for nctype in nctypes:
937
+ name = (
938
+ nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
939
+ )
940
+ # First search the formula for expressions which can be evaluated
941
+ # when the autograd Function is created to avoid saving variables
942
+ for regex, info in REPLACEMENTS:
943
+
944
+ def repl(m: re.Match[str]) -> str:
945
+ suffix: str = (
946
+ info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
947
+ )
948
+ expr: str = info["expr"](name) if "expr" in info else m.group(0)
949
+ saved.append(
950
+ SavedAttribute(
951
+ nctype=info["nctype"](name + suffix),
952
+ expr=expr,
953
+ )
954
+ )
955
+ if "res" in info:
956
+ replacement: str = info["res"](name)
957
+ return replacement
958
+ return name + suffix
959
+
960
+ formula = re.sub(regex.format(name), repl, formula)
961
+
962
+ # std::optional<std::string> types stored in Backward nodes must be
963
+ # converted to std::optional<std::string_view> before being passed into
964
+ # the backward function
965
+ if nctype.type == OptionalCType(BaseCType(stringT)):
966
+ formula = re.sub(
967
+ rf"\b{name}\b",
968
+ f"{name}.has_value() ? std::optional<c10::string_view>({name}.value()) : std::nullopt",
969
+ formula,
970
+ )
971
+
972
+ # Find any variables which remain in the formula and save them
973
+ if re.search(IDENT_REGEX.format(name), formula):
974
+ saved.append(
975
+ SavedAttribute(
976
+ nctype=nctype,
977
+ expr=name,
978
+ )
979
+ )
980
+
981
+ return formula, tuple(saved)
982
+
983
+
984
+ def _create_op_prefix(name: str) -> str:
985
+ """Takes a native function name converts to a op prefix name.
986
+
987
+ Note that the "name" parameter must be the native function name
988
+ without the optional variant suffix, so "add" instead of
989
+ "add.out".
990
+
991
+ OP names correspond to classes, hence the change to title case.
992
+
993
+ Example::
994
+ >>> _create_op_prefix('add')
995
+ 'AddBackward'
996
+ """
997
+ camel_case = "".join([p.title() for p in name.split("_")])
998
+ return (camel_case + "Backward").replace("ForwardBackward", "Backward")
999
+
1000
+
1001
+ def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
1002
+ seen: set[str] = set()
1003
+ saved: list[SavedAttribute] = []
1004
+ for var in vars:
1005
+ name = (
1006
+ var.nctype.name.name
1007
+ if isinstance(var.nctype.name, SpecialArgName)
1008
+ else var.nctype.name
1009
+ )
1010
+ if name in seen:
1011
+ continue
1012
+ seen.add(name)
1013
+ saved.append(var)
1014
+ return saved
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2
+ #include "torch/csrc/autograd/VariableTypeUtils.h"
3
+ #include "torch/csrc/autograd/generated/ViewFuncs.h"
4
+
5
+ #include <torch/library.h>
6
+ #include <ATen/FunctionalInverses.h>
7
+ #include <ATen/FunctionalTensorWrapper.h>
8
+
9
+ // ${generated_comment}
10
+
11
+ #ifndef AT_PER_OPERATOR_HEADERS
12
+ #include <ATen/Operators.h>
13
+ #else
14
+ $ops_headers
15
+ #endif
16
+
17
+ using namespace at;
18
+ using torch::autograd::CreationMeta;
19
+ using torch::autograd::as_view;
20
+ using torch::autograd::increment_version;
21
+
22
+ namespace torch {
23
+
24
+ namespace ADInplaceOrView {
25
+
26
+ namespace {
27
+ ${inplace_or_view_method_definitions}
28
+ } // namespace
29
+ } // namespace ADInplaceOrView
30
+
31
+ namespace {
32
+
33
+ TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) {
34
+ ${inplace_or_view_wrapper_registrations};
35
+ }
36
+
37
+ } // namespace
38
+ } // namespace torch
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.cpp ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "torch/csrc/autograd/FunctionsManual.h"
2
+ #include "torch/csrc/dynamo/compiled_autograd.h"
3
+
4
+ // ${generated_comment}
5
+
6
+ // The manual function definitions that used to be here are now in torch/csrc/autograd/FunctionsManual.cpp
7
+ // This speeds up re-compilation and allow to share these implementations so that they can be
8
+ // used for forward mode AD formulas as well.
9
+
10
+ using namespace torch::autograd::generated::details;
11
+ using at::Tensor;
12
+ using at::Scalar;
13
+ using at::IntArrayRef;
14
+ using at::TensorList;
15
+
16
+ namespace torch::autograd::generated {
17
+
18
+ ${autograd_function_definitions}
19
+
20
+ } // namespace torch::autograd::generated
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/Functions.h ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // ${generated_comment}
4
+
5
+ #include <ATen/ATen.h>
6
+ #include <ATen/core/functional.h>
7
+ #include <ATen/TensorGeometry.h>
8
+
9
+ #include "torch/csrc/autograd/function.h"
10
+ #include "torch/csrc/autograd/variable.h"
11
+ #include "torch/csrc/autograd/saved_variable.h"
12
+ #include <torch/csrc/Export.h>
13
+
14
+ #include <c10/core/SymIntArrayRef.h>
15
+
16
+ namespace torch { namespace autograd { namespace generated {
17
+
18
+ using at::Scalar;
19
+ using at::Tensor;
20
+ using at::IntArrayRef;
21
+ using at::ArrayRef;
22
+ using at::Type;
23
+ using at::TensorGeometry;
24
+ using at::ScalarType;
25
+ using std::optional;
26
+ using c10::fmap;
27
+
28
+ inline std::vector<Tensor> unpack_list(at::ArrayRef<SavedVariable> xs, std::shared_ptr<Node> saved_for = nullptr) {
29
+ // NB: we must explicitly do the conversion in the lambda, otherwise template
30
+ // deduction will give a Tensor of Variable which is not convertible
31
+ return fmap(xs, [&saved_for](const SavedVariable& x) {
32
+ // TODO(crcrpar): Use `std::move(saved_for)` to avoid incrementing refcount, which would need refactoring.
33
+ return static_cast<Tensor>(x.unpack(saved_for));
34
+ });
35
+ }
36
+
37
+ inline c10::List<std::optional<Tensor>> unpack_opt_list(at::ArrayRef<SavedVariable> xs, std::shared_ptr<Node> saved_for = nullptr) {
38
+ torch::List<std::optional<Tensor>> result;
39
+ result.reserve(xs.size());
40
+ for (const SavedVariable& v : xs) {
41
+ auto var = v.unpack(saved_for);
42
+ result.push_back(var.defined() ? std::optional<Tensor>(var) : ::std::nullopt);
43
+ }
44
+ return result;
45
+ }
46
+
47
+ using torch::autograd::TypeAndSize;
48
+
49
+ ${autograd_function_declarations}
50
+
51
+ }}} // namespace torch::autograd::generated
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2
+ #include "torch/csrc/jit/frontend/tracer.h"
3
+
4
+ #include <torch/library.h>
5
+
6
+ #include "torch/csrc/autograd/function.h"
7
+
8
+ #include "ATen/quantized/Quantizer.h"
9
+
10
+ // ${generated_comment}
11
+
12
+ // See the `Tracer` section in `torch/csrc/jit/OVERVIEW.md`.
13
+ // NOTE See [Sharded File] comment in VariableType
14
+
15
+ #ifndef AT_PER_OPERATOR_HEADERS
16
+ #include <ATen/Operators.h>
17
+ #else
18
+ $ops_headers
19
+ #endif
20
+
21
+ using namespace at;
22
+
23
+ namespace torch {
24
+
25
+ namespace TraceType {
26
+
27
+ namespace {
28
+ ${trace_method_definitions}
29
+ } // namespace
30
+ } // namespace TraceType
31
+
32
+ namespace {
33
+
34
+ TORCH_LIBRARY_IMPL(aten, Tracer, m) {
35
+ ${trace_wrapper_registrations};
36
+ }
37
+
38
+ } // namespace
39
+
40
+ } // namespace torch
.venv/lib/python3.11/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "torch/csrc/autograd/VariableTypeUtils.h"
2
+ #include "torch/csrc/autograd/generated/VariableType.h"
3
+ #include "torch/csrc/autograd/FunctionsManual.h"
4
+
5
+ #include <ATen/RedispatchFunctions.h>
6
+ #include <c10/core/impl/TorchDispatchModeTLS.h>
7
+ #include <ATen/core/TorchDispatchUtils.h>
8
+ #include <torch/library.h>
9
+
10
+ #include <ATen/SparseCsrTensorUtils.h>
11
+
12
+
13
+ // ${generated_comment}
14
+
15
+ // NOTE [Sharded File]: on this file's split-into-shards state
16
+ //
17
+ // Back in the good old days, VariableType.cpp was generated as one
18
+ // file with every function in it, and everything was great and
19
+ // simple.
20
+ //
21
+ // However, this file was also very large (over 36,000 lines), and
22
+ // compiling it was very slow, and in fact was a significant
23
+ // bottleneck for incremental rebuilds. To address this, we now
24
+ // generate the file split across multiple shards, named
25
+ // VariableType_0.cpp and so on, which can be compiled in parallel.
26
+ //
27
+ // For ease of inspection and debugging, so that it's not necessary to
28
+ // go rooting around in multiple files, we also generate all the
29
+ // functions together in VariableTypeEverything.cpp. This generated
30
+ // file is only for convenience; it's not actually used in the
31
+ // build. If the file you're looking at now is one of the shards, you
32
+ // may want to switch over to the Everything variant to make you
33
+ // grepping smoother.
34
+
35
+ using namespace at;
36
+ using namespace torch::autograd::generated;
37
+ using namespace torch::autograd::generated::details;
38
+
39
+
40
+ namespace torch::autograd {
41
+
42
+ namespace VariableType {
43
+ namespace{
44
+ C10_UNUSED void reset_grad_accumulator(Variable & self) {
45
+ AutogradMeta* meta = torch::autograd::impl::get_autograd_meta(self);
46
+ if (meta != nullptr) {
47
+ meta->grad_accumulator_.reset();
48
+ }
49
+ }
50
+ }
51
+
52
+ namespace {
53
+
54
+
55
+ ${type_derived_method_definitions}
56
+ }
57
+ }
58
+
59
+ namespace {
60
+
61
+ ${wrapper_registrations}
62
+
63
+ }
64
+
65
+ } // namespace torch::autograd