ZAIDX11 commited on
Commit
28780eb
·
verified ·
1 Parent(s): aafbbaf

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. archive/.venv/Lib/site-packages/torch/_C/_cudnn.pyi +14 -0
  2. archive/.venv/Lib/site-packages/torch/_C/_cusparselt.pyi +1 -0
  3. archive/.venv/Lib/site-packages/torch/_C/_distributed_autograd.pyi +26 -0
  4. archive/.venv/Lib/site-packages/torch/_higher_order_ops/executorch_call_delegate.py +175 -0
  5. archive/.venv/Lib/site-packages/torch/_higher_order_ops/flat_apply.py +125 -0
  6. archive/.venv/Lib/site-packages/torch/_higher_order_ops/flex_attention.py +1268 -0
  7. archive/.venv/Lib/site-packages/torch/_higher_order_ops/foreach_map.py +23 -0
  8. archive/.venv/Lib/site-packages/torch/_higher_order_ops/hints_wrap.py +142 -0
  9. archive/.venv/Lib/site-packages/torch/_higher_order_ops/invoke_subgraph.py +658 -0
  10. archive/.venv/Lib/site-packages/torch/_higher_order_ops/map.py +291 -0
  11. archive/.venv/Lib/site-packages/torch/_higher_order_ops/out_dtype.py +163 -0
  12. archive/.venv/Lib/site-packages/torch/_higher_order_ops/run_const_graph.py +60 -0
  13. archive/.venv/Lib/site-packages/torch/_higher_order_ops/scan.py +929 -0
  14. archive/.venv/Lib/site-packages/torch/_higher_order_ops/schema.py +306 -0
  15. archive/.venv/Lib/site-packages/torch/_higher_order_ops/strict_mode.py +108 -0
  16. archive/.venv/Lib/site-packages/torch/_higher_order_ops/torchbind.py +164 -0
  17. archive/.venv/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py +2051 -0
  18. archive/.venv/Lib/site-packages/torch/_higher_order_ops/utils.py +1134 -0
  19. archive/.venv/Lib/site-packages/torch/_higher_order_ops/while_loop.py +420 -0
  20. archive/.venv/Lib/site-packages/torch/_higher_order_ops/wrap.py +286 -0
archive/.venv/Lib/site-packages/torch/_C/_cudnn.pyi ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import IntEnum
2
+
3
+ # Defined in torch/csrc/cuda/shared/cudnn.cpp
4
+ is_cuda: bool
5
+
6
+ def getRuntimeVersion() -> tuple[int, int, int]: ...
7
+ def getCompileVersion() -> tuple[int, int, int]: ...
8
+ def getVersionInt() -> int: ...
9
+
10
+ class RNNMode(IntEnum):
11
+ rnn_relu = ...
12
+ rnn_tanh = ...
13
+ lstm = ...
14
+ gru = ...
archive/.venv/Lib/site-packages/torch/_C/_cusparselt.pyi ADDED
@@ -0,0 +1 @@
 
 
1
+ def getVersionInt() -> int: ...
archive/.venv/Lib/site-packages/torch/_C/_distributed_autograd.pyi ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ # This module is defined in torch/csrc/distributed/autograd/init.cpp
6
+
7
+ class DistAutogradContext:
8
+ def _context_id(self) -> int: ...
9
+ def _recv_functions(self) -> dict[int, Any]: ...
10
+ def _send_functions(self) -> dict[int, Any]: ...
11
+ def _known_worker_ids(self) -> set[int]: ...
12
+
13
+ def _new_context() -> DistAutogradContext: ...
14
+ def _release_context(context_id: int) -> None: ...
15
+ def _get_max_id() -> int: ...
16
+ def _is_valid_context(worker_id: int) -> bool: ...
17
+ def _retrieve_context(context_id: int) -> DistAutogradContext: ...
18
+ def _current_context() -> DistAutogradContext: ...
19
+ def _init(worker_id: int) -> None: ...
20
+ def _get_debug_info() -> dict[str, str]: ...
21
+ def backward(
22
+ context_id: int,
23
+ roots: list[torch.Tensor],
24
+ retain_graph: bool = False,
25
+ ) -> None: ...
26
+ def get_gradients(context_id: int) -> dict[torch.Tensor, torch.Tensor]: ...
archive/.venv/Lib/site-packages/torch/_higher_order_ops/executorch_call_delegate.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ # pyre-strict
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any, cast
14
+
15
+ import torch
16
+ import torch.utils._pytree as pytree
17
+ from torch._ops import HigherOrderOperator
18
+ from torch._subclasses.fake_tensor import FakeTensorMode
19
+ from torch.fx.experimental.proxy_tensor import (
20
+ disable_proxy_modes_tracing,
21
+ get_proxy_slot,
22
+ ProxyTorchDispatchMode,
23
+ track_tensor_tree,
24
+ )
25
+ from torch.utils._pytree import tree_flatten
26
+
27
+
28
+ class ExecutorchCallDelegate(HigherOrderOperator):
29
+ def __init__(self):
30
+ super().__init__("executorch_call_delegate")
31
+
32
+ def __call__(self, lowered_module, *args):
33
+ return super().__call__(lowered_module, *args)
34
+
35
+
36
+ executorch_call_delegate = ExecutorchCallDelegate()
37
+ executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
38
+ executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
39
+ executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)
40
+ executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
41
+
42
+ LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
43
+
44
+
45
+ # pyre-ignore
46
+ def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
47
+ # pyre-ignore
48
+ def _unwrap_proxy(e):
49
+ if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
50
+ return e
51
+ return get_proxy_slot(
52
+ cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy # type: ignore[attr-defined]
53
+ )
54
+
55
+ if not is_lowered_module(lowered_module):
56
+ raise ValueError(
57
+ "executorch_call_delegate()'s first argument must be a LoweredBackendModule"
58
+ )
59
+
60
+ with disable_proxy_modes_tracing():
61
+ out = call_delegate_cpu(lowered_module, *args)
62
+
63
+ get_lowered_module_name(proxy_mode.tracer.root, lowered_module)
64
+
65
+ node_args = (lowered_module, *args)
66
+ proxy_args = pytree.tree_map(_unwrap_proxy, node_args)
67
+ out_proxy = proxy_mode.tracer.create_proxy(
68
+ "call_function", func_overload, proxy_args, {}, name="executorch_call_delegate"
69
+ )
70
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
71
+
72
+
73
+ @executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
74
+ # pyre-ignore
75
+ def call_delegate_cpu(lowered_module, *args):
76
+ # FX creates this immutable_dict/list concept. Get rid of this.
77
+ map_types: dict[type, type] = {
78
+ torch.fx.immutable_collections.immutable_dict: dict,
79
+ torch.fx.immutable_collections.immutable_list: list,
80
+ }
81
+ new_args = pytree.tree_map_only(
82
+ tuple(map_types.keys()),
83
+ lambda a: map_types[type(a)](a),
84
+ args,
85
+ lambda a: isinstance(a, tuple(map_types.keys())),
86
+ )
87
+ return lowered_module.original_module.module()(*new_args)
88
+
89
+
90
+ @executorch_call_delegate.py_autograd_impl
91
+ # pyre-ignore
92
+ def call_delegate_autograd(lowered_module, *args):
93
+ # TODO: support autograd
94
+ flat_operands, _ = tree_flatten([lowered_module, *args])
95
+ requires_grad = any(
96
+ f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
97
+ )
98
+
99
+ with torch._C._ExcludeDispatchKeyGuard(
100
+ torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
101
+ ):
102
+ res = executorch_call_delegate(lowered_module, *args)
103
+
104
+ if requires_grad:
105
+ # Create aliases of the output that has requires_grad=True. We need
106
+ # at least one of the inputs to err_fn to require grad so that the
107
+ # output will have a grad_fn.
108
+
109
+ # pyre-ignore
110
+ def fake_requires_grad(var):
111
+ if var is not None:
112
+ var = var.detach()
113
+ if torch.is_floating_point(var) or torch.is_complex(var):
114
+ var.requires_grad = True
115
+ return var
116
+
117
+ return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res)
118
+
119
+ return res
120
+
121
+
122
+ @executorch_call_delegate.py_impl(ProxyTorchDispatchMode)
123
+ # pyre-ignore
124
+ def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args):
125
+ res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args)
126
+ return res
127
+
128
+
129
+ @executorch_call_delegate.py_impl(FakeTensorMode)
130
+ # pyre-ignore
131
+ def call_delegate_fake_tensor_mode(mode, lowered_module, *args):
132
+ with mode:
133
+ return call_delegate_cpu(lowered_module, *args)
134
+
135
+
136
+ @executorch_call_delegate.py_functionalize_impl
137
+ # pyre-ignore
138
+ def call_delegate_functionalize(ctx, lowered_module, *args):
139
+ unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
140
+ with ctx.redispatch_to_next():
141
+ res = executorch_call_delegate(lowered_module, *unwrapped_args)
142
+ return ctx.wrap_tensors(res)
143
+
144
+
145
+ # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre
146
+ def is_lowered_module(obj: Any) -> bool:
147
+ """
148
+ This function is added to avoid using isinstance(obj,
149
+ LoweredBackendModule) as it will import LoweredBackendModule, which may
150
+ cause a circular import.
151
+ """
152
+ return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE
153
+
154
+
155
+ def get_lowered_module_name(
156
+ root: torch.nn.Module,
157
+ # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
158
+ lowered_module: LOWERED_BACKEND_MODULE_TYPE, # type: ignore[valid-type]
159
+ ) -> str:
160
+ """
161
+ Adds the given lowered_module into the given root module and returns the
162
+ name of the module added.
163
+ """
164
+ # Find a qualifying name for the lowered submodule
165
+ qualname = None
166
+ i = 0
167
+ while True:
168
+ qualname = f"lowered_module_{i}"
169
+ if not hasattr(root, qualname):
170
+ break
171
+ i += 1
172
+ assert qualname is not None
173
+
174
+ root.add_module(qualname, lowered_module)
175
+ return qualname
archive/.venv/Lib/site-packages/torch/_higher_order_ops/flat_apply.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from dataclasses import dataclass
3
+ from typing import Callable
4
+
5
+ import torch
6
+ import torch.fx.node
7
+ import torch.utils._pytree as pytree
8
+ from torch._ops import HigherOrderOperator
9
+
10
+
11
+ def is_graphable(val) -> bool:
12
+ """Definition: a graphable type is a type that that is an acceptable input/output type to a FX node."""
13
+ return isinstance(val, torch.fx.node.base_types)
14
+
15
+
16
+ def is_graphable_type(typ) -> bool:
17
+ """Return whether the given type is graphable"""
18
+ return issubclass(typ, torch.fx.node.base_types)
19
+
20
+
21
+ def to_graphable(stuff):
22
+ """Flattens stuff into a flat list of graphable types."""
23
+ # We can consider preserving things like List[int] to improve
24
+ # perf and readability (right now that is all flattened out)
25
+ flat_args, spec = pytree.tree_flatten(stuff)
26
+ for arg in flat_args:
27
+ if not is_graphable(arg):
28
+ raise RuntimeError(
29
+ f"Expected all pytree.tree_leaves of (args, kwargs) to be graphable types, but found "
30
+ f"non-fx-graphable type {type(arg)}. If this type is meant to be constant, mark it as "
31
+ f"via pytree.register_constant; otherwise, register it as a pytree."
32
+ )
33
+ return flat_args, spec
34
+
35
+
36
+ def from_graphable(flat_args, spec):
37
+ """The inverse of to_graphable."""
38
+ stuff = pytree.tree_unflatten(flat_args, spec)
39
+ return stuff
40
+
41
+
42
+ def func_to_graphable(func):
43
+ """
44
+ Pack and flatten a function type into graphable types.
45
+ This is useful for legalizing the function argument of `flat_apply`.
46
+ """
47
+ return pytree.tree_flatten(_ConstantFunction(func))
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class _ConstantFunction:
52
+ func: Callable
53
+
54
+ def __call__(self, *args, **kwargs):
55
+ return self.func(*args, **kwargs)
56
+
57
+
58
+ pytree.register_constant(_ConstantFunction)
59
+
60
+ _op_types = (
61
+ torch._ops.OpOverload,
62
+ torch._ops.OpOverloadPacket,
63
+ torch._ops.HigherOrderOperator,
64
+ )
65
+
66
+
67
+ class FlatApply(HigherOrderOperator):
68
+ def __init__(self) -> None:
69
+ super().__init__("flat_apply")
70
+
71
+ def __call__(self, func, in_spec, *flat_args, **_unused):
72
+ """
73
+ Functions that take in non-graphable types cannot directly be put into FX graph.
74
+
75
+ Given func(*args, **kwargs), if all of the non-graphable types are pytrees,
76
+ then we're able to store a call to flat_apply(func, in_spec, *flat_args) in the FX graph.
77
+
78
+ The semantics of flat_apply(func, in_spec, *flat_args) are roughly equivalent to:
79
+
80
+ >>> def flat_apply_impl(func, in_spec, *flat_args):
81
+ >>> args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
82
+ >>> output = func(*args, **kwargs)
83
+ >>> return output
84
+
85
+ flat_apply supports the following two cases:
86
+ - an input type is a container type (e.g. of tensors) registered as a pytree.
87
+ We'll tree_flatten the input type and store the spec.
88
+ - an input type is a constant type (i.e. torch.compile will specialize on it)
89
+ registered with pytree.register_constant. The constant type goes directly
90
+ into the spec.
91
+
92
+ """
93
+ assert isinstance(func, _op_types) or pytree._is_constant_holder(func)
94
+ assert len(_unused) == 0
95
+ return impl(func, in_spec, *flat_args)
96
+
97
+
98
+ def impl(func, in_spec, *flat_args):
99
+ if not isinstance(func, _op_types):
100
+ # assume _ConstantFunction
101
+ func = pytree._retrieve_constant(func)
102
+ assert isinstance(func, _ConstantFunction)
103
+
104
+ args, kwargs = from_graphable(flat_args, in_spec)
105
+ out = func(*args, **kwargs)
106
+
107
+ # Right now, all outputs must either be graphable or lists/tuples of graphables.
108
+ #
109
+ # TODO: The following can be updated to support non-graphable outputs and pytrees.
110
+ # For non-graphable constant outputs: the assumption would be that they are constant
111
+ # (everytime the function runs those MUST be the same)
112
+ # For pytree outputs:
113
+ # I'm not sure if we need to return (flat_output, spec) or just (flat_output,):
114
+ # in the latter case the tracers need to carry out the output specs
115
+ # (they need to know how to reconstruct the object from just the flat_output).
116
+ def is_valid_output(x):
117
+ if isinstance(x, (tuple, list)):
118
+ return all(map(is_valid_output, x))
119
+ return is_graphable(x)
120
+
121
+ assert is_valid_output(out)
122
+ return out
123
+
124
+
125
+ flat_apply = FlatApply()
archive/.venv/Lib/site-packages/torch/_higher_order_ops/flex_attention.py ADDED
@@ -0,0 +1,1268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Sequence
3
+ from typing import Any, Callable, Optional, Union
4
+
5
+ import torch
6
+ import torch.utils._pytree as pytree
7
+ from torch import Tensor
8
+ from torch._C import DispatchKey
9
+ from torch._higher_order_ops.utils import (
10
+ _has_potential_branch_input_mutation,
11
+ _maybe_reenter_make_fx,
12
+ autograd_not_implemented,
13
+ has_user_subclass,
14
+ redirect_to_mode,
15
+ reenter_make_fx,
16
+ register_fake,
17
+ save_tensors_and_symints_for_backward,
18
+ saved_tensors_and_symints,
19
+ UnsupportedAliasMutationException,
20
+ validate_subgraph_args_types,
21
+ )
22
+ from torch._ops import HigherOrderOperator
23
+ from torch._subclasses import FakeTensor
24
+ from torch._subclasses.functional_tensor import FunctionalTensor
25
+ from torch.fx.experimental.proxy_tensor import (
26
+ make_fx,
27
+ ProxyTorchDispatchMode,
28
+ track_tensor_tree,
29
+ )
30
+ from torch.fx.graph_module import GraphModule
31
+ from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode
32
+
33
+
34
+ # Duplicate of _inductor/kernel/flex_attention.py to avoid circular import
35
+ def _construct_strides(
36
+ sizes: Sequence[int],
37
+ fill_order: Sequence[int],
38
+ ) -> Sequence[int]:
39
+ """From a list of sizes and a fill order, construct the strides of the permuted tensor."""
40
+ # Initialize strides
41
+ assert len(sizes) == len(
42
+ fill_order
43
+ ), "Length of sizes must match the length of the fill order"
44
+ strides = [0] * len(sizes)
45
+
46
+ # Start with stride 1 for the innermost dimension
47
+ current_stride = 1
48
+
49
+ # Iterate through the fill order populating strides
50
+ for dim in fill_order:
51
+ strides[dim] = current_stride
52
+ current_stride *= sizes[dim]
53
+
54
+ return strides
55
+
56
+
57
+ def _permute_strides(out: torch.Tensor, query_strides: tuple[int, ...]) -> torch.Tensor:
58
+ """
59
+ Create a new tensor with the same data and shape as the input,
60
+ but with strides permuted based on the input tensor's stride order.
61
+
62
+ Args:
63
+ out (torch.Tensor): The output tensor of attention.
64
+ query_strides (List[int]): The stride order of the input query tensor
65
+
66
+ Returns:
67
+ torch.Tensor: A new tensor with same shape and data as the input,
68
+ but with strides permuted based on the query tensor's stride order.
69
+ """
70
+ from torch._inductor.ir import get_fill_order
71
+
72
+ fill_order = get_fill_order(query_strides)
73
+ assert out.storage_offset() == 0, "Only support storage_offset == 0"
74
+ out_strides = _construct_strides(out.shape, fill_order)
75
+ new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides)
76
+ new_out.copy_(out)
77
+ return new_out
78
+
79
+
80
+ class FlexAttentionHOP(HigherOrderOperator):
81
+ def __init__(self) -> None:
82
+ super().__init__("flex_attention", cacheable=True)
83
+
84
+ def __call__(
85
+ self,
86
+ query: torch.Tensor,
87
+ key: torch.Tensor,
88
+ value: torch.Tensor,
89
+ score_mod: Callable,
90
+ block_mask: tuple,
91
+ scale: float,
92
+ kernel_options: dict[str, Any],
93
+ score_mod_other_buffers: tuple = (),
94
+ mask_mod_other_buffers: tuple = (),
95
+ ) -> tuple[torch.Tensor, torch.Tensor]:
96
+ validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
97
+ return super().__call__(
98
+ query,
99
+ key,
100
+ value,
101
+ score_mod,
102
+ block_mask,
103
+ scale,
104
+ kernel_options,
105
+ score_mod_other_buffers,
106
+ mask_mod_other_buffers,
107
+ )
108
+
109
+
110
+ flex_attention = FlexAttentionHOP()
111
+
112
+
113
+ class FlexAttentionBackwardHOP(HigherOrderOperator):
114
+ def __init__(self) -> None:
115
+ super().__init__("flex_attention_backward")
116
+
117
+ def __call__(
118
+ self,
119
+ query: torch.Tensor,
120
+ key: torch.Tensor,
121
+ value: torch.Tensor,
122
+ out: torch.Tensor,
123
+ logsumexp: torch.Tensor,
124
+ grad_out: torch.Tensor,
125
+ grad_logsumexp: torch.Tensor,
126
+ fw_graph: Union[Callable, GraphModule],
127
+ joint_graph: GraphModule,
128
+ block_mask: tuple,
129
+ scale: float,
130
+ kernel_options: dict[str, Any],
131
+ score_mod_other_buffers: tuple = (),
132
+ mask_mod_other_buffers: tuple = (),
133
+ ) -> tuple[
134
+ torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
135
+ ]:
136
+ validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
137
+ return super().__call__(
138
+ query,
139
+ key,
140
+ value,
141
+ out,
142
+ logsumexp,
143
+ grad_out,
144
+ grad_logsumexp,
145
+ fw_graph,
146
+ joint_graph,
147
+ block_mask,
148
+ scale,
149
+ kernel_options,
150
+ score_mod_other_buffers,
151
+ mask_mod_other_buffers,
152
+ )
153
+
154
+
155
+ flex_attention_backward = FlexAttentionBackwardHOP()
156
+
157
+
158
+ def _math_attention_inner(
159
+ query: torch.Tensor,
160
+ key: torch.Tensor,
161
+ value: torch.Tensor,
162
+ score_mod: Callable,
163
+ block_mask: tuple,
164
+ scale: float,
165
+ kernel_options: dict[str, Any],
166
+ score_mod_other_buffers: tuple = (),
167
+ mask_mod_other_buffers: tuple = (),
168
+ ) -> tuple[torch.Tensor, torch.Tensor]:
169
+ from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
170
+
171
+ working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32
172
+
173
+ scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)
174
+
175
+ b = torch.arange(0, scores.size(0), device=scores.device)
176
+ h = torch.arange(0, scores.size(1), device=scores.device)
177
+ m = torch.arange(0, scores.size(2), device=scores.device)
178
+ n = torch.arange(0, scores.size(3), device=scores.device)
179
+
180
+ captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
181
+ from torch.nn.attention.flex_attention import _vmap_for_bhqkv
182
+
183
+ # first input is score
184
+ score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim)
185
+
186
+ mask_mod = block_mask[-1]
187
+ mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers)
188
+ mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers)
189
+
190
+ with TransformGetItemToIndex():
191
+ scores = (scores * scale).to(working_precision)
192
+ post_mod_scores = torch.where(
193
+ mask_mod(b, h, m, n, *mask_mod_other_buffers),
194
+ score_mod(scores, b, h, m, n, *score_mod_other_buffers),
195
+ torch.tensor(-float("inf"), dtype=working_precision, device=scores.device),
196
+ )
197
+
198
+ return scores, post_mod_scores
199
+
200
+
201
+ def math_attention(
202
+ query: torch.Tensor,
203
+ key: torch.Tensor,
204
+ value: torch.Tensor,
205
+ score_mod: Callable,
206
+ block_mask: tuple,
207
+ scale: float,
208
+ kernel_options: dict[str, Any],
209
+ score_mod_other_buffers: tuple = (),
210
+ mask_mod_other_buffers: tuple = (),
211
+ ) -> tuple[torch.Tensor, torch.Tensor]:
212
+ """Eager implementation
213
+
214
+ This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions.
215
+ We then apply the vectorized score_mod function to the scores matrix. Each wrap of vmap applies one of the
216
+ batch, head, m, or n dimensions. We need to apply vmap 4 times to vectorized over all 4 dimensions.
217
+
218
+ Args:
219
+ query: The query tensor
220
+ key: The key tensor
221
+ value: The value tensor
222
+ score_mod: The score_mod function
223
+ other_buffers: Other buffers that are passed to the score_mod function
224
+ """
225
+ # broadcast query & key along head dim for GQA
226
+ G = query.size(1) // key.size(1)
227
+ value = torch.repeat_interleave(value, G, dim=1)
228
+ key = torch.repeat_interleave(key, G, dim=1)
229
+
230
+ Bq, Bkv = query.size(0), key.size(0)
231
+ if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)):
232
+ raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}")
233
+
234
+ key = key.expand((Bq, *key.size()[1:]))
235
+ value = value.expand((Bq, *value.size()[1:]))
236
+
237
+ _, post_mod_scores = _math_attention_inner(
238
+ query,
239
+ key,
240
+ value,
241
+ score_mod,
242
+ block_mask,
243
+ scale,
244
+ kernel_options,
245
+ score_mod_other_buffers,
246
+ mask_mod_other_buffers,
247
+ )
248
+
249
+ # Set fully masked rows' sumexp to 0.0
250
+ logsumexp = post_mod_scores.logsumexp(dim=-1)
251
+ masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1)
252
+ logsumexp = torch.where(masked_rows, -float("inf"), logsumexp)
253
+
254
+ post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1)
255
+
256
+ return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)
257
+
258
+
259
+ @flex_attention.py_impl(DispatchKey.CompositeExplicitAutograd)
260
+ def sdpa_dense(
261
+ query: torch.Tensor,
262
+ key: torch.Tensor,
263
+ value: torch.Tensor,
264
+ score_mod: Callable,
265
+ block_mask: tuple,
266
+ scale: float,
267
+ kernel_options: dict[str, Any],
268
+ score_mod_other_buffers: tuple = (),
269
+ mask_mod_other_buffers: tuple = (),
270
+ ) -> tuple[torch.Tensor, torch.Tensor]:
271
+ out, lse = math_attention(
272
+ query,
273
+ key,
274
+ value,
275
+ score_mod,
276
+ block_mask,
277
+ scale,
278
+ kernel_options,
279
+ score_mod_other_buffers,
280
+ mask_mod_other_buffers,
281
+ )
282
+ out = _permute_strides(out, query.stride())
283
+ return out, lse
284
+
285
+
286
+ def trace_flex_attention(
287
+ proxy_mode: ProxyTorchDispatchMode,
288
+ query: torch.Tensor,
289
+ key: torch.Tensor,
290
+ value: torch.Tensor,
291
+ score_mod: Callable,
292
+ block_mask: tuple,
293
+ scale: float,
294
+ kernel_options: dict[str, Any],
295
+ score_mod_other_buffers: tuple = (),
296
+ mask_mod_other_buffers: tuple = (),
297
+ ) -> tuple[torch.Tensor, torch.Tensor]:
298
+ """Traces the flex_attention operator with the given score_mod function and other_buffers.
299
+
300
+ Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function
301
+ This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We
302
+ access this graph module in inductor to inline the score_mod function to the triton template.
303
+ """
304
+ from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
305
+
306
+ example_out = flex_attention(
307
+ query,
308
+ key,
309
+ value,
310
+ score_mod,
311
+ block_mask,
312
+ scale,
313
+ kernel_options,
314
+ score_mod_other_buffers,
315
+ mask_mod_other_buffers,
316
+ )
317
+ example_vals = [query.new_zeros((), requires_grad=query.requires_grad)] + [
318
+ query.new_zeros((), dtype=torch.int) for _ in range(4)
319
+ ]
320
+ mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)]
321
+ mask_mod = block_mask[-1]
322
+ with TransformGetItemToIndex():
323
+ score_graph = reenter_make_fx(score_mod)(
324
+ *example_vals, *score_mod_other_buffers
325
+ )
326
+ mask_graph = reenter_make_fx(mask_mod)(
327
+ *mask_example_vals, *mask_mod_other_buffers
328
+ )
329
+ assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
330
+ block_mask = block_mask[:-1] + (mask_graph,)
331
+ qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_score")
332
+ proxy_mode.tracer.root.register_module(qualname, score_graph)
333
+ mask_qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_mask")
334
+ proxy_mode.tracer.root.register_module(mask_qualname, mask_graph)
335
+ node_args = (
336
+ query,
337
+ key,
338
+ value,
339
+ score_graph,
340
+ block_mask,
341
+ scale,
342
+ kernel_options,
343
+ score_mod_other_buffers,
344
+ mask_mod_other_buffers,
345
+ )
346
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
347
+ out_proxy = proxy_mode.tracer.create_proxy(
348
+ "call_function", flex_attention, proxy_args, {}
349
+ )
350
+ return track_tensor_tree(
351
+ example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
352
+ )
353
+
354
+
355
+ @flex_attention.py_impl(ProxyTorchDispatchMode)
356
+ def flex_attention_proxy_torch_dispatch_mode(
357
+ mode: ProxyTorchDispatchMode,
358
+ query: torch.Tensor,
359
+ key: torch.Tensor,
360
+ value: torch.Tensor,
361
+ score_mod: Callable,
362
+ block_mask: tuple,
363
+ scale: float,
364
+ kernel_options: dict[str, Any],
365
+ score_mod_other_buffers: tuple = (),
366
+ mask_mod_other_buffers: tuple = (),
367
+ ) -> tuple[torch.Tensor, torch.Tensor]:
368
+ assert mode is not None, "Mode should always be enabled for python fallback key"
369
+ return trace_flex_attention(
370
+ mode,
371
+ query,
372
+ key,
373
+ value,
374
+ score_mod,
375
+ block_mask,
376
+ scale,
377
+ kernel_options,
378
+ score_mod_other_buffers,
379
+ mask_mod_other_buffers,
380
+ )
381
+
382
+
383
+ @flex_attention.py_functionalize_impl
384
+ def flex_attention_functionalize(
385
+ ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,
386
+ query: torch.Tensor,
387
+ key: torch.Tensor,
388
+ value: torch.Tensor,
389
+ score_mod: Callable,
390
+ block_mask: tuple,
391
+ scale: float,
392
+ kernel_options: dict[str, Any],
393
+ score_mod_other_buffers: tuple = (),
394
+ mask_mod_other_buffers: tuple = (),
395
+ ) -> tuple[torch.Tensor, torch.Tensor]:
396
+ """Defines the functionalization rules for the flex_attention operator.
397
+
398
+ Write now we are unwrapping each tensor and then redispatching to the next, however we want to
399
+ guard against any mutations in the score_mod function, to the other_buffers since those
400
+ are free variables.
401
+ """
402
+ from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
403
+
404
+ if has_user_subclass(
405
+ (
406
+ query,
407
+ key,
408
+ value,
409
+ score_mod,
410
+ block_mask,
411
+ scale,
412
+ kernel_options,
413
+ score_mod_other_buffers,
414
+ mask_mod_other_buffers,
415
+ ),
416
+ allowed_subclasses=(FakeTensor, FunctionalTensor),
417
+ ):
418
+ return NotImplemented
419
+
420
+ query_unwrapped = ctx.unwrap_tensors(query)
421
+ key_unwrapped = ctx.unwrap_tensors(key)
422
+ value_unwrapped = ctx.unwrap_tensors(value)
423
+ block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
424
+ score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
425
+ mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
426
+
427
+ # Appease the mypy overlords
428
+ assert isinstance(query_unwrapped, torch.Tensor)
429
+ assert isinstance(key_unwrapped, torch.Tensor)
430
+ assert isinstance(value_unwrapped, torch.Tensor)
431
+ assert isinstance(block_mask_unwrapped, tuple)
432
+ assert isinstance(score_mod_other_buffers_unwrapped, tuple)
433
+ assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
434
+
435
+ example_vals = (
436
+ [query_unwrapped.new_zeros(())]
437
+ + [query_unwrapped.new_zeros((), dtype=torch.int) for _ in range(4)]
438
+ + list(score_mod_other_buffers_unwrapped)
439
+ )
440
+ with ctx.redispatch_to_next():
441
+ functional_score_mod = ctx.functionalize(score_mod)
442
+ pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
443
+ with TransformGetItemToIndex():
444
+ # TODO: So far only the input mutations are checked
445
+ # In the other HOPs, also aliases are checked which is
446
+ # omitted here
447
+ mutates = _has_potential_branch_input_mutation(
448
+ score_mod, example_vals, pre_dispatch
449
+ )
450
+ # The only care about mutations of existing buffers since we can't replay these.
451
+ # However, we can just error if anything is detected
452
+ if mutates:
453
+ raise UnsupportedAliasMutationException("Mutations detected in score_mod")
454
+
455
+ out = flex_attention(
456
+ query_unwrapped,
457
+ key_unwrapped,
458
+ value_unwrapped,
459
+ functional_score_mod,
460
+ block_mask_unwrapped,
461
+ scale,
462
+ kernel_options,
463
+ score_mod_other_buffers_unwrapped,
464
+ mask_mod_other_buffers_unwrapped,
465
+ )
466
+ return ctx.wrap_tensors(out) # type: ignore[return-value, arg-type]
467
+
468
+
469
+ @register_fake(flex_attention)
470
+ def flex_attention_fake_impl(
471
+ query: torch.Tensor,
472
+ key: torch.Tensor,
473
+ value: torch.Tensor,
474
+ score_mod: Callable,
475
+ block_mask: tuple,
476
+ scale: float,
477
+ kernel_options: dict[str, Any],
478
+ score_mod_other_buffers: tuple = (),
479
+ mask_mod_other_buffers: tuple = (),
480
+ ) -> tuple[torch.Tensor, torch.Tensor]:
481
+ if has_user_subclass(
482
+ (
483
+ query,
484
+ key,
485
+ value,
486
+ score_mod,
487
+ block_mask,
488
+ scale,
489
+ kernel_options,
490
+ score_mod_other_buffers,
491
+ mask_mod_other_buffers,
492
+ ),
493
+ allowed_subclasses=(FakeTensor,),
494
+ ):
495
+ return NotImplemented
496
+
497
+ # TODO: Figure out a better way to handle this for NJT than using sum()
498
+ if query.is_nested:
499
+ out = torch.empty_like(query, memory_format=torch.contiguous_format)
500
+ logsumexp = query.sum(dim=-1)
501
+ return out, logsumexp
502
+
503
+ v_head_dim = value.size(-1)
504
+ batch_size, num_heads, seq_len_q, _q_head_dim = query.shape
505
+ logsumexp = query.new_empty(batch_size, num_heads, seq_len_q, dtype=torch.float32)
506
+ out_shape = (batch_size, num_heads, seq_len_q, v_head_dim)
507
+ out = query.new_empty(out_shape)
508
+ out = _permute_strides(out, query.stride())
509
+ return out, logsumexp
510
+
511
+
512
+ # Registers dispatches for SAC
513
+ redirect_to_mode(flex_attention, _CachingTorchDispatchMode)
514
+ redirect_to_mode(flex_attention, _CachedTorchDispatchMode)
515
+
516
+
517
+ # ---------------------------- Autograd Implementation ----------------------------
518
+ def create_fw_bw_graph(
519
+ score_mod: Callable,
520
+ index_values: tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
521
+ other_buffers: tuple[Tensor, ...],
522
+ ) -> tuple[Callable, Callable]:
523
+ # See Note:[HOP create fw_bw graph]
524
+
525
+ # All of these imports need to be here in order to avoid circular dependencies
526
+ from torch._dispatch.python import suspend_functionalization
527
+ from torch._functorch.aot_autograd import AOTConfig, create_joint
528
+ from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
529
+ from torch._subclasses.functional_tensor import disable_functional_mode
530
+ from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
531
+
532
+ dummy_aot_config = AOTConfig(
533
+ fw_compiler=None, # type: ignore[arg-type]
534
+ bw_compiler=None, # type: ignore[arg-type]
535
+ partition_fn=None, # type: ignore[arg-type]
536
+ decompositions={},
537
+ num_params_buffers=0,
538
+ aot_id=0,
539
+ keep_inference_input_mutations=False,
540
+ )
541
+
542
+ with suspend_functionalization(), disable_functional_mode():
543
+ with disable_proxy_modes_tracing():
544
+
545
+ def _from_fun(
546
+ t: Union[Tensor, torch.SymInt, int],
547
+ ) -> Union[Tensor, torch.SymInt, int]:
548
+ if isinstance(t, torch.Tensor):
549
+ return torch.empty_strided(
550
+ t.size(),
551
+ t.stride(),
552
+ device=t.device,
553
+ dtype=t.dtype,
554
+ requires_grad=t.requires_grad,
555
+ )
556
+ return t
557
+
558
+ # If someone runs this hop under the default compiler backend ("eager")
559
+ # Then this path will be run with the actual user inputs. We convert them
560
+ # to fake tensors in order to not perform any actual compute.
561
+ from torch._guards import detect_fake_mode
562
+
563
+ fake_mode = detect_fake_mode(index_values)
564
+ if fake_mode is None:
565
+ fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
566
+
567
+ with fake_mode:
568
+ unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
569
+ unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
570
+
571
+ assert all(
572
+ isinstance(t, (FakeTensor, int, torch.SymInt))
573
+ for t in unwrapped_score_mod_indexes + unwrapped_other_buffers
574
+ )
575
+
576
+ example_flat_out = pytree.tree_map(
577
+ _from_fun,
578
+ score_mod(*unwrapped_score_mod_indexes, *unwrapped_other_buffers),
579
+ )
580
+ if not isinstance(example_flat_out, torch.Tensor):
581
+ raise RuntimeError(
582
+ "Expected output of score_mod to be a tensor."
583
+ f"Got type {type(example_flat_out)}."
584
+ )
585
+ example_grad = _from_fun(example_flat_out)
586
+
587
+ def joint_f(
588
+ score: Tensor,
589
+ b: Tensor,
590
+ h: Tensor,
591
+ m: Tensor,
592
+ n: Tensor,
593
+ example_grad: Tensor,
594
+ *other_buffers: tuple[Tensor, ...],
595
+ ) -> tuple[Tensor, ...]:
596
+ def fw_with_masks(
597
+ *args: tuple[Tensor, ...]
598
+ ) -> tuple[tuple[Tensor], tuple[bool]]:
599
+ fw_out = score_mod(*args)
600
+ out_requires_grad = fw_out.requires_grad
601
+ return ((fw_out,), (out_requires_grad,))
602
+
603
+ joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
604
+ args = [score, b, h, m, n] + list(other_buffers)
605
+ optional_grad = [example_grad] if example_grad.requires_grad else []
606
+ _, grads = joint(args, optional_grad)
607
+
608
+ return grads
609
+
610
+ joint_graph = make_fx(joint_f)(
611
+ *unwrapped_score_mod_indexes, example_grad, *unwrapped_other_buffers
612
+ )
613
+ return score_mod, joint_graph
614
+
615
+
616
+ class FlexAttentionAutogradOp(torch.autograd.Function):
617
+ @staticmethod
618
+ def forward(
619
+ ctx: Any,
620
+ query: Tensor,
621
+ key: Tensor,
622
+ value: Tensor,
623
+ fw_graph: Callable,
624
+ joint_graph: Callable,
625
+ block_mask: tuple[Any, ...],
626
+ scale: float,
627
+ kernel_options: dict[str, Any],
628
+ mask_mod_other_buffers: tuple[Any, ...],
629
+ *score_mod_other_buffers: tuple[Any, ...],
630
+ ) -> tuple[torch.Tensor, torch.Tensor]:
631
+ any_buffer_requires_grad = any(
632
+ buffer.requires_grad
633
+ for buffer in mask_mod_other_buffers
634
+ if isinstance(buffer, torch.Tensor)
635
+ )
636
+ assert (
637
+ not any_buffer_requires_grad
638
+ ), "Captured buffers from mask mod that require grad are not supported."
639
+ ctx._fw_graph = fw_graph
640
+ ctx._joint_graph = joint_graph
641
+ ctx._mask_graph = block_mask[-1]
642
+ ctx.scale = scale
643
+ ctx.kernel_options = kernel_options
644
+ ctx._score_mod_other_buffers_len = len(score_mod_other_buffers)
645
+ with torch._C._AutoDispatchBelowAutograd():
646
+ out, logsumexp = flex_attention(
647
+ query,
648
+ key,
649
+ value,
650
+ fw_graph,
651
+ block_mask,
652
+ scale,
653
+ kernel_options,
654
+ score_mod_other_buffers,
655
+ mask_mod_other_buffers,
656
+ )
657
+
658
+ save_tensors_and_symints_for_backward(
659
+ ctx,
660
+ (
661
+ query,
662
+ key,
663
+ value,
664
+ out,
665
+ logsumexp,
666
+ *block_mask[:-1],
667
+ *score_mod_other_buffers,
668
+ *mask_mod_other_buffers,
669
+ ),
670
+ )
671
+ return out, logsumexp
672
+
673
+ @staticmethod
674
+ def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> tuple[Optional[Tensor], ...]: # type: ignore[override]
675
+ fw_args = saved_tensors_and_symints(ctx)
676
+ (
677
+ query,
678
+ key,
679
+ value,
680
+ out,
681
+ logsumexp,
682
+ query_lengths,
683
+ kv_lengths,
684
+ kv_num_blocks,
685
+ kv_indices,
686
+ full_kv_num_blocks,
687
+ full_kv_indices,
688
+ q_num_blocks,
689
+ q_indices,
690
+ full_q_num_blocks,
691
+ full_q_indices,
692
+ Q_BLOCK_SIZE,
693
+ KV_BLOCK_SIZE,
694
+ *other_buffers,
695
+ ) = fw_args
696
+ fw_graph = ctx._fw_graph
697
+ joint_graph = ctx._joint_graph
698
+ mask_graph = ctx._mask_graph
699
+ scale = ctx.scale
700
+ kernel_options = ctx.kernel_options
701
+ score_mod_other_buffers = tuple(
702
+ other_buffers[: ctx._score_mod_other_buffers_len]
703
+ )
704
+ mask_mod_other_buffers = tuple(
705
+ other_buffers[ctx._score_mod_other_buffers_len :]
706
+ )
707
+ # We have asserted that mask_mod_other_buffers do not require grad,
708
+ # but score_mod_other_buffers can require grad.
709
+ none_grads = [None] * 6
710
+ (
711
+ grad_query,
712
+ grad_key,
713
+ grad_value,
714
+ grad_score_mod_captured,
715
+ ) = flex_attention_backward(
716
+ query,
717
+ key,
718
+ value,
719
+ out,
720
+ logsumexp,
721
+ grad_out,
722
+ grad_logsumexp,
723
+ fw_graph,
724
+ joint_graph,
725
+ (
726
+ query_lengths,
727
+ kv_lengths,
728
+ kv_num_blocks,
729
+ kv_indices,
730
+ full_kv_num_blocks,
731
+ full_kv_indices,
732
+ q_num_blocks,
733
+ q_indices,
734
+ full_q_num_blocks,
735
+ full_q_indices,
736
+ Q_BLOCK_SIZE,
737
+ KV_BLOCK_SIZE,
738
+ mask_graph,
739
+ ),
740
+ scale,
741
+ kernel_options,
742
+ score_mod_other_buffers,
743
+ mask_mod_other_buffers,
744
+ )
745
+ return grad_query, grad_key, grad_value, *none_grads, *grad_score_mod_captured
746
+
747
+
748
+ # TODO: Rework DispatchKey.Autograd to py_autograd_impl
749
+ @flex_attention.py_impl(DispatchKey.Autograd)
750
+ def flex_attention_autograd(
751
+ query: torch.Tensor,
752
+ key: torch.Tensor,
753
+ value: torch.Tensor,
754
+ score_mod: Callable,
755
+ block_mask: tuple,
756
+ scale: float,
757
+ kernel_options: dict[str, Any],
758
+ score_mod_other_buffers: tuple[Tensor, ...] = (),
759
+ mask_mod_other_buffers: tuple[Tensor, ...] = (),
760
+ ) -> tuple[torch.Tensor, torch.Tensor]:
761
+ from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
762
+
763
+ with TransformGetItemToIndex():
764
+ input_requires_grad = any(
765
+ isinstance(t, torch.Tensor) and t.requires_grad
766
+ for t in (query, key, value, *score_mod_other_buffers)
767
+ )
768
+ if torch.is_grad_enabled() and input_requires_grad:
769
+ example_vals = (
770
+ query.new_zeros((), requires_grad=input_requires_grad),
771
+ query.new_zeros((), dtype=torch.int),
772
+ query.new_zeros((), dtype=torch.int),
773
+ query.new_zeros((), dtype=torch.int),
774
+ query.new_zeros((), dtype=torch.int),
775
+ )
776
+ fw_graph, bw_graph = create_fw_bw_graph(
777
+ score_mod, example_vals, score_mod_other_buffers
778
+ )
779
+ else:
780
+ fw_graph, bw_graph = score_mod, None
781
+ out, logsumexp = FlexAttentionAutogradOp.apply(
782
+ query,
783
+ key,
784
+ value,
785
+ fw_graph,
786
+ bw_graph,
787
+ block_mask,
788
+ scale,
789
+ kernel_options,
790
+ mask_mod_other_buffers,
791
+ *score_mod_other_buffers,
792
+ )
793
+ return out, logsumexp
794
+
795
+
796
+ # ---------------------------- Backward HOP Implementation ----------------------------
797
+
798
+
799
+ @flex_attention_backward.py_impl(DispatchKey.CompositeExplicitAutograd)
800
+ def sdpa_dense_backward(
801
+ query: torch.Tensor,
802
+ key: torch.Tensor,
803
+ value: torch.Tensor,
804
+ out: torch.Tensor,
805
+ logsumexp: torch.Tensor,
806
+ grad_out: torch.Tensor,
807
+ grad_logsumexp: torch.Tensor,
808
+ fw_graph: Callable, # GraphModule type hint?
809
+ joint_graph: Callable,
810
+ block_mask: tuple,
811
+ scale: float,
812
+ kernel_options: dict[str, Any],
813
+ score_mod_other_buffers: tuple,
814
+ mask_mod_other_buffers: tuple,
815
+ ) -> tuple[
816
+ torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
817
+ ]:
818
+ from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
819
+
820
+ Bq, Hq, seq_len_q, qk_head_dim = query.shape
821
+ Bkv, Hkv, seq_len_kv, v_head_dim = value.shape
822
+
823
+ # Get outputs before calling repeat interleave and permute to input stride orders
824
+ actual_grad_query = query.new_empty((Bq, Hq, seq_len_q, qk_head_dim))
825
+ actual_grad_query = _permute_strides(actual_grad_query, query.stride())
826
+
827
+ actual_grad_key = key.new_empty((Bq, Hkv, seq_len_kv, qk_head_dim))
828
+ actual_grad_key = _permute_strides(actual_grad_key, key.stride())
829
+
830
+ actual_grad_value = value.new_empty((Bq, Hkv, seq_len_kv, v_head_dim))
831
+ actual_grad_value = _permute_strides(actual_grad_value, value.stride())
832
+
833
+ def _maybe_new_buffer(
834
+ buffer: Union[torch.Tensor, torch.SymInt, int],
835
+ ) -> Optional[Union[torch.Tensor, torch.SymInt, int]]:
836
+ if isinstance(buffer, torch.Tensor):
837
+ return (
838
+ torch.empty_like(buffer, memory_format=torch.contiguous_format)
839
+ if buffer.requires_grad
840
+ else None
841
+ )
842
+ return buffer
843
+
844
+ actual_grad_score_mod_captured = [
845
+ _maybe_new_buffer(buffer) for buffer in score_mod_other_buffers
846
+ ]
847
+
848
+ Bq, Bkv = query.size(0), key.size(0)
849
+ if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)):
850
+ raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}")
851
+
852
+ key = key.expand((Bq, *key.size()[1:]))
853
+ value = value.expand((Bq, *value.size()[1:]))
854
+
855
+ G = query.size(1) // key.size(1)
856
+ key = torch.repeat_interleave(key, G, dim=1)
857
+ value = torch.repeat_interleave(value, G, dim=1)
858
+
859
+ # We're undoing the log -> log2 change of base in the forwards
860
+ logsumexp = logsumexp * math.log(2)
861
+ # The backwards formula for the log -> log2 change of base in the forwards
862
+ grad_logsumexp = grad_logsumexp / math.log(2)
863
+ scores, post_mod_scores = _math_attention_inner(
864
+ query,
865
+ key,
866
+ value,
867
+ fw_graph,
868
+ block_mask,
869
+ scale,
870
+ kernel_options,
871
+ score_mod_other_buffers,
872
+ mask_mod_other_buffers,
873
+ )
874
+ masked_out_rows = logsumexp == -float("inf")
875
+ softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1))
876
+ softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores)
877
+
878
+ grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out
879
+
880
+ grad_softmax_scores = grad_out @ value.transpose(-2, -1)
881
+
882
+ sum_scores = torch.sum(out * grad_out, -1, keepdim=True)
883
+ grad_score_mod = softmax_scores * (
884
+ grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1)
885
+ )
886
+
887
+ b = torch.arange(0, scores.size(0), device=scores.device)
888
+ h = torch.arange(0, scores.size(1), device=scores.device)
889
+ m = torch.arange(0, scores.size(2), device=scores.device)
890
+ n = torch.arange(0, scores.size(3), device=scores.device)
891
+
892
+ mask_graph = block_mask[-1]
893
+ # Gradient of the inline score_mod function, with respect to the scores
894
+ captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
895
+ out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers)
896
+ from torch.nn.attention.flex_attention import _vmap_for_bhqkv
897
+
898
+ # inputs are [score, b, h, q_idx, kv_idx, gradOut, ...]
899
+ # score and gradOut are "fully" batched
900
+ joint_score_mod = _vmap_for_bhqkv(
901
+ joint_graph,
902
+ prefix=(0,),
903
+ suffix=(0,) + captured_buffers_in_dim,
904
+ out_dims=out_dims,
905
+ )
906
+ with TransformGetItemToIndex():
907
+ grad_scores, _, _, _, _, *grad_score_mod_captured = joint_score_mod(
908
+ scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers
909
+ )
910
+ grad_scores = grad_scores * scale
911
+ grad_scores = grad_scores.to(query.dtype)
912
+
913
+ mask_mod = _vmap_for_bhqkv(
914
+ mask_graph, prefix=(), suffix=(None,) * len(mask_mod_other_buffers)
915
+ )
916
+ with TransformGetItemToIndex():
917
+ mask_scores = mask_mod(b, h, m, n, *mask_mod_other_buffers)
918
+ grad_scores = torch.where(
919
+ mask_scores, grad_scores, torch.tensor(0, dtype=query.dtype)
920
+ )
921
+
922
+ grad_query = grad_scores @ key
923
+ grad_key = grad_scores.transpose(-2, -1) @ query
924
+
925
+ # Reduce DK, DV along broadcasted heads.
926
+ grad_key = grad_key.view(
927
+ grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1)
928
+ )
929
+ grad_value = grad_value.view(
930
+ grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1)
931
+ )
932
+
933
+ grad_key = torch.sum(grad_key, 2, keepdim=False)
934
+ grad_value = torch.sum(grad_value, 2, keepdim=False)
935
+
936
+ # Fill to correctly strided outputs
937
+ actual_grad_query.copy_(grad_query)
938
+ actual_grad_key.copy_(grad_key)
939
+ actual_grad_value.copy_(grad_value)
940
+
941
+ if Bq != Bkv:
942
+ assert (
943
+ Bq > 1 and Bkv == 1
944
+ ), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}"
945
+
946
+ actual_grad_key = torch.sum(actual_grad_key, 0, keepdim=True)
947
+ actual_grad_value = torch.sum(actual_grad_value, 0, keepdim=True)
948
+
949
+ score_mod_other_buffer_grads = [
950
+ actual_grad.copy_(grad) if isinstance(actual_grad, torch.Tensor) else None
951
+ for actual_grad, grad in zip(
952
+ actual_grad_score_mod_captured, grad_score_mod_captured
953
+ )
954
+ ]
955
+
956
+ return (
957
+ actual_grad_query,
958
+ actual_grad_key,
959
+ actual_grad_value,
960
+ tuple(score_mod_other_buffer_grads),
961
+ )
962
+
963
+
964
+ def trace_flex_attention_backward(
965
+ proxy_mode: ProxyTorchDispatchMode,
966
+ query: torch.Tensor,
967
+ key: torch.Tensor,
968
+ value: torch.Tensor,
969
+ out: torch.Tensor,
970
+ logsumexp: torch.Tensor,
971
+ grad_out: torch.Tensor,
972
+ grad_logsumexp: torch.Tensor,
973
+ fw_graph: Union[Callable, GraphModule],
974
+ joint_graph: GraphModule,
975
+ block_mask: tuple,
976
+ scale: float,
977
+ kernel_options: dict[str, Any],
978
+ score_mod_other_buffers: tuple = (),
979
+ mask_mod_other_buffers: tuple = (),
980
+ ) -> tuple[
981
+ torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
982
+ ]:
983
+ """We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
984
+ from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
985
+
986
+ example_out = flex_attention_backward(
987
+ query,
988
+ key,
989
+ value,
990
+ out,
991
+ logsumexp,
992
+ grad_out,
993
+ grad_logsumexp,
994
+ fw_graph,
995
+ joint_graph,
996
+ block_mask,
997
+ scale,
998
+ kernel_options,
999
+ score_mod_other_buffers,
1000
+ mask_mod_other_buffers,
1001
+ )
1002
+
1003
+ requires_grad = any(pytree.tree_map(lambda x: x.requires_grad, (query, key)))
1004
+ fw_example_vals = [query.new_zeros((), requires_grad=requires_grad)] + [
1005
+ query.new_zeros((), dtype=torch.int) for _ in range(4)
1006
+ ]
1007
+ bw_example_vals = fw_example_vals + [query.new_zeros(())]
1008
+ mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)]
1009
+ mask_graph = block_mask[-1]
1010
+ with TransformGetItemToIndex():
1011
+ # There's no active make_fx during the compiled autograd graph's initial capture
1012
+ fw_graph = _maybe_reenter_make_fx(fw_graph)(
1013
+ *fw_example_vals, *score_mod_other_buffers
1014
+ )
1015
+ joint_graph = _maybe_reenter_make_fx(joint_graph)(
1016
+ *bw_example_vals, *score_mod_other_buffers
1017
+ )
1018
+ mask_graph = _maybe_reenter_make_fx(mask_graph)(
1019
+ *mask_example_vals, *mask_mod_other_buffers
1020
+ )
1021
+ assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
1022
+ block_mask = block_mask[:-1] + (mask_graph,)
1023
+
1024
+ qualname = proxy_mode.tracer.get_fresh_qualname("fw_graph")
1025
+ proxy_mode.tracer.root.register_module(qualname, fw_graph) # type: ignore[arg-type]
1026
+ qualname = proxy_mode.tracer.get_fresh_qualname("joint_graph")
1027
+ proxy_mode.tracer.root.register_module(qualname, joint_graph)
1028
+ qualname = proxy_mode.tracer.get_fresh_qualname("mask_graph")
1029
+ proxy_mode.tracer.root.register_module(qualname, mask_graph)
1030
+
1031
+ node_args = (
1032
+ query,
1033
+ key,
1034
+ value,
1035
+ out,
1036
+ logsumexp,
1037
+ grad_out,
1038
+ grad_logsumexp,
1039
+ fw_graph,
1040
+ joint_graph,
1041
+ block_mask,
1042
+ scale,
1043
+ kernel_options,
1044
+ score_mod_other_buffers,
1045
+ mask_mod_other_buffers,
1046
+ )
1047
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
1048
+ out_proxy = proxy_mode.tracer.create_proxy(
1049
+ "call_function",
1050
+ flex_attention_backward,
1051
+ proxy_args,
1052
+ {},
1053
+ name="flex_attention_backward",
1054
+ )
1055
+ return track_tensor_tree(
1056
+ example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
1057
+ )
1058
+
1059
+
1060
+ @flex_attention_backward.py_impl(ProxyTorchDispatchMode)
1061
+ def flex_attention_backward_proxy_torch_dispatch_mode(
1062
+ mode: ProxyTorchDispatchMode,
1063
+ query: torch.Tensor,
1064
+ key: torch.Tensor,
1065
+ value: torch.Tensor,
1066
+ out: torch.Tensor,
1067
+ logsumexp: torch.Tensor,
1068
+ grad_out: torch.Tensor,
1069
+ grad_logsumexp: torch.Tensor,
1070
+ fw_graph: Union[Callable, GraphModule],
1071
+ joint_graph: GraphModule,
1072
+ block_mask: tuple,
1073
+ scale: float,
1074
+ kernel_options: dict[str, Any],
1075
+ score_mod_other_buffers: tuple = (),
1076
+ mask_mod_other_buffers: tuple = (),
1077
+ ) -> tuple[
1078
+ torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
1079
+ ]:
1080
+ assert mode is not None, "Mode should always be enabled for python fallback key"
1081
+ return trace_flex_attention_backward(
1082
+ mode,
1083
+ query,
1084
+ key,
1085
+ value,
1086
+ out,
1087
+ logsumexp,
1088
+ grad_out,
1089
+ grad_logsumexp,
1090
+ fw_graph,
1091
+ joint_graph,
1092
+ block_mask,
1093
+ scale,
1094
+ kernel_options,
1095
+ score_mod_other_buffers,
1096
+ mask_mod_other_buffers,
1097
+ )
1098
+
1099
+
1100
+ @flex_attention_backward.py_functionalize_impl
1101
+ def flex_attention_backward_functionalize(
1102
+ ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,
1103
+ query: torch.Tensor,
1104
+ key: torch.Tensor,
1105
+ value: torch.Tensor,
1106
+ out: torch.Tensor,
1107
+ logsumexp: torch.Tensor,
1108
+ grad_out: torch.Tensor,
1109
+ grad_logsumexp: torch.Tensor,
1110
+ fw_graph: Union[Callable, GraphModule],
1111
+ joint_graph: GraphModule,
1112
+ block_mask: tuple,
1113
+ scale: float,
1114
+ kernel_options: dict[str, Any],
1115
+ score_mod_other_buffers: tuple = (),
1116
+ mask_mod_other_buffers: tuple = (),
1117
+ ) -> tuple[
1118
+ torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
1119
+ ]:
1120
+ """Defines the functionalization rules for the flex_attention operator.
1121
+
1122
+ Write now we are unwrapping each tensor and then redispatching to the next,
1123
+ since we know that the forward score mod function is assured to be free of mutations
1124
+ to the other_buffers, we skip that mutate check and go straight to redispatching.
1125
+ """
1126
+
1127
+ if has_user_subclass(
1128
+ (
1129
+ query,
1130
+ key,
1131
+ value,
1132
+ out,
1133
+ logsumexp,
1134
+ grad_out,
1135
+ grad_logsumexp,
1136
+ block_mask,
1137
+ scale,
1138
+ kernel_options,
1139
+ score_mod_other_buffers,
1140
+ mask_mod_other_buffers,
1141
+ ),
1142
+ allowed_subclasses=(FakeTensor, FunctionalTensor),
1143
+ ):
1144
+ return NotImplemented
1145
+ query_unwrapped = ctx.unwrap_tensors(query)
1146
+ key_unwrapped = ctx.unwrap_tensors(key)
1147
+ value_unwrapped = ctx.unwrap_tensors(value)
1148
+ out_unwrapped = ctx.unwrap_tensors(out)
1149
+ logsumexp_unwrapped = ctx.unwrap_tensors(logsumexp)
1150
+ grad_out_unwrapped = ctx.unwrap_tensors(grad_out)
1151
+ grad_logsumexp_unwrapped = ctx.unwrap_tensors(grad_logsumexp)
1152
+ block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
1153
+ score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
1154
+ mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
1155
+
1156
+ # Appease the mypy overlords
1157
+ assert isinstance(query_unwrapped, torch.Tensor)
1158
+ assert isinstance(key_unwrapped, torch.Tensor)
1159
+ assert isinstance(value_unwrapped, torch.Tensor)
1160
+ assert isinstance(out_unwrapped, torch.Tensor)
1161
+ assert isinstance(logsumexp_unwrapped, torch.Tensor)
1162
+ assert isinstance(grad_out_unwrapped, torch.Tensor)
1163
+ assert isinstance(grad_logsumexp_unwrapped, torch.Tensor)
1164
+ assert isinstance(block_mask_unwrapped, tuple)
1165
+ assert isinstance(score_mod_other_buffers_unwrapped, tuple)
1166
+ assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
1167
+
1168
+ with ctx.redispatch_to_next():
1169
+ functional_fw_graph = ctx.functionalize(fw_graph)
1170
+ functional_joint_graph = ctx.functionalize(joint_graph)
1171
+
1172
+ (
1173
+ grad_query,
1174
+ grad_key,
1175
+ grad_value,
1176
+ grad_score_mod_captured,
1177
+ ) = flex_attention_backward(
1178
+ query_unwrapped,
1179
+ key_unwrapped,
1180
+ value_unwrapped,
1181
+ out_unwrapped,
1182
+ logsumexp_unwrapped,
1183
+ grad_out_unwrapped,
1184
+ grad_logsumexp_unwrapped,
1185
+ functional_fw_graph, # type: ignore[arg-type]
1186
+ functional_joint_graph, # type: ignore[arg-type]
1187
+ block_mask_unwrapped,
1188
+ scale,
1189
+ kernel_options,
1190
+ score_mod_other_buffers_unwrapped,
1191
+ mask_mod_other_buffers_unwrapped,
1192
+ )
1193
+
1194
+ return ctx.wrap_tensors((grad_query, grad_key, grad_value, grad_score_mod_captured)) # type: ignore[return-value,arg-type]
1195
+
1196
+
1197
+ @register_fake(flex_attention_backward)
1198
+ def flex_attention_backward_fake_tensor_mode(
1199
+ query: torch.Tensor,
1200
+ key: torch.Tensor,
1201
+ value: torch.Tensor,
1202
+ out: torch.Tensor,
1203
+ logsumexp: torch.Tensor,
1204
+ grad_out: torch.Tensor,
1205
+ grad_logsumexp: torch.Tensor,
1206
+ fw_graph: Union[Callable, GraphModule],
1207
+ joint_graph: GraphModule,
1208
+ block_mask: tuple,
1209
+ scale: float,
1210
+ kernel_options: dict[str, Any],
1211
+ score_mod_other_buffers: tuple = (),
1212
+ mask_mod_other_buffers: tuple = (),
1213
+ ) -> tuple[
1214
+ torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
1215
+ ]:
1216
+ if has_user_subclass(
1217
+ (
1218
+ query,
1219
+ key,
1220
+ value,
1221
+ out,
1222
+ logsumexp,
1223
+ grad_out,
1224
+ grad_logsumexp,
1225
+ block_mask,
1226
+ scale,
1227
+ kernel_options,
1228
+ score_mod_other_buffers,
1229
+ mask_mod_other_buffers,
1230
+ ),
1231
+ allowed_subclasses=(FakeTensor,),
1232
+ ):
1233
+ return NotImplemented
1234
+ Bq, _, _, qk_head_dim = query.shape
1235
+ Bkv, Hkv, seq_len_kv, v_head_dim = value.shape
1236
+
1237
+ grad_query = torch.empty_like(query)
1238
+ # zeros_and_scatter creates a contiguous zeros tensor -> contiguous_format
1239
+ grad_score_mod_captured = tuple(
1240
+ [
1241
+ (
1242
+ torch.empty_like(buffer, memory_format=torch.contiguous_format)
1243
+ if isinstance(buffer, torch.Tensor) and buffer.requires_grad
1244
+ else None
1245
+ )
1246
+ for buffer in score_mod_other_buffers
1247
+ ]
1248
+ )
1249
+
1250
+ broadcasted_grad_key = key.new_empty((Bq, Hkv, seq_len_kv, qk_head_dim))
1251
+ broadcasted_grad_key = _permute_strides(broadcasted_grad_key, key.stride())
1252
+
1253
+ broadcasted_grad_value = value.new_empty((Bq, Hkv, seq_len_kv, v_head_dim))
1254
+ broadcasted_grad_value = _permute_strides(broadcasted_grad_value, value.stride())
1255
+
1256
+ if Bq > 1 and Bkv == 1:
1257
+ grad_key = torch.sum(broadcasted_grad_key, dim=0, keepdim=True)
1258
+ grad_value = torch.sum(broadcasted_grad_value, dim=0, keepdim=True)
1259
+ else:
1260
+ grad_key = broadcasted_grad_key
1261
+ grad_value = broadcasted_grad_value
1262
+
1263
+ return grad_query, grad_key, grad_value, grad_score_mod_captured
1264
+
1265
+
1266
+ flex_attention_backward.py_autograd_impl(
1267
+ autograd_not_implemented(flex_attention_backward, deferred_error=True)
1268
+ )
archive/.venv/Lib/site-packages/torch/_higher_order_ops/foreach_map.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ from typing import Any, Callable
4
+
5
+ from torch._higher_order_ops.base_hop import BaseHOP, FunctionWithNoFreeVars
6
+
7
+
8
+ class ForeachMap(BaseHOP):
9
+ def __init__(self):
10
+ super().__init__("foreach_map")
11
+
12
+ def __call__(self, fn, *operands, **kwargs): # type: ignore[override]
13
+ fn = FunctionWithNoFreeVars(fn)
14
+ return super().__call__(fn, *operands, **kwargs)
15
+
16
+
17
+ _foreach_map = ForeachMap()
18
+
19
+
20
+ def foreach_map(op: Callable, *operands: Any, **kwargs: dict[str, Any]):
21
+ from torch._dynamo.polyfills import foreach_map_fn
22
+
23
+ return _foreach_map(foreach_map_fn, op, *operands, **kwargs)
archive/.venv/Lib/site-packages/torch/_higher_order_ops/hints_wrap.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.utils._pytree as pytree
4
+ from torch._C import DispatchKey
5
+ from torch._higher_order_ops.utils import (
6
+ autograd_not_implemented,
7
+ reenter_make_fx,
8
+ unique_graph_id,
9
+ )
10
+ from torch._ops import HigherOrderOperator
11
+ from torch._subclasses.fake_tensor import FakeTensorMode
12
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
13
+
14
+
15
+ # used for wrapping a function/op with context hints
16
+ class HintsWrapper(HigherOrderOperator):
17
+ def __init__(self):
18
+ super().__init__("hints_wrapper")
19
+
20
+ def __call__(self, body_fn, args, kwargs, hints):
21
+ r"""
22
+ Call implementation of hints_wrapper
23
+
24
+ Args:
25
+ body_fn (Callable): A callable function that is within the scope
26
+ that is being traced.
27
+
28
+ args (Tuple of torch.Tensor/int/float/bool): A tuple of inputs to
29
+ body_fn.
30
+
31
+ kwargs (dict): Keyword argument to the body_fn.
32
+
33
+ hints (dict): A dict of context hints which could be passed to
34
+ backend compiler.
35
+ """
36
+ if not isinstance(args, tuple):
37
+ raise RuntimeError(f"args must be a tuple, got {type(args)}")
38
+
39
+ if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args):
40
+ raise RuntimeError(
41
+ "args must be a tuple of tensors, ints, floats, or bools, got "
42
+ f"{args}"
43
+ )
44
+
45
+ if not isinstance(kwargs, dict):
46
+ raise RuntimeError(f"kwargs must be a dict, got {type(kwargs)}")
47
+
48
+ if len(kwargs) > 0:
49
+ raise RuntimeError(
50
+ f"kwargs except for hints are not supported, got {kwargs}"
51
+ )
52
+
53
+ if not isinstance(hints, dict):
54
+ raise RuntimeError(f"hints must be a dict, got {type(hints)}")
55
+
56
+ for k, v in hints.items():
57
+ if not isinstance(k, str):
58
+ raise RuntimeError(f"hints key must be a str, got {k}.")
59
+
60
+ if not isinstance(v, (int, float, bool, str)):
61
+ raise RuntimeError(
62
+ "hints must be a dict containing int, float, bool or str "
63
+ f"value, got value {v} for key {k}."
64
+ )
65
+
66
+ return super().__call__(body_fn, args, kwargs, hints)
67
+
68
+
69
+ hints_wrapper = HintsWrapper()
70
+
71
+
72
+ @hints_wrapper.py_impl(DispatchKey.CompositeExplicitAutograd)
73
+ def hints_wrapper_dense(body_fn, args, kwargs, hints):
74
+ return body_fn(*args, **kwargs)
75
+
76
+
77
+ hints_wrapper.py_autograd_impl(
78
+ autograd_not_implemented(hints_wrapper, deferred_error=True)
79
+ )
80
+
81
+
82
+ @hints_wrapper.py_impl(FakeTensorMode)
83
+ def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints):
84
+ flat_args = pytree.tree_leaves(args)
85
+ with mode:
86
+ return body_func(*flat_args, **kwargs)
87
+
88
+
89
+ @hints_wrapper.py_functionalize_impl
90
+ def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints):
91
+ from torch._higher_order_ops.utils import _check_alias_and_mutation
92
+
93
+ unwrapped_args = ctx.unwrap_tensors(args)
94
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
95
+ unwrapped_hints = ctx.unwrap_tensors(hints)
96
+ with ctx.redispatch_to_next():
97
+ functional_body_fn = ctx.functionalize(body_fn)
98
+ pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
99
+ _check_alias_and_mutation(
100
+ body_fn, unwrapped_args, "hints_wrapper", pre_dispatch
101
+ )
102
+
103
+ outputs = hints_wrapper(
104
+ functional_body_fn,
105
+ unwrapped_args,
106
+ unwrapped_kwargs,
107
+ unwrapped_hints,
108
+ )
109
+ return ctx.wrap_tensors(outputs)
110
+
111
+
112
+ def trace_hints_wrapper(proxy_mode, hints_wrapper, body_fn, args, kwargs, hints):
113
+ flat_args = tuple(pytree.tree_leaves(args))
114
+ body_graph = reenter_make_fx(body_fn)(*flat_args, **kwargs)
115
+
116
+ _, body_graph_name = unique_graph_id(proxy_mode, prefix="hints_wrapper_body_graph")
117
+ proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
118
+
119
+ new_args: tuple = (body_graph, flat_args, {})
120
+ # merge hints into kwargs
121
+ new_kwargs = {}
122
+ new_kwargs["hints"] = hints
123
+
124
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_args)
125
+ proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_kwargs)
126
+
127
+ out_proxy = proxy_mode.tracer.create_proxy(
128
+ "call_function", hints_wrapper, proxy_args, proxy_kwargs, name="hints_wrapper"
129
+ )
130
+
131
+ out = body_fn(*flat_args, **kwargs)
132
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
133
+
134
+
135
+ @hints_wrapper.py_impl(ProxyTorchDispatchMode)
136
+ def inner(proxy_mode, body_fn, args, kwargs, hints):
137
+ if proxy_mode.enable_tracing:
138
+ return trace_hints_wrapper(
139
+ proxy_mode, hints_wrapper, body_fn, args, kwargs, hints
140
+ )
141
+ else:
142
+ return hints_wrapper(body_fn, args, kwargs, hints)
archive/.venv/Lib/site-packages/torch/_higher_order_ops/invoke_subgraph.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+
4
+ import contextlib
5
+ from contextlib import nullcontext
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional, Union
8
+
9
+ import torch
10
+ import torch.utils._pytree as pytree
11
+ from torch._C import DispatchKey
12
+ from torch._dispatch.python import suspend_functionalization
13
+ from torch._higher_order_ops.utils import (
14
+ _from_fun,
15
+ _maybe_reenter_make_fx,
16
+ _set_compilation_env,
17
+ clone_outputs_aliasing_inputs,
18
+ FunctionalizeCtxWrapper,
19
+ get_dummy_aot_autograd_config,
20
+ HopInstance,
21
+ prepare_fw_with_masks,
22
+ reenter_make_fx,
23
+ register_fake,
24
+ save_tensors_and_symints_for_backward,
25
+ saved_tensors_and_symints,
26
+ )
27
+ from torch._ops import HigherOrderOperator
28
+ from torch._subclasses.functional_tensor import disable_functional_mode
29
+ from torch.fx.experimental.proxy_tensor import (
30
+ _temp_remove_metadata_torch_function_mode,
31
+ _temp_remove_pre_dispatch_torch_function_mode,
32
+ disable_proxy_modes_tracing,
33
+ ProxyTorchDispatchMode,
34
+ track_tensor_tree,
35
+ )
36
+ from torch.fx.graph_module import GraphModule
37
+ from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
38
+
39
+
40
+ invoke_subgraph_counter = 0
41
+
42
+
43
+ # During the tracing of the joint graph, we construct this information. This is
44
+ # used to filter out grad_outs/tangents in the `backward` method of
45
+ # InvokeSubgraphAutogradOp.
46
+ @dataclass
47
+ class OutputMetadata:
48
+ num_fw_outs: Optional[int] = None
49
+ indexes_with_none: set[int] = field(default_factory=set)
50
+ indexes_with_no_grad: set[int] = field(default_factory=set)
51
+
52
+
53
+ class InvokeSubgraphHOP(HigherOrderOperator):
54
+ def __init__(self) -> None:
55
+ # Invoke subgraph does not have any state, it is just a wrapper over a
56
+ # subgraph, so we can safely cache the HOP.
57
+ super().__init__("invoke_subgraph", cacheable=True)
58
+ # This is used by the fake tensor cache key validator to extract the
59
+ # subgraph and iterate over the nodes to find if all nodes are fake
60
+ # tensor cacheable.
61
+ self.subgraph_indexes = [
62
+ 0,
63
+ ]
64
+
65
+ # identifier is setup by upper part of the stack. This helps us in
66
+ # identifying two invoke_subgraph calls have same subgraph.
67
+ def __call__(
68
+ self,
69
+ subgraph: Union[GraphModule, FunctionalizeCtxWrapper],
70
+ identifier: Optional[str],
71
+ *operands,
72
+ ):
73
+ assert identifier is None or isinstance(
74
+ identifier, str
75
+ ), "identifier must be a None or a string"
76
+
77
+ assert all(
78
+ isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands
79
+ ), f"invoke_subgraph operands must be a list of tensors/ints/SymInts {operands}"
80
+
81
+ return super().__call__(subgraph, identifier, *operands)
82
+
83
+ def gen_schema(self, subgraph, identifier, *operands):
84
+ from torch._higher_order_ops.schema import HopSchemaGenerator
85
+ from torch._higher_order_ops.utils import (
86
+ check_input_alias_and_mutation_return_outputs,
87
+ materialize_as_graph,
88
+ )
89
+
90
+ gm: torch.fx.GraphModule = (
91
+ subgraph
92
+ if isinstance(subgraph, torch.fx.GraphModule)
93
+ else materialize_as_graph(subgraph, operands)
94
+ )
95
+
96
+ schema_gen = HopSchemaGenerator(self)
97
+ schema_gen.add_arg("subgraph", gm)
98
+ schema_gen.add_arg("identifier", identifier)
99
+ (
100
+ _,
101
+ _,
102
+ _,
103
+ mutated_inputs,
104
+ outputs,
105
+ ) = check_input_alias_and_mutation_return_outputs(gm, operands)
106
+ for idx, arg in enumerate(operands):
107
+ schema_gen.add_arg(f"arg{idx}", arg, is_mutated=idx in mutated_inputs)
108
+ for out in outputs:
109
+ schema_gen.add_output(out)
110
+
111
+ return schema_gen.gen_schema()
112
+
113
+
114
+ invoke_subgraph = InvokeSubgraphHOP()
115
+
116
+
117
+ def invoke_subgraph_placeholder(func, *args, **kwargs):
118
+ if torch.compiler.is_dynamo_compiling():
119
+ # This is just a placeholder for Dynamo to replace with invoke_subgraph
120
+ raise RuntimeError("invoke_subgraph should not be called directly in Dynamo")
121
+
122
+ if torch.compiler.is_compiling():
123
+ # For non-strict export tracing, we still want to go through Dynamo
124
+ from torch._dynamo.backends.debugging import (
125
+ make_eager_backend_with_torch_function_mode,
126
+ )
127
+
128
+ def _invoke_subgraph_placeholder_wrapper(func, args):
129
+ return invoke_subgraph_placeholder(func, *args)
130
+
131
+ with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
132
+ with _temp_remove_metadata_torch_function_mode() as metadata_mode:
133
+ if metadata_mode:
134
+ backend = make_eager_backend_with_torch_function_mode(metadata_mode)
135
+ else:
136
+ backend = "eager"
137
+
138
+ return torch.compile(
139
+ _invoke_subgraph_placeholder_wrapper,
140
+ backend=backend,
141
+ fullgraph=True,
142
+ )(func, args)
143
+
144
+ return func(*args, **kwargs)
145
+
146
+
147
+ def mark_compile_region(fn=None):
148
+ """
149
+ This wrapper instructs torch.compile to compile the wrapped region once and
150
+ reuse the compiled artifact, instead of the usual way of aggressively
151
+ inlining the function.
152
+
153
+ Under the hood, it tells TorchDynamo to use InvokeSubgraph HOP for the
154
+ region. For PyTorch eager, this is a no-op.
155
+ """
156
+
157
+ def wrap(func):
158
+ def inner(*args, **kwargs):
159
+ # Get the innermost function to avoid nested compile regions
160
+ inner_func = func
161
+ while hasattr(inner_func, "__marked_compile_region_fn__"):
162
+ inner_func = inner_func.__marked_compile_region_fn__
163
+ return invoke_subgraph_placeholder(inner_func, *args, **kwargs)
164
+
165
+ inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined]
166
+
167
+ return inner
168
+
169
+ if fn:
170
+ return wrap(fn)
171
+ else:
172
+ return wrap
173
+
174
+
175
+ def get_invoke_subgraph_cache():
176
+ cache = None
177
+ if tracing_ctx := torch._guards.TracingContext.try_get():
178
+ cache = tracing_ctx.hop_dispatch_set_cache.get_cache(invoke_subgraph)
179
+ return cache
180
+
181
+
182
+ # TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra
183
+ def trace_joint_graph(fn, fw_inputs, fw_outputs):
184
+ """
185
+ Naively trace out a joint graph. This simplifies the reconstruction of joint
186
+ graph in the min-cut partitioner later on.
187
+ """
188
+ from torch._functorch.aot_autograd import create_joint
189
+
190
+ dummy_aot_config = get_dummy_aot_autograd_config()
191
+
192
+ # This joint_fn is inserted as the backward graph as is. This simplifies the
193
+ # min-cut partitioner work later on.
194
+ # Input signature - (*primals, *tangents)
195
+ # Output signature - (*grads, *fw_outs)
196
+ # The output signature is deliberately kept grads first and fw_outs second.
197
+ # Having grads first makes the min-cut partitioner HOP graph stitching
198
+ # easier.
199
+ def joint_fn(*primals_and_tangents):
200
+ primals = primals_and_tangents[: len(fw_inputs)]
201
+ tangents = primals_and_tangents[len(fw_inputs) :]
202
+
203
+ fw_outs, grads = create_joint(
204
+ prepare_fw_with_masks(fn), aot_config=dummy_aot_config
205
+ )(primals, tangents)
206
+
207
+ maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
208
+
209
+ # return signature is deliberately kept (*grads, *fw_outs). This
210
+ # simplifies partitioning work later on.
211
+ return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs)))
212
+
213
+ primals = list(fw_inputs)
214
+ # This assumes that the tangent strides match fw_outputs strides. Check the
215
+ # InvokeSubgraphAutogradOp backward op for the contiguous call.
216
+ tangents = [_from_fun(out) for out in fw_outputs]
217
+
218
+ joint_operands = primals + tangents
219
+
220
+ return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
221
+
222
+
223
+ # TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra
224
+ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
225
+ with suspend_functionalization(), disable_functional_mode():
226
+ with disable_proxy_modes_tracing():
227
+ # args are functional tensors, generate some example tensors
228
+ fw_inputs = pytree.tree_map(_from_fun, operands)
229
+
230
+ from torch._guards import detect_fake_mode
231
+
232
+ fake_mode = detect_fake_mode(fw_inputs)
233
+ context = (
234
+ nullcontext()
235
+ if fake_mode is None or fake_mode.shape_env is None
236
+ else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
237
+ )
238
+
239
+ with context:
240
+ fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
241
+
242
+ num_fw_outs = len(fw_outs)
243
+
244
+ # Collect the indexes of none in the output to check that the grad
245
+ # is None at the corresponding index in the backward. This check is
246
+ # performed in the autograd.Function - InvokeSubgraphAutogradOp.
247
+ # Also collect the indexes of no_grad in the output to filter out
248
+ # the grad_outs in the `backward` method.
249
+ output_metadata = OutputMetadata()
250
+
251
+ output_metadata.num_fw_outs = num_fw_outs
252
+ for idx, fw_out in enumerate(fw_outs):
253
+ if fw_out is None:
254
+ output_metadata.indexes_with_none.add(idx)
255
+ elif not fw_out.requires_grad:
256
+ output_metadata.indexes_with_no_grad.add(idx)
257
+
258
+ if grad_outputs is None:
259
+ # Infer grad_outputs to be the same properties as the fw_outputs
260
+ # if they're not passed in
261
+ # Although fw_outs are equivalent to grad_outputs for tracing
262
+ # purposes, we have to carefully handle the None and fw_out that do
263
+ # not have require_grad. At those indexes, we will have None in the
264
+ # backward graph.
265
+ grad_outputs = fw_outs
266
+ grad_outputs = [grad for grad in grad_outputs if grad is not None]
267
+ grad_outputs = [grad for grad in grad_outputs if grad.requires_grad]
268
+
269
+ # Force grad_out to be contiguous. This is because at runtime,
270
+ # grad_out could have different strides than fw_outs. So, we
271
+ # force the grad_outs to be contiguous for both tracing and
272
+ # runtime.
273
+ grad_outputs = [grad.contiguous() for grad in grad_outputs]
274
+
275
+ if any(
276
+ not isinstance(out, torch.Tensor)
277
+ for out in grad_outputs
278
+ if out is not None
279
+ ):
280
+ raise RuntimeError(
281
+ "Expect outputs of invoke_subgraph to only contains tensors or None. "
282
+ f"Got types {[type(out) for out in grad_outputs]}."
283
+ )
284
+
285
+ # Trace the forward subgraph
286
+ fw_graph = _maybe_reenter_make_fx(subgraph)(*fw_inputs)
287
+
288
+ # Trace the joint graph and assign it to the bwd graph
289
+ bw_graph = trace_joint_graph(
290
+ subgraph,
291
+ fw_inputs,
292
+ grad_outputs,
293
+ )
294
+ return fw_graph, bw_graph, output_metadata
295
+
296
+
297
+ def get_output_metadata(subgraph, *operands):
298
+ with suspend_functionalization(), disable_functional_mode():
299
+ with disable_proxy_modes_tracing():
300
+ # args are functional tensors, generate some example tensors
301
+ fw_inputs = pytree.tree_map(_from_fun, operands)
302
+
303
+ from torch._guards import detect_fake_mode
304
+
305
+ fake_mode = detect_fake_mode(fw_inputs)
306
+ context = (
307
+ nullcontext()
308
+ if fake_mode is None or fake_mode.shape_env is None
309
+ else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
310
+ )
311
+
312
+ with context:
313
+ fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
314
+
315
+ num_fw_outs = len(fw_outs)
316
+
317
+ # Collect the indexes of none in the output to check that the grad
318
+ # is None at the corresponding index in the backward. This check is
319
+ # performed in the autograd.Function - InvokeSubgraphAutogradOp.
320
+ # Also collect the indexes of no_grad in the output to filter out
321
+ # the grad_outs in the `backward` method.
322
+ output_metadata = OutputMetadata()
323
+
324
+ output_metadata.num_fw_outs = num_fw_outs
325
+ for idx, fw_out in enumerate(fw_outs):
326
+ if fw_out is None:
327
+ output_metadata.indexes_with_none.add(idx)
328
+ elif not fw_out.requires_grad:
329
+ output_metadata.indexes_with_no_grad.add(idx)
330
+ return output_metadata
331
+
332
+
333
+ def trace_joint_graph_as_bwd(
334
+ subgraph, num_primals, joint_operands, include_key_set, exclude_key_set
335
+ ):
336
+ """
337
+ Naively trace out a joint graph. This simplifies the reconstruction of joint
338
+ graph in the min-cut partitioner later on.
339
+ """
340
+ from torch._functorch.aot_autograd import create_joint
341
+
342
+ dummy_aot_config = get_dummy_aot_autograd_config()
343
+
344
+ if isinstance(subgraph, torch.fx.GraphModule):
345
+
346
+ def graph_with_interpreter(*args):
347
+ # Running graph with interpreter is needed for propagating the stack_trace
348
+ with torch.fx.traceback.preserve_node_meta():
349
+ return torch.fx.Interpreter(subgraph).run(*args)
350
+
351
+ fn = graph_with_interpreter
352
+ else:
353
+ fn = subgraph
354
+
355
+ # This joint_fn is inserted as the backward graph as is. This simplifies the
356
+ # min-cut partitioner work later on.
357
+ # Input signature - (*primals, *tangents)
358
+ # Output signature - (*grads, *fw_outs)
359
+ # The output signature is deliberately kept grads first and fw_outs second.
360
+ # Having grads first makes the min-cut partitioner HOP graph stitching
361
+ # easier.
362
+ def joint_fn(*primals_and_tangents):
363
+ primals = primals_and_tangents[:num_primals]
364
+ tangents = primals_and_tangents[num_primals:]
365
+
366
+ fw_outs, grads = create_joint(
367
+ prepare_fw_with_masks(fn), aot_config=dummy_aot_config
368
+ )(primals, tangents)
369
+
370
+ maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
371
+
372
+ # return signature is deliberately kept (*grads, *fw_outs). This
373
+ # simplifies partitioning work later on.
374
+ return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs)))
375
+
376
+ with suspend_functionalization(), disable_functional_mode():
377
+ with disable_proxy_modes_tracing():
378
+ joint_operands = [_from_fun(arg) for arg in joint_operands]
379
+ with contextlib.ExitStack() as stack:
380
+ stack.enter_context(
381
+ torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
382
+ )
383
+ with torch.enable_grad():
384
+ return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
385
+
386
+
387
+ class InvokeSubgraphAutogradOp(torch.autograd.Function):
388
+ """
389
+ Saves the subgraph, i.e. original callable, in the forward method. And then
390
+ traces out a joint graph in the backward. This delaying of tracing in
391
+ backward, also called as lazy backward, ensures that the assumptions about
392
+ the grad_out strides and tensor-subclass-ness are already accounted for.
393
+ """
394
+
395
+ @staticmethod
396
+ def forward(
397
+ ctx,
398
+ subgraph,
399
+ identifier,
400
+ output_metadata,
401
+ *operands,
402
+ ):
403
+ # We want to delay the backward graph construction until the backward.
404
+ # So in forward, we just run the fw callable as is. And save all the
405
+ # information necessary to construct the backward graph in the ctx.
406
+ ctx._subgraph = subgraph
407
+ ctx._identifier = identifier
408
+ ctx._output_metadata = output_metadata
409
+ # We snapshot the dispatch keys in forward for materializing the
410
+ # the bw_graph in backward.
411
+ ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
412
+ ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
413
+
414
+ save_tensors_and_symints_for_backward(ctx, operands)
415
+
416
+ with torch._C._AutoDispatchBelowAutograd():
417
+ out = invoke_subgraph(
418
+ subgraph,
419
+ f"fw_{identifier}",
420
+ *operands,
421
+ )
422
+
423
+ # Check that None is at expected indexes.
424
+ for idx, o in enumerate(out):
425
+ if o is None:
426
+ assert idx in output_metadata.indexes_with_none
427
+
428
+ return out
429
+
430
+ @staticmethod
431
+ def backward(
432
+ ctx,
433
+ *grad_outs,
434
+ ):
435
+ from torch._dynamo.utils import dynamo_timed
436
+
437
+ subgraph = ctx._subgraph
438
+ identifier = ctx._identifier
439
+ output_metadata = ctx._output_metadata
440
+ primals = saved_tensors_and_symints(ctx)
441
+
442
+ # Filter out grads that are None or do not require_grad. This was
443
+ # the assumption we made during the tracing of joint_graph.
444
+ filtered_grad_outs = []
445
+ for idx, o in enumerate(grad_outs):
446
+ if o is None:
447
+ assert idx in output_metadata.indexes_with_none
448
+ elif idx in output_metadata.indexes_with_no_grad:
449
+ # Deliberately skip over the grad_outs which we know should be
450
+ # None because the corresponding fwd_out does not require_grad.
451
+ pass
452
+ else:
453
+ filtered_grad_outs.append(o)
454
+ filtered_grad_outs = tuple(filtered_grad_outs)
455
+
456
+ # Important note - Even though the forward graph can be same for
457
+ # different invoke_subgraphs, the backward graph can be different
458
+ # because the tangent strides can be different. So, here we cache on
459
+ # tangent_metadata in addition to identifier
460
+ from torch._guards import detect_fake_mode
461
+ from torch._subclasses._fake_tensor_utils import _CacheKeyState
462
+ from torch._subclasses.fake_tensor import extract_tensor_metadata
463
+
464
+ fake_mode = detect_fake_mode(primals + filtered_grad_outs)
465
+ state = _CacheKeyState(fake_mode.shape_env)
466
+
467
+ tangent_metadata: list[object] = []
468
+ for tangent in filtered_grad_outs:
469
+ metadata = extract_tensor_metadata(tangent)
470
+ metadata._flatten_into(tangent_metadata, fake_mode, state)
471
+ tangent_metadata = tuple(tangent_metadata)
472
+
473
+ # bw_graph is a joint graph with signature (*primals_and_tangents) and
474
+ # returns (*grads_and_fw_outs). To get the grads, we use the num_fw_outs
475
+ # to extract the grads.
476
+ primals_and_tangents = primals + filtered_grad_outs
477
+
478
+ # Check if we have already traced the bwd subgraph.
479
+ bw_graph = None
480
+ suffix = None
481
+ invoke_subgraph_cache = get_invoke_subgraph_cache()
482
+ cache_hit = False
483
+ if invoke_subgraph_cache:
484
+ bw_graph, suffix = invoke_subgraph_cache.get_lazy_bwd_entry(
485
+ identifier, tangent_metadata
486
+ )
487
+ cache_hit = bw_graph is not None
488
+
489
+ if bw_graph is None:
490
+ assert suffix is None
491
+ with dynamo_timed(
492
+ "invoke_subgraph_trace_joint_graph", log_pt2_compile_event=True
493
+ ):
494
+ bw_graph = trace_joint_graph_as_bwd(
495
+ subgraph,
496
+ len(primals),
497
+ primals_and_tangents,
498
+ ctx._fw_include_key_set,
499
+ ctx._fw_exclude_key_set,
500
+ )
501
+
502
+ if invoke_subgraph_cache and not cache_hit:
503
+ suffix = invoke_subgraph_cache.add_lazy_bwd_entry(
504
+ identifier, tangent_metadata, bw_graph
505
+ )
506
+
507
+ grads = invoke_subgraph(
508
+ bw_graph, f"bw_{identifier}_{suffix}", *primals_and_tangents
509
+ )[: -output_metadata.num_fw_outs]
510
+ return None, None, None, *grads
511
+
512
+
513
+ @invoke_subgraph.py_autograd_impl
514
+ def _(subgraph, identifier, *operands):
515
+ # Check if we have already traced the subgraph.
516
+ invoke_subgraph_cache = get_invoke_subgraph_cache()
517
+ if invoke_subgraph_cache:
518
+ if saved_autograd_fn := invoke_subgraph_cache.get_autograd_key_entry(
519
+ identifier
520
+ ):
521
+ return saved_autograd_fn(*operands)
522
+
523
+ output_metadata = get_output_metadata(subgraph, *operands)
524
+
525
+ def autograd_fn_callable(*args):
526
+ return InvokeSubgraphAutogradOp.apply(
527
+ subgraph, identifier, output_metadata, *args
528
+ )
529
+
530
+ # Save the autograd_fn_callable in the dispatch set cache.
531
+ if invoke_subgraph_cache:
532
+ invoke_subgraph_cache.add_autograd_key_entry(identifier, autograd_fn_callable)
533
+
534
+ return autograd_fn_callable(*operands)
535
+
536
+
537
+ @invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
538
+ def _(subgraph, identifier, *operands):
539
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
540
+
541
+ mode = _get_current_dispatch_mode()
542
+ assert mode is None, "Mode should never be enabled for CPU/CUDA key"
543
+ return subgraph(*operands)
544
+
545
+
546
+ @invoke_subgraph.py_functionalize_impl
547
+ def _(ctx, subgraph, identifier, *operands):
548
+ from torch._higher_order_ops.auto_functionalize import (
549
+ can_auto_functionalize,
550
+ do_auto_functionalize_v2,
551
+ )
552
+
553
+ unwrapped_operands = ctx.unwrap_tensors(operands)
554
+ hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands)
555
+ if can_auto_functionalize(hop_instance):
556
+ # NOTE: [auto_functionalize x invoke_subgraph caching]
557
+ # We call auto_functionalized_v2 to support input mutation of invoke_subgraph.
558
+ # See NOTE [Support input mutation of hops] for the overall design.
559
+ #
560
+ # invoke_subgraph is special because of its identifier based caching machanism.
561
+ # In invoke_subgraph's functionalization key implementation, we create a new
562
+ # identifer because the subgraph is replaced by FunctionWithNoFreeVars in a
563
+ # functional + epilogue form.
564
+ assert isinstance(identifier, str), identifier
565
+ return do_auto_functionalize_v2(
566
+ ctx.mode,
567
+ hop_instance,
568
+ (subgraph, "auto_functionalized_" + identifier, *operands),
569
+ {},
570
+ )
571
+
572
+ with ctx.redispatch_to_next():
573
+ # NB: There is an assumption that subgraph does not mutate inputs and
574
+ # there is no aliasing. Its Dynamo responsibility to prevent formation
575
+ # of invoke_subgraph ops if input aliasing/mutation is detected.
576
+ functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph)
577
+ out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands)
578
+ return ctx.wrap_tensors(out)
579
+
580
+
581
+ # Register the hop fake fn. This will be called in the fake_tensor _dispatch_impl.
582
+ @register_fake(invoke_subgraph)
583
+ def _(subgraph, identifier, *operands):
584
+ from torch._dynamo.utils import dynamo_timed
585
+
586
+ with dynamo_timed("invoke_subgraph_fake_tensor", log_pt2_compile_event=True):
587
+ return subgraph(*operands)
588
+
589
+
590
+ @invoke_subgraph.py_impl(ProxyTorchDispatchMode)
591
+ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands):
592
+ # Check if we have already traced the subgraph.
593
+ graph = None
594
+ invoke_subgraph_cache = get_invoke_subgraph_cache()
595
+ if invoke_subgraph_cache:
596
+ graph = invoke_subgraph_cache.get_proxy_dispatch_entry(identifier)
597
+
598
+ if graph is None:
599
+ from torch._dynamo.utils import dynamo_timed
600
+
601
+ with dynamo_timed("invoke_subgraph_proxy_tensor", log_pt2_compile_event=True):
602
+ graph = reenter_make_fx(subgraph)(*operands)
603
+
604
+ from torch._guards import detect_fake_mode
605
+
606
+ fake_mode = detect_fake_mode(operands)
607
+ insert_deferred_runtime_asserts(
608
+ graph,
609
+ fake_mode.shape_env,
610
+ "invoke_subgraph_proxy_torch_dispatch_mode",
611
+ export=True,
612
+ )
613
+ graph.recompile()
614
+
615
+ assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
616
+ if invoke_subgraph_cache:
617
+ invoke_subgraph_cache.add_proxy_dispatch_entry(identifier, graph)
618
+
619
+ node_args = (graph, identifier, *operands)
620
+
621
+ def _unwrap_proxy(arg):
622
+ if isinstance(arg, torch.fx.GraphModule):
623
+ # NOTE: [invoke_subgraph proxy_mode x auto_functionalize]
624
+ # Previously, we assumed that `invoke_subgraph` would always be traced with the same tracer.
625
+ # This allowed us to cache modules by their identifiers, assuming they were already registered.
626
+ #
627
+ # However, this assumption no longer holds when we auto-functionalize `invoke_subgraph`.
628
+ # auto_functionalize functionalizes the subgraph and wrap it with `FunctionWithNoFreeVars`.
629
+ # In the proxy mode implementation of `auto_functionalized_v2`, we need to materialize `FunctionWithNoFreeVars`
630
+ # input as a graph module. To do this, we re-trace the `invoke_subgraph` hop, which starts a new sub-tracer
631
+ # (see NOTE [materialize callable inputs as graph]). # When the new sub-tracer traces the `invoke_subgraph`
632
+ # with a previously cached identifier, the corresponding graph module might not
633
+ # exist as a submodule in the new tracer's root. Therefore, we register it as a submodule below.
634
+ #
635
+ # The alternative is to give a new identifer when we re-trace the invoke_subgraph but this will increase
636
+ # the compilatoin time, which defeats the purpose of caching.
637
+ registered_before = False
638
+ for (
639
+ _,
640
+ submod,
641
+ ) in proxy_mode.tracer.root.named_modules(): # type: ignore[union-attr]
642
+ if arg is submod:
643
+ registered_before = True
644
+
645
+ if not registered_before:
646
+ qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph") # type: ignore[union-attr]
647
+ proxy_mode.tracer.root.register_module(qualname, arg) # type: ignore[union-attr]
648
+ return proxy_mode.tracer.unwrap_proxy(arg) # type: ignore[union-attr]
649
+
650
+ proxy_args = pytree.tree_map(_unwrap_proxy, node_args) # type: ignore[union-attr]
651
+ out_proxy = proxy_mode.tracer.create_proxy(
652
+ "call_function", invoke_subgraph, proxy_args, {}
653
+ )
654
+
655
+ example_out = invoke_subgraph(graph, identifier, *operands)
656
+ return track_tensor_tree(
657
+ example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
658
+ )
archive/.venv/Lib/site-packages/torch/_higher_order_ops/map.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ from typing import Callable, Union
4
+ from typing_extensions import TypeVarTuple
5
+
6
+ import torch
7
+ import torch.utils._pytree as pytree
8
+ from torch._C import DispatchKey
9
+ from torch._dispatch.python import suspend_functionalization
10
+ from torch._higher_order_ops.utils import _maybe_run_with_interpreter, reenter_make_fx
11
+ from torch._ops import HigherOrderOperator
12
+ from torch._subclasses.fake_tensor import FakeTensorMode
13
+ from torch._subclasses.functional_tensor import disable_functional_mode
14
+ from torch.fx.experimental.proxy_tensor import (
15
+ disable_proxy_modes_tracing,
16
+ make_fx,
17
+ ProxyTorchDispatchMode,
18
+ track_tensor_tree,
19
+ )
20
+
21
+ from .utils import (
22
+ _from_fun,
23
+ _stack_pytree,
24
+ _unstack_pytree,
25
+ clone_outputs_aliasing_inputs,
26
+ prepare_fw_with_masks,
27
+ save_tensors_and_symints_for_backward,
28
+ saved_tensors_and_symints,
29
+ )
30
+
31
+
32
+ class MapImpl(HigherOrderOperator):
33
+ def __init__(self):
34
+ super().__init__("map_impl")
35
+
36
+ def __call__(self, *args, **kwargs):
37
+ return super().__call__(*args, **kwargs)
38
+
39
+
40
+ map_impl = MapImpl()
41
+
42
+
43
+ def create_fw_bw_graph(f, num_mapped_args, *args):
44
+ mapped_xs = args[:num_mapped_args]
45
+ pos_args = args[num_mapped_args:]
46
+
47
+ # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
48
+
49
+ with suspend_functionalization(), disable_functional_mode():
50
+ with disable_proxy_modes_tracing():
51
+ unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs)
52
+ example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]
53
+
54
+ example_pos_args = [
55
+ _from_fun(arg) if isinstance(arg, torch.Tensor) else arg
56
+ for arg in pos_args
57
+ ]
58
+ example_flat_out = pytree.tree_map(
59
+ _from_fun, f(*example_xs, *example_pos_args)
60
+ )
61
+ if any(
62
+ not isinstance(out, torch.Tensor)
63
+ for out in example_flat_out
64
+ if out is not None
65
+ ):
66
+ raise RuntimeError(
67
+ "Expect outputs of map only contains tensors or None. "
68
+ f"Got types {[type(out) for out in example_flat_out]}."
69
+ )
70
+ example_grad = [_from_fun(out) for out in example_flat_out]
71
+
72
+ fw_graph = make_fx(f)(*example_xs, *example_pos_args)
73
+
74
+ from torch._functorch.aot_autograd import AOTConfig, create_joint
75
+
76
+ dummy_aot_config = AOTConfig(
77
+ fw_compiler=None, # type: ignore[arg-type]
78
+ bw_compiler=None, # type: ignore[arg-type]
79
+ partition_fn=None, # type: ignore[arg-type]
80
+ decompositions={},
81
+ num_params_buffers=0,
82
+ aot_id=0,
83
+ keep_inference_input_mutations=False,
84
+ )
85
+
86
+ def joint_f(*example_args):
87
+ joint_mapped_args = example_args[:joint_num_mapped]
88
+ args = example_args[joint_num_mapped:]
89
+
90
+ mapped_input = joint_mapped_args[:num_mapped_args]
91
+ mapped_grads = joint_mapped_args[num_mapped_args:]
92
+
93
+ joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config)
94
+ _, grads = joint(
95
+ list(mapped_input) + list(args),
96
+ [
97
+ grad
98
+ for grad in mapped_grads
99
+ if grad is not None and grad.requires_grad
100
+ ],
101
+ )
102
+
103
+ # In order to keep map functional for backward graph,
104
+ # we clone outputs that are aliasing inputs
105
+ maybe_clone = clone_outputs_aliasing_inputs(example_args)
106
+
107
+ return pytree.tree_map(maybe_clone, grads)
108
+
109
+ joint_num_mapped = len(example_grad) + len(example_xs)
110
+ joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
111
+ return fw_graph, joint_graph
112
+
113
+
114
+ def map(
115
+ f: Callable[[pytree.PyTree, tuple[pytree.PyTree, ...]], pytree.PyTree],
116
+ xs: Union[pytree.PyTree, torch.Tensor],
117
+ *args: TypeVarTuple,
118
+ ):
119
+ r"""
120
+ Perfoms a map of f with xs. Intuitively, you can think of the semantic being:
121
+
122
+ out = []
123
+ for idx in len(xs.size(0)):
124
+ xs_sliced = xs.select(0, idx)
125
+ out.append(f(xs_sliced, *args))
126
+ torch.stack(out)
127
+
128
+ .. warning::
129
+ `torch._higher_order_ops.map` is a prototype feature in PyTorch. It currently
130
+ does not support autograd and you may run into miscompiles.
131
+ Read more about feature classification at:
132
+ https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
133
+
134
+
135
+ Args:
136
+ f (Callable): a callable that takes an input x, that could either be a single Tensor
137
+ or a nested dict, list of tensors and some additional inputs
138
+ xs: the inputs that're to be mapped over. We'll iterate over the first dim of each x
139
+ and perform f on each slice.
140
+
141
+ *args: additional arguments provided to each step of f. They could also be omitted and
142
+ map is able to automatically figure out the read dependency.
143
+
144
+ Return:
145
+ the stacked output for each step of f
146
+
147
+ Example:
148
+
149
+ def f(xs):
150
+ return xs[0] + xs[1] + const1 + const2
151
+
152
+ xs = [torch.randn(2, 3), torch.randn(2, 3)]
153
+ const1 = torch.randn(2, 3)
154
+ const2 = torch.randn(2, 3)
155
+ # returns a tensor of shape [2, 2, 3]
156
+ torch._higher_order_ops.map(f, xs)
157
+
158
+ """
159
+ flat_xs, xs_spec = pytree.tree_flatten(xs)
160
+ flat_args, args_spec = pytree.tree_flatten(args)
161
+ if not all(isinstance(t, torch.Tensor) for t in flat_xs):
162
+ raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
163
+
164
+ shapes = [xs.shape for xs in flat_xs]
165
+ leading_dim_size = shapes[0][0]
166
+ if leading_dim_size == 0:
167
+ raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
168
+
169
+ if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
170
+ raise RuntimeError(
171
+ f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
172
+ )
173
+
174
+ def run_flattened_map(f, flat_xs, flat_args):
175
+ def wrapped_fn(*flat_args, f, xs_tree_spec, args_tree_spec, num_xs):
176
+ xs = pytree.tree_unflatten(flat_args[:num_xs], xs_tree_spec)
177
+ args = pytree.tree_unflatten(flat_args[num_xs:], args_tree_spec)
178
+ return f(xs, *args)
179
+
180
+ inner_f = functools.partial(
181
+ wrapped_fn,
182
+ f=f,
183
+ xs_tree_spec=xs_spec,
184
+ args_tree_spec=args_spec,
185
+ num_xs=len(flat_xs),
186
+ )
187
+ return map_impl(inner_f, flat_xs, flat_args)
188
+
189
+ from torch._higher_order_ops.utils import _maybe_compile_and_run_fn
190
+
191
+ return _maybe_compile_and_run_fn(run_flattened_map, f, flat_xs, flat_args)
192
+
193
+
194
+ class MapAutogradOp(torch.autograd.Function):
195
+ @staticmethod
196
+ def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
197
+ save_tensors_and_symints_for_backward(ctx, flat_args)
198
+ ctx._joint_graph = joint_graph
199
+ ctx._num_mapped_args = num_mapped_args
200
+ with torch._C._AutoDispatchBelowAutograd():
201
+ return (
202
+ *map_impl(
203
+ fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
204
+ ),
205
+ )
206
+
207
+ @staticmethod
208
+ def backward(ctx, *flat_grads):
209
+ fw_args = saved_tensors_and_symints(ctx)
210
+ fw_mapped_args = fw_args[: ctx._num_mapped_args]
211
+ pos_args = fw_args[ctx._num_mapped_args :]
212
+
213
+ grads = map_impl(
214
+ ctx._joint_graph,
215
+ fw_mapped_args + flat_grads,
216
+ pos_args,
217
+ )
218
+ return None, None, None, *grads
219
+
220
+
221
+ def trace_map(proxy_mode, func_overload, f, xs, pos_args):
222
+ example_input = _unstack_pytree(xs)[0]
223
+ body_graph = f
224
+
225
+ body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args)
226
+
227
+ next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_")
228
+
229
+ proxy_mode.tracer.root.register_module(next_name, body_graph)
230
+
231
+ fake_outs = map_impl(body_graph, xs, pos_args)
232
+
233
+ node_args = (body_graph, list(xs), list(pos_args))
234
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
235
+ out_proxy = proxy_mode.tracer.create_proxy(
236
+ "call_function", func_overload, proxy_args, {}, name="map_impl"
237
+ )
238
+ return track_tensor_tree(
239
+ fake_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
240
+ )
241
+
242
+
243
+ @map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
244
+ def map_dense(f, xs, pos_args):
245
+ pytrees = [f(*inp, *pos_args) for inp in _unstack_pytree(xs)]
246
+ return _stack_pytree(pytrees)
247
+
248
+
249
+ @map_impl.py_autograd_impl
250
+ def map_autograd(f, xs, pos_args):
251
+ num_mapped_args = len(xs)
252
+ fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
253
+ flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
254
+ return flat_out
255
+
256
+
257
+ @map_impl.py_impl(ProxyTorchDispatchMode)
258
+ def map_proxy_torch_dispatch_mode(mode, f, xs, args):
259
+ return trace_map(mode, map_impl, f, xs, args)
260
+
261
+
262
+ @map_impl.py_impl(FakeTensorMode)
263
+ def map_fake_tensor_mode(mode, f, xs, args):
264
+ with mode:
265
+ return map_dense(f, xs, args)
266
+
267
+
268
+ @map_impl.py_functionalize_impl
269
+ def map_functionalize(ctx, f, xs, pos_args):
270
+ from torch._higher_order_ops.utils import _check_alias_and_mutation
271
+
272
+ unwrapped_xs = ctx.unwrap_tensors(xs)
273
+ unwrapped_args = ctx.unwrap_tensors(pos_args)
274
+ wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f))
275
+
276
+ with ctx.redispatch_to_next():
277
+ example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
278
+ pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
279
+ _check_alias_and_mutation(f, example_inputs, "map", pre_dispatch)
280
+ map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
281
+ return ctx.wrap_tensors(map_return)
282
+
283
+
284
+ def _fake_map(f, x, *args):
285
+ from functorch.experimental.control_flow import _stack_pytree, _unstack_pytree
286
+
287
+ x_pytrees = _unstack_pytree(x)
288
+ zs = []
289
+ for xp in x_pytrees:
290
+ zs.append(f(xp, *args))
291
+ return _stack_pytree(zs)
archive/.venv/Lib/site-packages/torch/_higher_order_ops/out_dtype.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import torch
4
+ import torch.utils._pytree as pytree
5
+ from torch._C import DispatchKey
6
+ from torch._higher_order_ops.utils import autograd_not_implemented
7
+ from torch._ops import HigherOrderOperator
8
+ from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
9
+ from torch._subclasses.fake_tensor import FakeTensorMode
10
+ from torch.fx.experimental.proxy_tensor import (
11
+ disable_proxy_modes_tracing,
12
+ maybe_handle_decomp,
13
+ ProxyTorchDispatchMode,
14
+ track_tensor_tree,
15
+ )
16
+
17
+
18
+ # TODO to figure out a more generic approach
19
+ ALLOWABLE_OPS = [
20
+ torch.ops.aten.linear.default,
21
+ torch.ops.aten.mm.default,
22
+ torch.ops.aten.conv2d.default,
23
+ torch.ops.aten.convolution.default,
24
+ torch.ops.aten.mul.Tensor,
25
+ torch.ops.aten.mul.Scalar,
26
+ torch.ops.aten.div.Tensor,
27
+ torch.ops.aten.div.Scalar,
28
+ ]
29
+
30
+
31
+ class OutDtypeOperator(HigherOrderOperator):
32
+ """
33
+ The out_dtype operator takes an existing ATen functional operator, an
34
+ `out_dtype` argument, and arguments to the original operator, and executes
35
+ the original operator and returns a Tensor with the `out_dtype` precision.
36
+ This operator does not mandate a compute precision so it allows the
37
+ representation to not be opinionated about the exact implementation.
38
+
39
+ The general implementation for all operators will be the following:
40
+ 1. Promote inputs dtypes based on default PyTorch dtype promotion rules,
41
+ using the dtypes of all input Tensors/Scalars and the `out_dtype`
42
+ arugument.
43
+ 2. Execute the operator
44
+ 3. Cast the output to `out_dtype`
45
+ """
46
+
47
+ def __init__(self) -> None:
48
+ super().__init__("out_dtype")
49
+
50
+ def __call__(self, op, output_dtype, *args):
51
+ if not isinstance(op, torch._ops.OpOverload):
52
+ raise ValueError("out_dtype's first argument must be an OpOverload")
53
+ if op._schema.is_mutable:
54
+ raise ValueError(
55
+ "out_dtype's first argument needs to be a functional operator"
56
+ )
57
+ if not (
58
+ len(op._schema.returns) == 1
59
+ and isinstance(op._schema.returns[0].type, torch.TensorType)
60
+ ):
61
+ raise ValueError(
62
+ "out_dtype's can only apply to ops that return a single tensor"
63
+ f"Instead got {[r.type for r in op._schema.returns]}"
64
+ )
65
+
66
+ if op not in ALLOWABLE_OPS:
67
+ raise ValueError(
68
+ f"out_dtype only allows the following operators: {ALLOWABLE_OPS}."
69
+ )
70
+
71
+ res = super().__call__(op, output_dtype, *args)
72
+
73
+ return res
74
+
75
+
76
+ out_dtype = OutDtypeOperator()
77
+
78
+
79
+ def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args):
80
+ # NB: Long-term we should put the decomposition logic into
81
+ # ProxyTorchDispatchMode so that people do not need to call maybe_handle_decomp
82
+ # in all HigherOrderOp proxy implementations.
83
+ r = maybe_handle_decomp(proxy_mode, func_overload, (op, output_dtype, *args), {})
84
+ if r is not NotImplemented:
85
+ return r
86
+
87
+ with disable_proxy_modes_tracing():
88
+ # This is a simplified implementation of this operator just for tracing.
89
+ # Actual implementation may also first promote the arguments
90
+ out = op(*args).to(dtype=output_dtype)
91
+
92
+ node_args = (op, output_dtype, *args)
93
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
94
+ out_proxy = proxy_mode.tracer.create_proxy(
95
+ "call_function", func_overload, proxy_args, {}, name="out_dtype"
96
+ )
97
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
98
+
99
+
100
+ @out_dtype.py_impl(DispatchKey.CompositeExplicitAutograd)
101
+ def out_dtype_dense(op: torch._ops.OpOverload, output_dtype: torch.dtype, *args):
102
+ if is_int_mm(op, output_dtype, args):
103
+ return torch._int_mm(*args)
104
+ return out_dtype_fallback(op, output_dtype, *args)
105
+
106
+
107
+ def is_int_mm(op, output_dtype, args):
108
+ return (
109
+ op == torch.ops.aten.mm.default
110
+ and output_dtype == torch.int32
111
+ and len(args) == 2
112
+ and args[0].dtype == torch.int8
113
+ and args[1].dtype == torch.int8
114
+ and args[0].is_cuda
115
+ and args[1].is_cuda
116
+ )
117
+
118
+
119
+ def out_dtype_fallback(op, output_dtype, *args):
120
+ flat_inputs = pytree.arg_tree_leaves(*args) + [torch.ones(1, dtype=output_dtype)]
121
+ promote_dtype: torch.dtype = elementwise_dtypes(
122
+ *flat_inputs,
123
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
124
+ )[0]
125
+
126
+ casted_args = pytree.tree_map_only(
127
+ torch.Tensor, lambda arg: arg.to(dtype=promote_dtype), args
128
+ )
129
+ res = op(*casted_args).to(dtype=output_dtype)
130
+ return res
131
+
132
+
133
+ out_dtype.py_autograd_impl(autograd_not_implemented(out_dtype, deferred_error=True))
134
+
135
+
136
+ @out_dtype.py_impl(ProxyTorchDispatchMode)
137
+ def out_dtype_proxy(
138
+ mode: ProxyTorchDispatchMode,
139
+ op: torch._ops.OpOverload,
140
+ output_dtype: torch.dtype,
141
+ *args,
142
+ ):
143
+ return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
144
+
145
+
146
+ @out_dtype.py_impl(FakeTensorMode)
147
+ def out_dtype_fake_tensor_mode(
148
+ mode: FakeTensorMode,
149
+ op: torch._ops.OpOverload,
150
+ output_dtype: torch.dtype,
151
+ *args,
152
+ ):
153
+ with mode:
154
+ return out_dtype_dense(op, output_dtype, *args)
155
+
156
+
157
+ @out_dtype.py_functionalize_impl
158
+ def out_dtype_func(ctx, op, output_dtype, *args):
159
+ unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
160
+
161
+ with ctx.redispatch_to_next():
162
+ res = out_dtype(op, output_dtype, *unwrapped_args)
163
+ return ctx.wrap_tensors(res)
archive/.venv/Lib/site-packages/torch/_higher_order_ops/run_const_graph.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ from torch._C import DispatchKey
4
+ from torch._higher_order_ops.utils import autograd_not_implemented
5
+ from torch._ops import HigherOrderOperator
6
+ from torch._subclasses.fake_tensor import FakeTensorMode
7
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
8
+ from torch.utils import _pytree as pytree
9
+
10
+
11
+ class RunConstGraph(HigherOrderOperator):
12
+ def __init__(self):
13
+ super().__init__("run_const_graph")
14
+
15
+ def __call__(self, graph, args):
16
+ return super().__call__(graph, args)
17
+
18
+
19
+ run_const_graph = RunConstGraph()
20
+
21
+
22
+ @run_const_graph.py_impl(ProxyTorchDispatchMode)
23
+ def run_const_graph_dispatch_mode(mode, graph, args):
24
+ const_gm, weights = graph, args
25
+ p_args = pytree.tree_map(mode.tracer.unwrap_proxy, (graph, args))
26
+ assert isinstance(const_gm, torch.fx.GraphModule)
27
+ assert not hasattr(mode.tracer.root, "_const_graph")
28
+ mode.tracer.root.register_module("_const_graph", const_gm)
29
+
30
+ proxy = mode.tracer.create_proxy("call_function", run_const_graph, p_args, {})
31
+
32
+ out = const_gm(*weights)
33
+ return track_tensor_tree(out, proxy, constant=None, tracer=mode.tracer)
34
+
35
+
36
+ @run_const_graph.py_functionalize_impl
37
+ def run_const_graph_functional(ctx, graph, args):
38
+ unwrapped_args = ctx.unwrap_tensors(args)
39
+
40
+ with ctx.redispatch_to_next():
41
+ out = run_const_graph(*unwrapped_args)
42
+ return ctx.wrap_tensors(out)
43
+
44
+
45
+ run_const_graph.py_autograd_impl(
46
+ autograd_not_implemented(run_const_graph, deferred_error=True)
47
+ )
48
+
49
+
50
+ @run_const_graph.py_impl(FakeTensorMode)
51
+ def run_const_graph_fake_tensor_mode(mode, graph, args):
52
+ assert isinstance(graph, torch.fx.GraphModule)
53
+ with mode:
54
+ return graph(*args)
55
+
56
+
57
+ @run_const_graph.py_impl(DispatchKey.CPU)
58
+ def run_const_graph_cpu(graph, args):
59
+ assert isinstance(graph, torch.fx.GraphModule)
60
+ return graph(*args)
archive/.venv/Lib/site-packages/torch/_higher_order_ops/scan.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import itertools
4
+ from collections.abc import Sequence
5
+ from typing import Any, Callable, Optional
6
+
7
+ import torch
8
+ import torch._prims_common as utils
9
+ import torch.utils._pytree as pytree
10
+ from torch._C import DispatchKey
11
+ from torch._higher_order_ops.cond import create_bw_fn
12
+ from torch._higher_order_ops.utils import (
13
+ _maybe_compile_and_run_fn,
14
+ check_meta_consistency,
15
+ first_slice_copy,
16
+ materialize_as_graph,
17
+ reenter_make_fx,
18
+ save_tensors_and_symints_for_backward,
19
+ saved_tensors_and_symints,
20
+ unique_graph_id,
21
+ validate_subgraph_args_types,
22
+ )
23
+ from torch._ops import HigherOrderOperator
24
+ from torch._subclasses.fake_tensor import FakeTensorMode
25
+ from torch.fx.experimental.proxy_tensor import (
26
+ disable_proxy_modes_tracing,
27
+ ProxyTorchDispatchMode,
28
+ track_tensor_tree,
29
+ )
30
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
31
+
32
+
33
+ aten = torch._ops.ops.aten
34
+
35
+
36
+ def wrap_combine_fn_flat(
37
+ *args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves
38
+ ):
39
+ assert len(args) == (
40
+ num_init_leaves + num_inp_leaves
41
+ ), f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
42
+ carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init)
43
+ xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs)
44
+ return combine_fn(carry, xs)
45
+
46
+
47
+ def _extract_carry_and_out(flat_out: list[Any], num_carry: int):
48
+ return split_into_chunks(flat_out, [num_carry, len(flat_out) - num_carry])
49
+
50
+
51
+ # We also do a clone with contiguous_format. This is to be consistent with
52
+ # eager semantic of scan, which stacks the outputs. The result is contiguous
53
+ # as a result of the stack operation.
54
+ def stack_y(y: torch.Tensor, scan_length: int) -> torch.Tensor:
55
+ return (
56
+ y.unsqueeze(0)
57
+ .repeat(*([scan_length] + [1] * y.ndim))
58
+ .clone(memory_format=torch.contiguous_format)
59
+ )
60
+
61
+
62
+ # NOTE: These functions can be reused in associative_scan and eventually moved to
63
+ # torch._higher_order_ops.utils
64
+ def get_tensor_mask(tensor_list: list[Any]) -> list[bool]:
65
+ # Returns a mask whether a list element is a tensor or not
66
+ return [True if isinstance(v, torch.Tensor) else False for v in tensor_list]
67
+
68
+
69
+ def mask_list(
70
+ mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None
71
+ ) -> list[Any]:
72
+ # Masks elements on an `inp` list.
73
+ # If other is None, then the elements of the `inp` list where the mask is False are removed
74
+ # If other is not None, then the elements of the `inp` list where the mask is False are
75
+ # replaced with the elements of the `other` list
76
+ assert len(mask) == len(
77
+ inp
78
+ ), "The length of the mask needs to be identical to the length of the input"
79
+ if other is not None:
80
+ assert len(inp) == len(
81
+ other
82
+ ), "If an input and an other list is provided, they need to have the same length"
83
+ return [i if m else o for m, i, o in zip(mask, inp, other)]
84
+ else:
85
+ return [i for m, i in zip(mask, inp) if m]
86
+
87
+
88
+ def first_slice_copy_with_grad(li: list[Any]) -> list[Any]:
89
+ # First_slice_copy does not keep the original requires_grad flag,
90
+ # but we need it for materialize_as_graph
91
+ # in order to compute the correct gradients
92
+ # The reason why first_slice_copy doesn't keep requires_grad flag is
93
+ # because it's called in torch.autograd.Function.backward/forward.
94
+ slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li]
95
+ return slc
96
+
97
+
98
+ def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]:
99
+ it = iter(iterable)
100
+ assert sum(chunk_sizes) == len(
101
+ iterable
102
+ ), "the sum of all chunks needs to match the length of the iterable."
103
+ return [list(itertools.islice(it, size)) for size in chunk_sizes]
104
+
105
+
106
+ def call_operator(operator, *args):
107
+ return pytree.tree_leaves(operator(*args))
108
+
109
+
110
+ def scan(
111
+ combine_fn: Callable[
112
+ [pytree.PyTree, pytree.PyTree], tuple[pytree.PyTree, pytree.PyTree]
113
+ ],
114
+ init: pytree.PyTree,
115
+ xs: pytree.PyTree,
116
+ *,
117
+ dim: int = 0,
118
+ reverse: bool = False,
119
+ ) -> tuple[pytree.PyTree, pytree.PyTree]:
120
+ r"""
121
+ Performs an inclusive scan with a combine function.
122
+
123
+ .. warning::
124
+ `torch.scan` is a prototype feature in PyTorch. It currently
125
+ does not support autograd and you may run into miscompiles.
126
+ Read more about feature classification at:
127
+ https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
128
+
129
+ Args:
130
+ combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> (Tensor, Tensor)``,
131
+ or if xs is a pytree ``(pytree, pytree) -> (pytree, pytree)``.
132
+ The first input to ``combine_fn`` is the previous or initial scan carry
133
+ and the second input element to ``combine_fn`` is a slice of the input along dim.
134
+ The first output element of ``combine_fn`` is the next scan carry
135
+ and the second output of ``combine_fn`` represents a slice of the output.
136
+ This function must be pure, i.e., no lifted arguments are supported at the moment
137
+ and may not have any side effects.
138
+ init (torch.Tensor or pytree with tensor leaves): The inital scan carry, a tensor, or nested pytree of tensors.
139
+ The ``init`` is expected to have the same pytree structure as the first output element (i.e. carry)
140
+ of ``combine_fn``.
141
+ xs (torch.Tensor or pytree with tensor leaves): The input tensor, or nested pytree of tensors.
142
+
143
+ Kwargs:
144
+ dim (int): the dimension to scan over, default 0.
145
+ reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``.
146
+
147
+ Returns:
148
+ final_carry (torch.Tensor or pytree with tensor leaves),
149
+ the final carry of the scan operation with same pytree structure as init.
150
+ out (torch.Tensor or pytree with tensor leaves),
151
+ each tensor leaf is a stacked output along first dim, where each slice is the output of a scan iteration.
152
+
153
+ Restrictions:
154
+ - The combine_fn shouldn't have any aliasing between input-input, input-output, and output-output. E.g. return a view
155
+ or the same tensor as input is not supported. As a workaround, can clone the output to avoid aliasing.
156
+
157
+ - The combine_fn shoudn't mutate any inputs. We'll remove the mutation restriction for inference soon. Please file an issue
158
+ if you input mutation support for training is needed.
159
+
160
+ - The combine_fn's init carry should match the next_carry in pytree structure and in tensor metadata.
161
+
162
+ Example::
163
+
164
+ def add(x: torch.Tensor, y: torch.Tensor):
165
+ next_carry = y = x + y
166
+ # clone the output to avoid output-output aliasing
167
+ return next_carry, y.clone()
168
+
169
+ i0 = torch.zeros(1)
170
+ xs = torch.arange(5)
171
+ # returns torch.tensor([10.]), torch.tensor([[0], [1.], [3.], [6.], [10.]])
172
+ last_carry, cumsum = scan(add, init=i0, xs=xs)
173
+
174
+
175
+ """
176
+ # The reason we flatten init and xs before calling into dynamo is that
177
+ # we want to create a consistent input ordering for combine_fn
178
+ # and we also want to the input ordering matches the output ordering.
179
+ leaves_init, spec_init = pytree.tree_flatten(init)
180
+ leaves_xs_orig, spec_xs = pytree.tree_flatten(xs)
181
+
182
+ # Shortcut if no xs is provided
183
+ if len(leaves_xs_orig) == 0:
184
+ return init, []
185
+
186
+ def _validate_input(cfn, lxs, linit, d, r):
187
+ # Basic arguments check
188
+ if not callable(cfn):
189
+ raise RuntimeError("Combine_fn must be a callable, but got {cfn}")
190
+ if not isinstance(d, int):
191
+ raise RuntimeError("Dim must be an int, but got " + str(type(d)))
192
+ if not isinstance(r, bool):
193
+ raise RuntimeError("Reverse must be a bool, but got " + str(type(r)))
194
+
195
+ # Checks for init
196
+ if len(linit) == 0:
197
+ raise RuntimeError("scan() operator requires init leaves.")
198
+ for x in linit:
199
+ if not isinstance(x, torch.Tensor):
200
+ raise RuntimeError(f"All init leaves must be a Tensor but got {x}")
201
+
202
+ # Checks for xs
203
+ for x in lxs:
204
+ if not isinstance(x, torch.Tensor):
205
+ raise RuntimeError(f"All xs leaves must be a Tensor but got {x}")
206
+ if any(x.ndim <= d for x in lxs):
207
+ raise RuntimeError(
208
+ "All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0"
209
+ )
210
+ if any(x.shape[d] == 0 for x in lxs):
211
+ raise RuntimeError(
212
+ "All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0"
213
+ )
214
+
215
+ ndim = leaves_xs_orig[0].ndim
216
+ dim = utils.canonicalize_dim(ndim, dim)
217
+
218
+ _validate_input(combine_fn, leaves_xs_orig, leaves_init, dim, reverse)
219
+
220
+ # Move scan dim to 0 and always perform scan on dim 0
221
+ leaves_xs = []
222
+ for elem in leaves_xs_orig:
223
+ leaves_xs.append(torch.movedim(elem, dim, 0))
224
+
225
+ if reverse:
226
+ leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs]
227
+
228
+ # TODO: Support _inductor lowering
229
+ # TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc.
230
+
231
+ combine_fn = functools.partial(
232
+ wrap_combine_fn_flat,
233
+ combine_fn=combine_fn,
234
+ spec_init=spec_init,
235
+ spec_xs=spec_xs,
236
+ num_init_leaves=len(leaves_init),
237
+ num_inp_leaves=len(leaves_xs),
238
+ )
239
+
240
+ def run_flattened_scan(combine_fn, leaves_init, leaves_xs):
241
+ return scan_op(combine_fn, leaves_init, leaves_xs, additional_inputs=())
242
+
243
+ carry, out = _maybe_compile_and_run_fn(
244
+ run_flattened_scan,
245
+ combine_fn,
246
+ leaves_init,
247
+ leaves_xs,
248
+ )
249
+
250
+ if reverse:
251
+ out = pytree.tree_map(lambda elem: elem.flip([0]), out)
252
+
253
+ return carry, out
254
+
255
+
256
+ class ScanOp(HigherOrderOperator):
257
+ def __init__(self):
258
+ super().__init__("scan")
259
+
260
+ def __call__(self, combine_fn, init, xs, additional_inputs):
261
+ # There is currently an issue that the ScanOp is sometimes called with
262
+ # the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785
263
+ # Once this issue is resolved, the assertion should only allow tuples
264
+ # and the tuple cast should be removed
265
+ assert isinstance(
266
+ additional_inputs, (tuple, list)
267
+ ), "additional_inputs must be a tuple."
268
+ additional_inputs = (
269
+ tuple(additional_inputs)
270
+ if isinstance(additional_inputs, list)
271
+ else additional_inputs
272
+ )
273
+ validate_subgraph_args_types(additional_inputs)
274
+ return super().__call__(combine_fn, init, xs, additional_inputs)
275
+
276
+
277
+ scan_op = ScanOp()
278
+
279
+
280
+ def generic_scan(operator, init, xs, dim=0, additional_inputs=()):
281
+ def _scan(init, xs):
282
+ """Perform scan on `elems` using `elems_init."""
283
+ carry = init
284
+ if len(xs) == 0:
285
+ return carry, []
286
+
287
+ num_elems = xs[0].shape[dim]
288
+ ind = 0
289
+
290
+ # Compute dummy shapes for the pre-allocation
291
+ num_init_leaves = len(init)
292
+ dummy_carry, dummy_out = _extract_carry_and_out(
293
+ call_operator(
294
+ operator,
295
+ *carry,
296
+ *[first_slice_copy(elem, dim) for elem in xs],
297
+ *additional_inputs,
298
+ ),
299
+ num_init_leaves,
300
+ )
301
+
302
+ out_tensor_mask = get_tensor_mask(dummy_out)
303
+ dummy_out_masked = mask_list(out_tensor_mask, dummy_out)
304
+
305
+ # Pre-alocate
306
+ # outs -> Output matrix
307
+ # idxs -> Index matrix for scatter_
308
+ # out: (num_elems, M, N, ...)
309
+ # idx: (1, M, N)
310
+ outs = [
311
+ torch.zeros(
312
+ [num_elems] + list(e.size()),
313
+ dtype=e.dtype,
314
+ device=e.device,
315
+ )
316
+ for i, e in enumerate(dummy_out_masked)
317
+ ]
318
+ idxs = [
319
+ torch.ones_like(e, dtype=torch.int64).unsqueeze(0)
320
+ for i, e in enumerate(dummy_out_masked)
321
+ ]
322
+
323
+ def store_out_in_outs(out, ind):
324
+ # Store the intermediate out in the outs matrix
325
+ for o, x, idx in zip(outs, out, idxs):
326
+ # o: (num_elems, M, N ...)
327
+ # x: (M, N, ...) -> (1, M, N)
328
+ # ind * idx: (1, M, N,) with values to be ind
329
+ # essentially: o[ind][n][k] = x[0][n][k]
330
+ o.scatter_(0, ind * idx, x.unsqueeze(0))
331
+
332
+ for i in range(num_elems):
333
+ ind = i
334
+ carry, out = _extract_carry_and_out(
335
+ call_operator(
336
+ operator,
337
+ *carry,
338
+ *[elem.select(dim, ind) for elem in xs],
339
+ *additional_inputs,
340
+ ),
341
+ num_init_leaves,
342
+ )
343
+
344
+ # Store the inits in the outs matrix.
345
+ store_out_in_outs(mask_list(out_tensor_mask, out), ind)
346
+
347
+ # Expand outs with None depending on the tensor mask of the output
348
+ outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask]
349
+
350
+ return [*carry, *outs_expanded]
351
+
352
+ scans = _scan(init, xs)
353
+ return scans
354
+
355
+
356
+ def trace_scan(
357
+ proxy_mode,
358
+ func_overload,
359
+ combine_fn: Callable,
360
+ init: list[torch.Tensor],
361
+ xs: list[torch.Tensor],
362
+ additional_inputs: tuple[torch.Tensor],
363
+ ):
364
+ from torch._dynamo.utils import clone_input
365
+
366
+ with disable_proxy_modes_tracing():
367
+ sample_inits = [clone_input(x_init) for x_init in init]
368
+ sample_inputs = [first_slice_copy(x) for x in xs]
369
+ sample_additional_inputs = [
370
+ clone_input(x) if isinstance(x, torch.Tensor) else x
371
+ for x in additional_inputs
372
+ ]
373
+ combine_graph = reenter_make_fx(combine_fn)(
374
+ *sample_inits, *sample_inputs, *sample_additional_inputs
375
+ )
376
+
377
+ outputs = None
378
+ for node in combine_graph.graph.nodes:
379
+ if node.op == "output":
380
+ assert outputs is None
381
+ assert len(node.args) == 1
382
+ outputs = node.args[0]
383
+
384
+ assert outputs is not None
385
+
386
+ carry, output = _extract_carry_and_out(outputs, len(init))
387
+ init_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
388
+ i.clone() for i in init
389
+ ]
390
+ carry_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
391
+ c.meta["val"] for c in carry
392
+ ]
393
+ check_meta_consistency(
394
+ init_fake_tensors, carry_fake_tensors, "init", "carry", include_contiguity=False
395
+ )
396
+
397
+ _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")
398
+
399
+ proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)
400
+
401
+ args = (combine_graph, init, xs, additional_inputs)
402
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
403
+ out_proxy = proxy_mode.tracer.create_proxy(
404
+ "call_function", func_overload, proxy_args, {}, name="scan"
405
+ )
406
+
407
+ with disable_proxy_modes_tracing():
408
+ scan_length = xs[0].shape[0]
409
+ fake_carry, fake_outputs = _extract_carry_and_out(
410
+ [o.meta["val"] for o in outputs], len(init)
411
+ )
412
+ out = (
413
+ *fake_carry,
414
+ *(stack_y(t, scan_length) for t in fake_outputs),
415
+ )
416
+
417
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
418
+
419
+
420
+ @scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
421
+ def scan_op_dense(combine_fn, init, xs, additional_inputs):
422
+ mode = _get_current_dispatch_mode()
423
+ assert mode is None, "Mode should never be enabled for CPU/CUDA key"
424
+ return generic_scan(combine_fn, init, xs, additional_inputs=additional_inputs)
425
+
426
+
427
+ class ScanAutogradOp(torch.autograd.Function):
428
+ """
429
+ Example ::
430
+
431
+ def combine_fn(x: torch.Tensor, y: torch.Tensor):
432
+ next_carry = y = x * y
433
+ return next_carry, y
434
+
435
+ The ``combine_fn_bw``, computing the gradients for x and y of ``combine_fn`` is computed as:
436
+ def combine_fn_bw(x: torch.Tensor, y: torch.Tensor, g_carry: torch.Tensor, g_y: torch.Tensor):
437
+ return g_y * y + g_carry * y, g_y * x + g_carry * x
438
+
439
+ Note: In a real usecase of scan, there may be additional_inputs that participate in the
440
+ forward as well as in the backward of the scan operator. For the sake of readability those inputs
441
+ have been omitted in the following example, but are included in the subsequent detailed description below
442
+
443
+ The forward output of scan is computed as:
444
+ carry, ys = scan(combine_fn, init, xs).
445
+
446
+ This computation can be unpacked as
447
+ c_0, ys_0 = combine_fn(init, xs_0)
448
+ c_1, ys_1 = combine_fn(carry_0, xs_1)
449
+ c_2, ys_2 = combine_fn(carry_1, xs_2)
450
+ ...
451
+ c_T, ys_T = combine_fn(carry_(T-1), xs_T)
452
+
453
+ We collect c_0, c_1, ..., c_T into a vector of carries that we save for the backward,
454
+ but we only output (c_T, ys),
455
+ where ys is the vector of all intermediate outputs [y_0, y_1, ..., y_T].
456
+
457
+ Given the carries and the ys, the gradients for xs and for init can be computed as follows:
458
+ We receive the upstream gradients in torch.autograd.Function, i.e., we get g_c_T and g_ys,
459
+ where g_ys is the vector of all intermediate gradients of the outputs [g_ys_0, g_ys_1, ..., g_ys_T]
460
+
461
+ We then proceed to compute the gradients for the init (g_init) and the xs (g_xs) by running a
462
+ scan operation reverse over time. For example,
463
+
464
+ g_c_(T-1), g_xs_T = combine_fn_bw(c_(T-1), xs_T, g_c_T, g_ys_T)
465
+ g_c_(T-2), g_xs_(T-1) = combine_fn_bw(c_(T-2), xs_(T-1), g_c_(T-1), g_ys_(T-1))
466
+ g_c_(T-3), g_xs_(T-2) = combine_fn_bw(c_(T-3), xs_(T-2), g_c_(T-2), g_ys_(T-2))
467
+ ...
468
+ g_init, g_xs_1 = combine_fn_bw(c_0, xs_1, g_c_0, g_ys_1)
469
+ 0 , g_xs_0 = combine_fn_bw(init, xs_0, g_init, g_ys_0),
470
+
471
+ where combine_fn_bw takes the forward inputs of step t (i.e. c_(t-1), xs_t),
472
+ the gradients of the carry of step t (i.e. g_c_t) and
473
+ the upstream gradient of the output of step t (i.e. g_ys_T)
474
+ and returns the gradient of xs_t -> g_xs_t, as well as the gradient for the carry of step t-1 -> g_c_(t-1).
475
+
476
+ Through this procedure we end up with the
477
+ gradients for the init -> g_init,
478
+ the gradients for the xs -> g_xs.
479
+
480
+
481
+ NOTE: [scan autograd implementation]
482
+
483
+ The forward of scan can be computed as:
484
+ 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``:
485
+ To use a scan operation for the backward path as well, we need access to the carries from all steps.
486
+ Thus, the function ``combine_fn`` is wrapped such that it returns all carries and not only the last carry.
487
+ In particular, we define ``combine_fn_with_carry_checkpoint``:
488
+ def combine_fn_with_carry_checkpoint(x: torch.Tensor, y: torch.Tensor):
489
+ carry, y = combine_fn(x, y)
490
+ return carry, (carry, y)
491
+
492
+ The scan operator will stack all outputs along the scan dimension.
493
+ Thus, by putting next_carry also into outputs of ``combine_fn_with_carry_checkpoint``,
494
+ the carries from all steps will be stacked and hence gives us chekpointed_carries
495
+
496
+ 2.) Compute all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``:
497
+ c_T, (carries, ys) = scan_op(combine_fn_with_carry_checkpoint, init, xs, additional_inputs),
498
+ Where c_T (last carry) and ys (all outputs) are the original results of scan with the ``combine_fn``.
499
+ However, carries are checkpointed carries from all steps.
500
+ As a result of the forward, only the last carry c_T and the ys are returned,
501
+ while all carries are saved for the backward.
502
+
503
+ The backward of scan can be computed as:
504
+
505
+ 3.) Prepare the backward graph:
506
+ We prepare the backward graph to be used in the backward function.
507
+ We utilize ``create_bw_fn`` to generate the joint function, i.e.,
508
+ ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands), where fw_operands = [init, xs_0, additional_inputs]
509
+
510
+ The ctx._combine_fn_bw requires the primals (operands)
511
+ followed by the tangents (upstream gradients) from a single step
512
+ and produces the gradients of that step, i.e.,
513
+ g_c_(T-1), g_xs_T, g_additional_input_T = ctx._combine_fn_bw(c_(T-1), xs_T, additional_inputs, g_c_T, g_ys_T).
514
+
515
+ 4.) Create a wrapper of the ``combine_fn_bw``, i.e., ``combine_fn_bw_grad_accumulation``:
516
+ In the forward, there may be additional inputs that participate in every forward step.
517
+ The gradients for those additional inputs are also computed at every step and need to be accumulated over all steps,
518
+ which is taken care of in this wrapper. For example:
519
+ def combine_fn_bw_grad_accumulation(*args):
520
+ carried_g_additional_input = args[:num_additional_inputs]
521
+ inputs_bw_fn = args[num_additional_inputs:]
522
+ g_c_(t-1), g_xs_t, g_additional_input_t = ctx._combine_fn_bw(*inputs_bw_fn)
523
+ new_g_additional_inputs = carried_g_additional_input + g_additional_input_t
524
+ # The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
525
+ # The ``g_xs_t`` is encoded as the output of the backward scan operator
526
+ return [*new_g_additional_inputs, *g_c_t, *g_xs_t]
527
+
528
+ 5.) Perform the backward scan as
529
+ g_additional_inputs, g_init, g_xs = scan_op(combine_fn_bw_grad_accumulation, bw_init, bw_xs), where
530
+ bw_init consists of the initial gradient carry for the additional_inputs (initialized with 0s):
531
+ initial_g_additional_inputs, and the gradient of the last carry: g_c_T. Thus:
532
+ bwd_init = [*initial_g_additional_inputs, *g_c_T].
533
+
534
+ bw_xs consists of the combination of the upstream gradients g_ys,
535
+ the forward carries prepended with the fw_init, i.e., bw_carries = concat([fw_init, fw_carries[:-1]]) and
536
+ the fw_xs. In particular,
537
+ bwd_xs = [*g_ys, *bw_carries, *fw_xs].
538
+
539
+ Note: g_c_T and g_ys are provided through the torch.autograd.Function.backward's input
540
+
541
+ As demonstrated in the Example above, this backward scan then yields the gradient for the init -> g_init
542
+ and the gradient for the xs -> g_xs
543
+
544
+ NOTE: [scan partial grad handling]
545
+ If any element of init, of xs, of the outputs or of the additional_inputs does not require gradients,
546
+ i.e., requires_grad=False, there will be still gradients returned for those elements,
547
+ but those gradients will be a tensor filled with zeros of the same shape as the element itself.
548
+
549
+ A special case are additional_inputs that are not tensors. Such inputs can occur for example with symbolic tracing,
550
+ where the shape symbol (SymInt) becomes an additional_input.
551
+ For such cases, we compute a ``additional_inputs_tensor_mask``, which is True for elements of additional_inputs
552
+ that are tensors and False otherwise. Gradients of additional_inputs are only accumulated if this mask is True,
553
+ otherwise, the value of initial_g_additional_inputs is passed, which is None for non-Tensor values.
554
+ """
555
+
556
+ @staticmethod
557
+ def forward(
558
+ ctx,
559
+ combine_fn,
560
+ num_leaves_init,
561
+ num_leaves_xs,
562
+ num_additional_inputs,
563
+ *operands,
564
+ ):
565
+ ctx._num_leaves_init = num_leaves_init
566
+ ctx._num_leaves_xs = num_leaves_xs
567
+ ctx._num_additional_inputs = num_additional_inputs
568
+ ctx._combine_fn = combine_fn
569
+ init, xs, additional_inputs = split_into_chunks(
570
+ operands, [num_leaves_init, num_leaves_xs, num_additional_inputs]
571
+ )
572
+ additional_inputs_tensor_mask = get_tensor_mask(additional_inputs)
573
+ ctx._additional_inputs_tensor_mask = additional_inputs_tensor_mask
574
+
575
+ # We snapshot the dispatch keys in forward for materializing the
576
+ # the bw_graph in backward.
577
+ ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
578
+ ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
579
+
580
+ # 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``
581
+ # The wrapper of the forward graph returns carries from all iterations,
582
+ # not just from the last iteration. These are required in the backward path
583
+ def combine_fn_with_carry_checkpoint(*args):
584
+ carry, y = _extract_carry_and_out(combine_fn(*args), num_leaves_init)
585
+ return [
586
+ *carry,
587
+ # We additionally checkpoint all the intemediate carry outputs for backward.
588
+ *[
589
+ n_c.clone().detach() if isinstance(n_c, torch.Tensor) else n_c
590
+ for n_c in carry
591
+ ],
592
+ *y,
593
+ ]
594
+
595
+ with torch._C._AutoDispatchBelowAutograd():
596
+ # 2.) Compute the all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``
597
+ c_T, carries_ys = _extract_carry_and_out(
598
+ scan_op(
599
+ combine_fn_with_carry_checkpoint,
600
+ init,
601
+ xs,
602
+ additional_inputs,
603
+ ),
604
+ num_leaves_init,
605
+ )
606
+
607
+ # Collect the carries for each time step from the outs
608
+ # and save them for the backward path
609
+ carries = list(carries_ys[:num_leaves_init])
610
+ ys = list(carries_ys[num_leaves_init:])
611
+ save_tensors_and_symints_for_backward(ctx, list(operands) + carries + ys)
612
+ ctx._num_leaves_ys = len(ys)
613
+
614
+ return (*c_T, *ys)
615
+
616
+ @staticmethod
617
+ def backward(ctx, *flat_grads):
618
+ r"""
619
+ This function computes the gradients of the scan operation.
620
+ It does so by using a scan operator using all carries and the upstream gradients (see description above)
621
+
622
+ Args:
623
+ flat_grads (torch.Tensor): The tensor of flattened upstream gradients.
624
+ """
625
+
626
+ # Collect the saved items from the forward
627
+ num_leaves_init = ctx._num_leaves_init
628
+ num_leaves_xs = ctx._num_leaves_xs
629
+ num_leaves_ys = ctx._num_leaves_ys
630
+ num_additional_inputs = ctx._num_additional_inputs
631
+ additional_inputs_tensor_mask = ctx._additional_inputs_tensor_mask
632
+
633
+ def prepend_init_to_carries(init, carries):
634
+ # Prepare the carries for the backward path.
635
+ # This requires to concatenate the init and the carries
636
+ return [
637
+ torch.cat([torch.unsqueeze(i, 0), c[:-1]], dim=0)
638
+ for i, c in zip(init, carries)
639
+ ]
640
+
641
+ def initialize_g_additional_inputs(
642
+ additional_inputs,
643
+ ):
644
+ # The initial gradients for the additional_inputs are all zeros
645
+ g_additional_inputs = [
646
+ torch.zeros_like(ai) if ai_tm else None
647
+ for ai_tm, ai in zip(additional_inputs_tensor_mask, additional_inputs)
648
+ ]
649
+ return g_additional_inputs
650
+
651
+ # Retrieve the forward inputs and the forward outputs and dissect them
652
+ flat_args = saved_tensors_and_symints(ctx)
653
+ fw_init, fw_xs, additional_inputs, fw_carries, fw_ys = split_into_chunks(
654
+ flat_args,
655
+ [
656
+ num_leaves_init,
657
+ num_leaves_xs,
658
+ num_additional_inputs,
659
+ num_leaves_init,
660
+ num_leaves_ys,
661
+ ],
662
+ )
663
+
664
+ # 3.) Prepare the backward graph
665
+ fw_operands = (
666
+ *fw_init,
667
+ *[first_slice_copy(xs) for xs in fw_xs],
668
+ *additional_inputs,
669
+ )
670
+ ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands)
671
+
672
+ # 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs
673
+ def combine_fn_bw_grad_accumulation(*args):
674
+ # Dissect args and re-order them for the ``ctx._combine_fn_bw``
675
+ # The content of ``combine_fn_bw_tangents`` is [*carries_g, *outs_g]
676
+ # The content of ``combine_fn_bw_primals`` is [*init, *xs, *additional_inputs]
677
+ (
678
+ carried_g_additional_input,
679
+ combine_fn_bw_tangents,
680
+ combine_fn_bw_primals,
681
+ ) = split_into_chunks(
682
+ args,
683
+ [
684
+ num_additional_inputs,
685
+ num_leaves_init + num_leaves_ys,
686
+ num_leaves_init + num_leaves_xs + num_additional_inputs,
687
+ ],
688
+ )
689
+ combine_fn_bw_args = (*combine_fn_bw_primals, *combine_fn_bw_tangents)
690
+
691
+ g_c_t, g_xs_t, g_additional_inputs_t = split_into_chunks(
692
+ ctx._combine_fn_bw(*combine_fn_bw_args),
693
+ [num_leaves_init, num_leaves_xs, num_additional_inputs],
694
+ )
695
+
696
+ new_g_additional_inputs = [
697
+ # If the additional inputs are ints or SymInts, those values are taken as is and no gradients are added
698
+ carr_g + curr_g if add_inp_tm else carr_g
699
+ for add_inp_tm, carr_g, curr_g in zip(
700
+ additional_inputs_tensor_mask,
701
+ carried_g_additional_input,
702
+ g_additional_inputs_t,
703
+ )
704
+ ]
705
+
706
+ # The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
707
+ # The ``g_xs_t`` is encoded as the output of the backward scan operator
708
+ return [*new_g_additional_inputs, *g_c_t, *g_xs_t]
709
+
710
+ # Materialize the ``combine_fn_bw_grad_accumulation``
711
+ def construct_args_single_step_bw():
712
+ # This function constructs the arguments for a single step of the backward scan.
713
+ # In other words, it creates the arguments for ``combine_fn_bw_grad_accumulation``
714
+ # The order of the arguments returned is identical to the order the backward scan
715
+ # operations provides
716
+
717
+ # The following arguments are used for the backward part of the joint graph
718
+ # The first argument relates to the gradient accumulation of the additional inputs.
719
+ # Because only tensor elements of additional inputs can have requires_grad=True,
720
+ # the values for non-tensor elements of additional inputs are None
721
+ masked_additional_inputs = [
722
+ a.clone() if add_inp_tm else None
723
+ for add_inp_tm, a in zip(
724
+ additional_inputs_tensor_mask, additional_inputs
725
+ )
726
+ ]
727
+
728
+ # The second argument relates to the gradients of the carries.
729
+ # Because the arguments are for a single step only,
730
+ # only the first slice of the carries is used.
731
+ sliced_carries = [first_slice_copy(c) for c in fw_carries]
732
+
733
+ # The third argument relates to the gradients of the ys.
734
+ # Because the arguments are for a single step only,
735
+ # only the first slice of the ys is used.
736
+ sliced_ys = [first_slice_copy(o) for o in fw_ys]
737
+
738
+ # The following arguments are used for the forward part of the joint graph
739
+ # The fourth argument relates to the init for the forward.
740
+ # I.e., fw_init
741
+
742
+ # The fifth argument relates to the xs for the forward.
743
+ # Because the arguments are for a single step only,
744
+ # only the first slice of the xs is used.
745
+ # Note: It is important to preserve the requires_grad flag of xs
746
+ # and thus we use the wrapper function ``first_slice_copy_with_grad``
747
+ fw_xs_slice = first_slice_copy_with_grad(fw_xs)
748
+
749
+ # The last argument relates to the additional inputs for the forward.
750
+ # I.e., additional_inputs
751
+
752
+ return (
753
+ *masked_additional_inputs,
754
+ *sliced_carries,
755
+ *sliced_ys,
756
+ *fw_init,
757
+ *fw_xs_slice,
758
+ *additional_inputs,
759
+ )
760
+
761
+ args_single_step_bw = construct_args_single_step_bw()
762
+
763
+ # TODO: we need to materialize the bw graphs because dynamo is unable to
764
+ # trace through the joint function when torch.compile torch.autograd.grad.
765
+ combine_fn_bw_grad_accumulation_gm = materialize_as_graph(
766
+ combine_fn_bw_grad_accumulation,
767
+ args_single_step_bw,
768
+ ctx._fw_include_key_set,
769
+ ctx._fw_exclude_key_set,
770
+ force_enable_grad=True,
771
+ )
772
+
773
+ # Decompose the flat_grads into g_c_T, g_ys
774
+ g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys])
775
+
776
+ # Initialize the g_additional_inputs with zero-tensors.
777
+ # This step is necessary because the gradients of the additional inputs are accumulated in the
778
+ # ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point
779
+ initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs)
780
+
781
+ # Prepend the inits to the carries.
782
+ # This is needed, because when computing the gradients, the last carry is not needed
783
+ # but the first carry, the init, is required.
784
+ bw_carries = prepend_init_to_carries(fw_init, fw_carries)
785
+
786
+ # Prepare the xs for the backward scan.
787
+ bwd_xs = [*g_ys, *bw_carries, *fw_xs]
788
+
789
+ # The flipping of the ``bwd_xs`` is necessary because the scan_op in the backward is always performed in reverse
790
+ bwd_xs = [torch.flip(elem, [0]) for elem in bwd_xs]
791
+
792
+ # Prepare the bwd_init
793
+ bwd_init = [*initial_g_additional_inputs, *g_c_T]
794
+
795
+ # 5.) Perform the backwrad scan:
796
+ # The ``combine_fn_bw_wrapped`` receives the
797
+ # initial_g_additional_inputs and the last carry as the ``bwd_init`` and the
798
+ # gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs``
799
+ gradients = scan_op(
800
+ combine_fn_bw_grad_accumulation_gm,
801
+ bwd_init,
802
+ bwd_xs,
803
+ additional_inputs,
804
+ )
805
+
806
+ # Unpack the computed gradients
807
+ g_additional_inputs, g_init, g_xs = split_into_chunks(
808
+ gradients, [num_additional_inputs, num_leaves_init, num_leaves_xs]
809
+ )
810
+
811
+ # The flipping back along the scan dimension is required to get the gradients in the right order for ``xs``
812
+ g_xs = [torch.flip(elem, [0]) for elem in g_xs]
813
+
814
+ return *[None] * 4, *g_init, *g_xs, *g_additional_inputs
815
+
816
+
817
+ @scan_op.py_autograd_impl
818
+ def scan_autograd(combine_fn, init, xs, additional_inputs):
819
+ num_leaves_init = len(init)
820
+ num_leaves_xs = len(xs)
821
+ num_additional_inputs = len(additional_inputs)
822
+
823
+ flat_out = ScanAutogradOp.apply(
824
+ combine_fn,
825
+ num_leaves_init,
826
+ num_leaves_xs,
827
+ num_additional_inputs,
828
+ *(tuple(init) + tuple(xs) + additional_inputs),
829
+ )
830
+ return *flat_out[:num_leaves_init], *flat_out[num_leaves_init:]
831
+
832
+
833
+ @scan_op.py_impl(ProxyTorchDispatchMode)
834
+ def scan_proxy_mode(mode, combine_fn, init, xs, additional_inputs):
835
+ return trace_scan(mode, scan_op, combine_fn, init, xs, additional_inputs)
836
+
837
+
838
+ @scan_op.py_impl(FakeTensorMode)
839
+ def scan_fake_tensor_mode(mode, combine_fn, init, xs, additional_inputs):
840
+ with mode:
841
+ scan_length = xs[0].shape[0]
842
+ carry, outputs = _extract_carry_and_out(
843
+ combine_fn(
844
+ *init,
845
+ *[first_slice_copy(inp) for inp in xs],
846
+ *additional_inputs,
847
+ ),
848
+ len(init),
849
+ )
850
+ out = (
851
+ *carry,
852
+ *(stack_y(t, scan_length) for t in outputs),
853
+ )
854
+ return out
855
+
856
+
857
+ @scan_op.py_functionalize_impl
858
+ def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
859
+ from torch._higher_order_ops.utils import (
860
+ _check_alias_and_mutation,
861
+ _maybe_run_with_interpreter,
862
+ )
863
+
864
+ unwrapped_xs = ctx.unwrap_tensors(xs)
865
+ unwrapped_init = ctx.unwrap_tensors(init)
866
+ unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
867
+
868
+ with ctx.redispatch_to_next():
869
+ functional_combine_fn = ctx.functionalize(
870
+ _maybe_run_with_interpreter(combine_fn)
871
+ )
872
+ sample_unwrapped_xs_sliced = [first_slice_copy(inp) for inp in unwrapped_xs]
873
+ sample_inputs = list(
874
+ itertools.chain(
875
+ unwrapped_init,
876
+ sample_unwrapped_xs_sliced,
877
+ unwrapped_additional_inputs,
878
+ )
879
+ )
880
+ pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
881
+ _check_alias_and_mutation(combine_fn, sample_inputs, "scan", pre_dispatch)
882
+ ret = scan_op(
883
+ functional_combine_fn,
884
+ unwrapped_init,
885
+ unwrapped_xs,
886
+ unwrapped_additional_inputs,
887
+ )
888
+ return ctx.wrap_tensors(ret)
889
+
890
+
891
+ # dense implementation for scan. Used for testing only.
892
+ def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False):
893
+ carry_leaves, carry_spec = pytree.tree_flatten(init)
894
+ inp_leaves, inp_spec = pytree.tree_flatten(xs)
895
+ if xs is None or len(inp_leaves) == 0:
896
+ return init, []
897
+ result_flat = []
898
+ carry = carry_leaves
899
+ op = reversed if reverse else lambda x: x
900
+
901
+ dummy_carry, dummy_out = combine_fn(
902
+ pytree.tree_unflatten(carry, carry_spec),
903
+ pytree.tree_unflatten(
904
+ [first_slice_copy(elem, dim) for elem in inp_leaves],
905
+ inp_spec,
906
+ ),
907
+ )
908
+ dummy_out_leaves, dummy_out_spec = pytree.tree_flatten(dummy_out)
909
+ num_leaves = len(dummy_out_leaves)
910
+
911
+ for ind in op(range(inp_leaves[0].size(dim))):
912
+ xs = [elem.select(dim, ind) for elem in inp_leaves]
913
+
914
+ carry, y = combine_fn(
915
+ pytree.tree_unflatten(carry, carry_spec),
916
+ pytree.tree_unflatten(xs, inp_spec),
917
+ )
918
+ carry, _ = pytree.tree_flatten(carry)
919
+ y, _ = pytree.tree_flatten(y)
920
+ result_flat.append(y)
921
+
922
+ results = [
923
+ torch.stack([e[leave_ind] for e in op(result_flat)])
924
+ for leave_ind in range(num_leaves)
925
+ ]
926
+ return (
927
+ pytree.tree_unflatten(carry, carry_spec),
928
+ pytree.tree_unflatten(results, dummy_out_spec),
929
+ )
archive/.venv/Lib/site-packages/torch/_higher_order_ops/schema.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.utils._pytree as pytree
7
+ from torch.fx.node import Target
8
+
9
+
10
+ # Below is an implementation of generating FunctionSchema from example values.
11
+ # This is helpful for generating FunctionSchema for HigherOrderOperator, where
12
+ # we don't have a function to inspect and each call of the higher order operator
13
+ # would have different schema.
14
+ @dataclass(frozen=True)
15
+ class HopArgumentInfo:
16
+ # Could give a name to the operand by default it's empty string.
17
+ name: str
18
+ example_value: Any
19
+ # Provide an default_value
20
+ default_value: Any
21
+ # Whether this arugment gets mutated in the hop subgraph.
22
+ # For output, this should always be False
23
+ is_mutated: bool
24
+ kw_only: bool
25
+
26
+
27
+ class HopArgumentInfoGen:
28
+ @staticmethod
29
+ def from_example(
30
+ example_value: Any,
31
+ *,
32
+ name: str = "",
33
+ default_value: Optional[Any] = None,
34
+ is_mutated: bool = False,
35
+ kw_only: bool = False,
36
+ ) -> HopArgumentInfo:
37
+ if default_value is not None:
38
+ assert type(example_value) == type(
39
+ default_value
40
+ ), f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}"
41
+
42
+ return HopArgumentInfo(
43
+ name=name,
44
+ example_value=example_value,
45
+ default_value=default_value,
46
+ is_mutated=is_mutated,
47
+ kw_only=kw_only,
48
+ )
49
+
50
+
51
+ class CTypeGen:
52
+ convert_to_base_ty = {
53
+ int: torch._C.IntType.get(),
54
+ float: torch._C.FloatType.get(),
55
+ str: torch._C.StringType.get(),
56
+ bool: torch._C.BoolType.get(),
57
+ }
58
+
59
+ # should return torch._C.JitType but that annotation is busted
60
+ @staticmethod
61
+ def from_example(obj: Any) -> Any:
62
+ import torch
63
+
64
+ if isinstance(obj, torch.fx.GraphModule):
65
+ return torch._C.AnyType.get()
66
+ elif isinstance(obj, torch.SymInt):
67
+ return torch._C.SymIntType.get()
68
+ return torch._C._jit_try_infer_type(obj).type()
69
+
70
+
71
+ class CArgumentGen:
72
+ @staticmethod
73
+ def from_hop_argument_info(
74
+ arg_idx: int, arg_info: HopArgumentInfo, is_output: bool = False
75
+ ) -> Any:
76
+ typ = CTypeGen.from_example(arg_info.example_value)
77
+ if is_output:
78
+ return torch._C.Argument("", typ, None, None, False, None)
79
+
80
+ alias_set = set({f"alias::a{arg_idx}"}) if arg_info.is_mutated else set()
81
+ alias_info = torch._C._AliasInfo(arg_info.is_mutated, alias_set, alias_set) # type: ignore[attr-defined]
82
+ return torch._C.Argument(
83
+ arg_info.name,
84
+ typ,
85
+ None,
86
+ arg_info.default_value,
87
+ arg_info.kw_only,
88
+ alias_info,
89
+ )
90
+
91
+
92
+ class HopSchemaGenerator:
93
+ def __init__(self, hop: torch._ops.HigherOrderOperator):
94
+ self.arg_infos: list[HopArgumentInfo] = []
95
+ self.example_outputs: list[Any] = []
96
+ self.schema_tree_spec: Optional[pytree.TreeSpec] = None
97
+ self.hop = hop
98
+
99
+ def add_arg(
100
+ self,
101
+ name: str,
102
+ example_value: Any,
103
+ default_value: Optional[Any] = None,
104
+ is_mutated: bool = False,
105
+ kw_only: bool = False,
106
+ ) -> None:
107
+ if callable(example_value):
108
+ assert isinstance(
109
+ example_value, (torch.fx.GraphModule, torch._ops.OperatorBase)
110
+ ), (
111
+ "Expect callable to be a GraphModule or an. Please call materialize_as_graph first "
112
+ f"to turn callable arguments {example_value} into a GraphModule."
113
+ )
114
+ _, flat_spec = pytree.tree_flatten(example_value)
115
+ if not flat_spec.is_leaf():
116
+ raise RuntimeError(
117
+ f"example_value {example_value} is not a leaf node. "
118
+ "Please only add flattened inputs to the hop schema. "
119
+ "If you need some structure in the arguments, please"
120
+ "add_arg for flattened args one by one then "
121
+ "call add_schema_tree_spec to register the original pytree "
122
+ " spec of the args."
123
+ )
124
+
125
+ arg_info = HopArgumentInfoGen.from_example(
126
+ example_value=example_value,
127
+ name=name,
128
+ default_value=default_value,
129
+ is_mutated=is_mutated,
130
+ kw_only=kw_only,
131
+ )
132
+ self.arg_infos.append(arg_info)
133
+
134
+ def add_output(self, output: Any) -> None:
135
+ self.example_outputs.append(output)
136
+
137
+ def add_schema_tree_spec(self, *args: Any, **kwargs: Any) -> None:
138
+ """schema tree spec is the tree spec from flattening all inputs to the hop with pytree.tree_flatten
139
+ Since torch.FunctionSchema only have proper mutation/alias support for flattened inputs, we need
140
+ to store the tree spec in order to reconstruct the inputs to the hop.
141
+ """
142
+ self.schema_tree_spec = pytree.tree_flatten((args, kwargs))[1]
143
+
144
+ def gen_schema(self) -> torch._C.FunctionSchema:
145
+ for i, arg_info in enumerate(self.arg_infos):
146
+ arg_spec = pytree.tree_flatten(arg_info.example_value)[1]
147
+ if not arg_spec.is_leaf() and self.schema_tree_spec is None:
148
+ raise RuntimeError(
149
+ f"example_value of arg_infos[{i}] is {arg_info.example_value}, which is not a leaf node. "
150
+ "Please call add_schema_tree_spec to add a schema tree spec first. "
151
+ "Or consider changing the hop's signature to only take flattened arguments."
152
+ )
153
+
154
+ return CFunctionSchemaGen.from_hop_argument_info(
155
+ str(self.hop),
156
+ self.arg_infos,
157
+ HopArgumentInfoGen.from_example(tuple(self.example_outputs), name="out"),
158
+ self.schema_tree_spec,
159
+ )
160
+
161
+
162
+ class CFunctionSchemaGen:
163
+ """
164
+ Note: [HigherOrderOperator schema generation]
165
+ Each invocation of a HigherOrderOperator will have a different schema.
166
+ For example, the schema of torch.cond varies depending on the true_fn and
167
+ false_fn. So we need a way to generate the schema for each invocation of a HOP.
168
+
169
+ We want to enforce the following invariants for HOP's schema:
170
+ 1. Flattened inputs. There should be no pytree structure in it.
171
+ 2. Flattened outputs. Note even if the hop returns a single value, it should be wrapped as a tuple.
172
+ 3. No aliasing. This includes inp-inp aliasing, inp-out aliasing and out-out aliasing.
173
+
174
+ By enforcing these invariants, we could make HOP's schema meets the requirement of schema parser
175
+ and makes hop easier to handle downstream. For example, suppose we have an invoke_quant_test HOP:
176
+
177
+ class GraphModule(torch.nn.Module):
178
+ def forward(self, l_x_, l_y_):
179
+ subgraph_0 = self.subgraph_0
180
+ invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4');
181
+
182
+ class subgraph_0(torch.nn.Module):
183
+ def forward(self, l_x_, l_y_):
184
+ add_ = l_x_.add_(1)
185
+ matmul = l_x_ @ l_y_
186
+ sin = matmul.sin()
187
+ child = sin.cos()
188
+ child_1 = l_x_ + l_y_
189
+ child_2 = l_x_ - l_y_
190
+ child_3 = l_x_ @ l_y_
191
+ return (child, child_1, child_2, child_3)
192
+
193
+ By encoding the inputs of hop into a list of HopArgumentInfo and output as a single HopArgumentInfo,
194
+ we would get the following schema:
195
+ invoke_quant_test(Any arg0, Tensor(!) arg1, Tensor arg2, str scheme="\\"nf4\\"") -> (Tensor, Tensor, Tensor, Tensor)
196
+ """
197
+
198
+ @staticmethod
199
+ def from_hop_argument_info(
200
+ op_name: str,
201
+ inp_argument_info: list[HopArgumentInfo],
202
+ out_argument_info: HopArgumentInfo,
203
+ schema_tree_spec: Optional[pytree.TreeSpec],
204
+ ) -> Any:
205
+ args = []
206
+ for i, arg_info in enumerate(inp_argument_info):
207
+ args.append(CArgumentGen.from_hop_argument_info(i, arg_info))
208
+
209
+ # NOTE: we want the output to always be a single argument with torch._C.TupleType.
210
+ assert isinstance(
211
+ out_argument_info.example_value, tuple
212
+ ), f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}"
213
+ assert (
214
+ not out_argument_info.is_mutated
215
+ ), "out_argument_info.is_mutated should always be set to False."
216
+ rets = None
217
+ if len(out_argument_info.example_value) == 1:
218
+ rets = [CArgumentGen.from_hop_argument_info(0, out_argument_info, True)]
219
+ else:
220
+ rets = [
221
+ CArgumentGen.from_hop_argument_info(
222
+ i,
223
+ HopArgumentInfoGen.from_example(
224
+ name=f"out{i}",
225
+ example_value=val,
226
+ default_value=None,
227
+ is_mutated=False,
228
+ ),
229
+ is_output=True,
230
+ )
231
+ for i, val in enumerate(out_argument_info.example_value)
232
+ ]
233
+
234
+ return HopSchema(
235
+ op_name,
236
+ "",
237
+ args,
238
+ rets,
239
+ False,
240
+ False,
241
+ schema_tree_spec,
242
+ )
243
+
244
+
245
+ class HopSchema(torch._C.FunctionSchema):
246
+ def __init__(
247
+ self,
248
+ name: str,
249
+ overload_name: str,
250
+ arguments: list[torch._C.Argument],
251
+ returns: list[torch._C.Argument],
252
+ is_vararg: bool,
253
+ is_varret: bool,
254
+ schema_tree_spec: Optional[pytree.TreeSpec],
255
+ ):
256
+ self.tree_spec = schema_tree_spec
257
+ self.is_vararg = is_vararg
258
+ self.is_varret = is_varret
259
+ super().__init__(
260
+ name,
261
+ overload_name,
262
+ arguments,
263
+ returns,
264
+ self.is_vararg,
265
+ self.is_varret,
266
+ )
267
+
268
+ def __deepcopy__(self, memo: Any) -> "HopSchema":
269
+ # Need to additionally copy the tree_spec since
270
+ # it's not a member of torch._C.FunctionSchema
271
+ return HopSchema(
272
+ self.name,
273
+ self.overload_name,
274
+ self.arguments,
275
+ self.returns,
276
+ self.is_vararg,
277
+ self.is_varret,
278
+ copy.deepcopy(self.tree_spec),
279
+ )
280
+
281
+
282
+ def find_hop_schema(
283
+ gm: torch.fx.GraphModule, target: Target
284
+ ) -> list[torch._C.FunctionSchema]:
285
+ schemas = []
286
+ for node in gm.graph.find_nodes(op="call_function", target=target):
287
+
288
+ def _get_example_value(node: torch.fx.Node) -> Any:
289
+ if node.op == "get_attr":
290
+ assert isinstance(node.target, str)
291
+ return getattr(gm, node.target)
292
+ else:
293
+ return (
294
+ node.meta["example_value"]
295
+ if "example_value" in node.meta
296
+ else node.meta["val"]
297
+ )
298
+
299
+ fake_args, fake_kwargs = pytree.tree_map_only(
300
+ torch.fx.Node,
301
+ _get_example_value,
302
+ (node.args, node.kwargs),
303
+ )
304
+ schema = node.target.gen_schema(*fake_args, **fake_kwargs)
305
+ schemas.append(schema)
306
+ return schemas
archive/.venv/Lib/site-packages/torch/_higher_order_ops/strict_mode.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch._subclasses.functional_tensor
4
+ import torch.utils._pytree as pytree
5
+ from torch._C import DispatchKey
6
+ from torch._functorch.utils import exposed_in
7
+ from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented
8
+ from torch._ops import HigherOrderOperator
9
+ from torch._subclasses.fake_tensor import FakeTensorMode
10
+ from torch.fx.experimental.proxy_tensor import (
11
+ _temp_remove_metadata_torch_function_mode,
12
+ _temp_remove_pre_dispatch_torch_function_mode,
13
+ disable_proxy_modes_tracing,
14
+ make_fx,
15
+ ProxyTorchDispatchMode,
16
+ track_tensor_tree,
17
+ )
18
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
19
+
20
+
21
+ @exposed_in("torch")
22
+ def strict_mode(callable, operands):
23
+ from torch._dynamo.backends.debugging import (
24
+ make_eager_backend_with_torch_function_modes,
25
+ )
26
+
27
+ if torch.compiler.is_dynamo_compiling():
28
+ return strict_mode_op(callable, operands)
29
+
30
+ with _set_compilation_env():
31
+ with _temp_remove_metadata_torch_function_mode() as metadata_mode:
32
+ with _temp_remove_pre_dispatch_torch_function_mode() as predispatch_mode:
33
+ modes = [metadata_mode, predispatch_mode]
34
+ modes = [mode for mode in modes if mode is not None]
35
+ if modes:
36
+ backend = make_eager_backend_with_torch_function_modes(modes)
37
+ else:
38
+ backend = "eager"
39
+ with torch._dynamo.utils.disable_cache_limit():
40
+ return torch.compile(
41
+ strict_mode_op, backend=backend, fullgraph=True
42
+ )(callable, operands)
43
+
44
+
45
+ class StrictMode(HigherOrderOperator):
46
+ def __init__(self):
47
+ super().__init__("strict_mode")
48
+
49
+ def __call__(self, callable, operands):
50
+ return super().__call__(callable, operands)
51
+
52
+
53
+ strict_mode_op = StrictMode()
54
+
55
+
56
+ @strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd)
57
+ def strict_mode_op_dense(callable, operands):
58
+ mode = _get_current_dispatch_mode()
59
+ assert mode is None, "Mode should never be enabled for CPU/CUDA key"
60
+ return callable(*operands)
61
+
62
+
63
+ strict_mode_op.py_autograd_impl(
64
+ autograd_not_implemented(strict_mode_op, deferred_error=True)
65
+ )
66
+
67
+
68
+ @strict_mode_op.py_impl(ProxyTorchDispatchMode)
69
+ def inner(mode, callable, operands):
70
+ return trace_strict_mode(mode, strict_mode_op, callable, operands)
71
+
72
+
73
+ def trace_strict_mode(mode, strict_mode_op, callable, operands):
74
+ pre_dispatch = getattr(mode, "pre_dispatch", False)
75
+
76
+ with disable_proxy_modes_tracing():
77
+ graph = make_fx(callable, pre_dispatch=pre_dispatch)(*operands)
78
+
79
+ graph_name = mode.tracer.get_fresh_qualname("strict_graph_")
80
+ mode.tracer.root.register_module(graph_name, graph)
81
+
82
+ args = (graph, operands)
83
+
84
+ proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
85
+
86
+ out_proxy = mode.tracer.create_proxy(
87
+ "call_function", strict_mode_op, proxy_args, {}, name="strict_mode"
88
+ )
89
+
90
+ out = graph(*operands)
91
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
92
+
93
+
94
+ @strict_mode_op.py_impl(FakeTensorMode)
95
+ def strict_mode_fake_tensor_mode(mode, callable, operands):
96
+ with mode:
97
+ true_outs = callable(*operands)
98
+ return true_outs
99
+
100
+
101
+ @strict_mode_op.py_functionalize_impl
102
+ def strict_mode_func(ctx, callable, inputs):
103
+ unwrapped_inputs = ctx.unwrap_tensors(inputs)
104
+ with ctx.redispatch_to_next():
105
+ functional_callable = ctx.functionalize(callable)
106
+
107
+ cond_return = strict_mode_op(functional_callable, unwrapped_inputs)
108
+ return ctx.wrap_tensors(cond_return)
archive/.venv/Lib/site-packages/torch/_higher_order_ops/torchbind.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ from contextlib import contextmanager
4
+
5
+ import torch
6
+ from torch._C import DispatchKey # @manual
7
+ from torch._functorch._aot_autograd.utils import KNOWN_TYPES
8
+ from torch._higher_order_ops.utils import autograd_not_implemented
9
+ from torch._library.fake_class_registry import (
10
+ _is_script_object,
11
+ _ns_and_class_name,
12
+ FakeScriptObject,
13
+ )
14
+ from torch._ops import HigherOrderOperator
15
+ from torch._subclasses.fake_tensor import FakeTensorMode
16
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
17
+ from torch.fx.node import has_side_effect
18
+ from torch.utils import _pytree as pytree
19
+
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+
24
+ # The call_torchbind operator represents a method invocation on a torchbind
25
+ # object. The calling convention is:
26
+ # call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
27
+ # We do not expect users to write this operator directly. Instead it will be
28
+ # emitted by Dynamo when tracing encounters a torchbind object.
29
+ class CallTorchBind(HigherOrderOperator):
30
+ def __init__(self):
31
+ super().__init__("call_torchbind")
32
+
33
+ def __call__(self, obj, method, *args, **kwargs):
34
+ return super().__call__(obj, method, *args, **kwargs)
35
+
36
+ @staticmethod
37
+ def schema(obj, method) -> torch.FunctionSchema:
38
+ """
39
+ Returns the schema of ``CallTorchbind.__call__``.
40
+ """
41
+ assert isinstance(obj, torch._inductor.ir.TorchBindObject)
42
+ val = obj.get_real_obj()
43
+ schema = val._get_method(method).schema
44
+ schema_str = str(schema)
45
+ new_schema_str = f"call_torchbind({str(schema.arguments[0].real_type)} {schema.arguments[0].name},"
46
+ first_comma_index = schema_str.find(",")
47
+ if first_comma_index == -1:
48
+ # If no comma is found, find the last closing parenthesis
49
+ first_comma_index = schema_str.rfind(") ->")
50
+ new_schema_str = new_schema_str + " str method" + schema_str[first_comma_index:]
51
+ new_schema = torch._C.parse_schema(new_schema_str)
52
+ return new_schema
53
+
54
+
55
+ call_torchbind = CallTorchBind()
56
+
57
+ # Register this operator as side-effectful with FX.
58
+ # TODO: this is not really sufficient. While passes (hopefully) check
59
+ # Node.is_impure() and make good decisions, we also assume we can execute the
60
+ # graph as many times as we want without changing behavior, which is NOT true of
61
+ # ops that mutate torchbind object state.
62
+ has_side_effect(call_torchbind)
63
+
64
+ _orig_scriptmethod_call = torch.ScriptMethod.__call__
65
+
66
+
67
+ def torchbind_method_redispatch(self, *args, **kwargs):
68
+ if _is_script_object(self.raw_owner):
69
+ return call_torchbind(self.raw_owner, self.name, *args, **kwargs)
70
+ return _orig_scriptmethod_call(self, *args, **kwargs)
71
+
72
+
73
+ @contextmanager
74
+ def enable_torchbind_tracing():
75
+ """Context manager that acts as a feature flag to enable torchbind tracing
76
+ behavior. Once torchbind tracing has been stabilized, we can remove this and
77
+ turn it always on.
78
+ """
79
+ try:
80
+ KNOWN_TYPES.append(torch.ScriptObject)
81
+ torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign]
82
+ yield
83
+ finally:
84
+ assert (
85
+ KNOWN_TYPES.pop() is torch.ScriptObject
86
+ ), "Someone else messed with KNOWN_TYPES during tracing, exploding."
87
+ torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign]
88
+
89
+
90
+ @call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd)
91
+ def call_torchbind_impl(obj, method, *args, **kwargs):
92
+ if isinstance(obj, torch.ScriptObject):
93
+ return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)
94
+ elif isinstance(obj, FakeScriptObject):
95
+ return getattr(obj.wrapped_obj, method)(*args, **kwargs)
96
+ else:
97
+ raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind")
98
+
99
+
100
+ @call_torchbind.py_impl(ProxyTorchDispatchMode)
101
+ def inner(mode, *args, **kwargs):
102
+ proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
103
+ proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
104
+
105
+ out_proxy = mode.tracer.create_proxy(
106
+ "call_function",
107
+ call_torchbind,
108
+ proxy_args,
109
+ proxy_kwargs,
110
+ )
111
+ out = call_torchbind(*args, **kwargs)
112
+
113
+ obj, method, *_rest_args = args
114
+ if isinstance(obj, torch.ScriptObject):
115
+ ns, class_name = _ns_and_class_name(
116
+ obj._type().qualified_name() # type: ignore[attr-defined]
117
+ )
118
+ log.warning(
119
+ "Tracing torchbind method %s.%s with real ScriptObject. This may"
120
+ " cause the original object being mutated. If this is not intended,"
121
+ ' You can register a fake class with torch._library.register_fake_class("%s::%s").',
122
+ class_name,
123
+ method,
124
+ ns,
125
+ class_name,
126
+ )
127
+
128
+ ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
129
+ if "val" not in out_proxy.node.meta:
130
+ assert out is None or isinstance(
131
+ out, (int, float, bool)
132
+ ), "Currently, only these constant dtypes are supported to be returned from torchbind methods."
133
+ out_proxy.node.meta["val"] = out
134
+ return ret
135
+
136
+
137
+ # When tracing with fake script object, the call_torchbind op will return a fake tensor
138
+ # When tracing with real script object, the call_torchbind op may return a real tensor,
139
+ # we need to convert it to fake tensor mannually. Dynamic shape is surpported.
140
+ @call_torchbind.py_impl(FakeTensorMode)
141
+ def call_torchbind_fake(mode, *args, **kwargs):
142
+ with mode:
143
+ out = call_torchbind_impl(*args, **kwargs)
144
+ return pytree.tree_map_only(
145
+ torch.Tensor,
146
+ lambda x: mode.from_tensor(x, static_shapes=True)
147
+ if not isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
148
+ else x,
149
+ out,
150
+ )
151
+
152
+
153
+ call_torchbind.py_autograd_impl(
154
+ autograd_not_implemented(call_torchbind, deferred_error=True)
155
+ )
156
+
157
+
158
+ @call_torchbind.py_functionalize_impl
159
+ def call_torchbind_func(ctx, *args, **kwargs):
160
+ from torch._higher_order_ops.effects import handle_effects
161
+
162
+ return handle_effects(
163
+ ctx.mode._allow_token_discovery, ctx.mode._tokens, call_torchbind, args, kwargs
164
+ )
archive/.venv/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py ADDED
@@ -0,0 +1,2051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import copy
3
+ import dataclasses
4
+ import functools
5
+ import inspect
6
+ import itertools
7
+ import logging
8
+ import operator
9
+ import threading
10
+ from collections import defaultdict
11
+ from collections.abc import Sequence
12
+ from typing import Any, Callable, Optional, TYPE_CHECKING, Union
13
+ from typing_extensions import Never
14
+
15
+ import sympy
16
+
17
+ import torch.fx as fx
18
+ import torch.utils._pytree as pytree
19
+ from torch import SymInt, Tensor
20
+ from torch._C import DispatchKey
21
+ from torch._ops import HigherOrderOperator
22
+ from torch._prims_common import clone_preserve_strides
23
+ from torch._subclasses.fake_tensor import FakeTensorMode
24
+ from torch.fx.experimental.proxy_tensor import (
25
+ disable_proxy_modes_tracing,
26
+ ProxyTorchDispatchMode,
27
+ track_tensor_tree,
28
+ )
29
+ from torch.fx.experimental.symbolic_shapes import guard_scalar
30
+ from torch.types import IntLikeType
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from triton._C.libtriton.ir import (
35
+ module as TritonIRModule,
36
+ operation as TritonIROperation,
37
+ )
38
+
39
+ from torch._dynamo.symbolic_convert import InstructionTranslator
40
+ from torch._dynamo.variables.constant import ConstantVariable
41
+ from torch._dynamo.variables.functions import TritonKernelVariable
42
+ from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
43
+ from torch.fx.proxy import Proxy
44
+ from torch.utils._triton import has_triton
45
+
46
+ TritonMetaParamsType = dict[str, int]
47
+ TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...]
48
+ TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]]
49
+ TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
50
+
51
+ if has_triton():
52
+ from triton.runtime.autotuner import Autotuner, Config as TritonConfig
53
+ from triton.runtime.jit import JITFunction
54
+ else:
55
+
56
+ class Autotuner: # type: ignore[no-redef]
57
+ pass
58
+
59
+ class JITFunction: # type: ignore[no-redef]
60
+ pass
61
+
62
+ TritonKernelType = Union[Autotuner, JITFunction]
63
+ # mypy specifically complains that TritonAutotunerType is not a valid type if Autotuner is not inside of a Union.
64
+ TritonAutotunerType = Union[Autotuner]
65
+
66
+ log = logging.getLogger("torch._dynamo")
67
+
68
+ # e.g. for a host-side Triton TMA API call ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``,
69
+ # the metadata will look like ``("experimental", ([50, 60], [32, 15], 4))``
70
+ TMAExperimentalMetadata = tuple[
71
+ str, # type of TMA (should be "experimental")
72
+ tuple[
73
+ list[IntLikeType], # dims
74
+ list[IntLikeType], # block_dims
75
+ IntLikeType, # element_size
76
+ ],
77
+ ]
78
+
79
+ # e.g. for host-side Triton TMA API call ``TensorDescriptor.from_tensor(ptr, [32, 64])``
80
+ # the metadata will look like ``("stable", ([32, 64],))``
81
+ TMAStableMetadata = tuple[
82
+ str, # type of TMA ("experimental" or "stable")
83
+ tuple[list[IntLikeType],], # block_shape
84
+ ]
85
+
86
+
87
+ def create_tma_experimental_metadata(
88
+ dims: list[IntLikeType],
89
+ block_dims: list[IntLikeType],
90
+ element_size: IntLikeType,
91
+ ) -> TMAExperimentalMetadata:
92
+ return ("experimental", (dims, block_dims, element_size))
93
+
94
+
95
+ def maybe_unpack_tma_experimental_metadata(
96
+ tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata]
97
+ ) -> Optional[tuple[list[IntLikeType], list[IntLikeType], IntLikeType]]:
98
+ if not tma_meta or len(tma_meta) != 2:
99
+ return None
100
+ if tma_meta[0] == "experimental":
101
+ return tma_meta[1] # type: ignore[return-value]
102
+ return None
103
+
104
+
105
+ def create_tma_stable_metadata(
106
+ block_shape: list[IntLikeType],
107
+ ) -> TMAStableMetadata:
108
+ return ("stable", (block_shape,))
109
+
110
+
111
+ def maybe_unpack_tma_stable_metadata(
112
+ tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata]
113
+ ) -> Optional[tuple[list[IntLikeType]]]:
114
+ if not tma_meta or len(tma_meta) != 2:
115
+ return None
116
+ if tma_meta[0] == "stable":
117
+ return tma_meta[1] # type: ignore[return-value]
118
+ return None
119
+
120
+
121
+ # TMADescriptorMetadata maps kernel parameter names to the metadata that allows
122
+ # reconstructing TMA descriptors from the underlying tensors (passed as kernel
123
+ # arguments in the fx graph, instead of the TMA descriptors).
124
+ #
125
+ # Since there are two TMA APIs (the old "experimental" API and the new "stable" API),
126
+ # each entry in the dict is a tuple that starts with a string, either "experimental"
127
+ # or "stable". The second entry in the tuple is another tuple, with data that depends
128
+ # on the API type (see TMAExperimentalMetadata and TMAStableMetadata above).
129
+ #
130
+ # These are stored as raw tuples (instead of classes) for ease of serialization.
131
+ TMADescriptorMetadata = dict[
132
+ str, # kernel parameter name
133
+ Union[TMAExperimentalMetadata, TMAStableMetadata],
134
+ ]
135
+
136
+
137
+ ###############################################################################
138
+ # Kernel Side Table
139
+
140
+
141
+ # We cannot put Triton Kernels into the FX graph as the graph nodes
142
+ # do not support arbitrary functions.
143
+ # Use a side table.
144
+ # We use two dicts so that fetching both the kernel and id are O(1)
145
+ class KernelSideTable:
146
+ id_to_kernel: dict[int, "TritonKernelType"] = {}
147
+ kernel_to_id: dict["TritonKernelType", int] = {}
148
+ constant_args: dict[int, dict[str, Any]] = {}
149
+ lock = threading.Lock()
150
+
151
+ # Returns index on the table
152
+ def add_kernel(self, kernel: "TritonKernelType") -> int:
153
+ with self.lock:
154
+ if kernel in self.kernel_to_id:
155
+ return self.kernel_to_id[kernel]
156
+
157
+ idx = len(self.id_to_kernel)
158
+ self.id_to_kernel[idx] = kernel
159
+ self.kernel_to_id[kernel] = idx
160
+ return idx
161
+
162
+ # Returns the triton kernel at the given index
163
+ def get_kernel(self, idx: int) -> "TritonKernelType":
164
+ # No need to lock here as fetching from dict is atomic
165
+ assert idx in self.id_to_kernel
166
+ return self.id_to_kernel[idx]
167
+
168
+ # Not every constant arg can be added to the graph. Use this side table
169
+ # for constant args.
170
+ def add_constant_args(self, args: dict[str, Any]) -> int:
171
+ with self.lock:
172
+ idx = len(self.constant_args)
173
+ self.constant_args[idx] = args
174
+ return idx
175
+
176
+ # Returns the constant args
177
+ def get_constant_args(self, idx: int) -> dict[str, Any]:
178
+ # No need to lock here as fetching from dict is atomic
179
+ assert idx in self.constant_args
180
+ return self.constant_args[idx]
181
+
182
+ # Resets the table (only meant to be used in unit tests)
183
+ # This is only safe assuming single threaded execution
184
+ def reset_table(self) -> None:
185
+ self.id_to_kernel = {}
186
+ self.kernel_to_id = {}
187
+ self.constant_args = {}
188
+
189
+
190
+ kernel_side_table = KernelSideTable()
191
+
192
+
193
+ ###############################################################################
194
+ # Mutation Tracker
195
+
196
+
197
+ @dataclasses.dataclass(frozen=True)
198
+ class Param:
199
+ idx: int
200
+
201
+
202
+ @dataclasses.dataclass(frozen=True)
203
+ class Intermediate:
204
+ idx: int
205
+
206
+ def fake(self) -> bool:
207
+ return self.idx < 0
208
+
209
+
210
+ @dataclasses.dataclass(frozen=True)
211
+ class Op:
212
+ name: str
213
+ fn_call_name: Optional[str]
214
+ args: list[Union[Param, Intermediate]]
215
+ ret: Intermediate = dataclasses.field(repr=False)
216
+ # used for scf.yield: see [Note: scf.yield fix-up]
217
+ sub_idx: Optional[int] = None
218
+ # used for tt.elementwise_inline_asm
219
+ # `is_pure = True` assumes the asm block has no side-effects
220
+ is_pure: bool = False
221
+
222
+ def __post_init__(self) -> None:
223
+ if self.name == "tt.call":
224
+ assert self.fn_call_name is not None
225
+ else:
226
+ assert self.fn_call_name is None
227
+
228
+
229
+ def generate_ttir(
230
+ kernel: "TritonKernelType",
231
+ kwargs: dict[str, Any],
232
+ tma_descriptor_metadata: TMADescriptorMetadata,
233
+ ) -> tuple["TritonIRModule", list[str]]:
234
+ """
235
+ Uses Triton's internal code generation to create TTIR
236
+ """
237
+ import sympy
238
+ import triton
239
+ import triton.runtime.jit
240
+ from triton.compiler.compiler import ASTSource
241
+ from triton.runtime.autotuner import Autotuner
242
+ from triton.runtime.jit import JITFunction
243
+
244
+ from torch._inductor.utils import (
245
+ get_triton_attrs_descriptor_version,
246
+ triton_version_uses_attrs_dict,
247
+ TritonAttrsDescriptorVersion,
248
+ )
249
+ from torch.utils._triton import has_triton_tensor_descriptor_host_tma
250
+
251
+ triton_version = get_triton_attrs_descriptor_version()
252
+
253
+ import torch._inductor.ir
254
+ from torch._subclasses.fake_tensor import FakeTensor
255
+
256
+ if isinstance(kernel, Autotuner):
257
+ if len(kernel.configs) > 0:
258
+ # If we are autotuning, then it doesn't matter which version gets
259
+ # picked for tracing purposes, so lets pick the first one
260
+ kwargs = {**kwargs, **kernel.configs[0].kwargs}
261
+ kernel = kernel.fn
262
+
263
+ assert isinstance(kernel, JITFunction)
264
+
265
+ context = triton._C.libtriton.ir.context()
266
+ target = triton.runtime.driver.active.get_current_target()
267
+ backend = triton.compiler.compiler.make_backend(target)
268
+ options = backend.parse_options({})
269
+
270
+ # ignore backend-specific kwargs same way as in the native Triton code
271
+ # https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596
272
+ # why this is important for user-defined Triton kernels on AMD: https://github.com/pytorch/pytorch/issues/140800
273
+ for name in list(kwargs):
274
+ if name not in kernel.arg_names and name in options.__dict__:
275
+ kwargs.pop(name)
276
+
277
+ if len(kwargs) != len(kernel.arg_names):
278
+ raise ValueError(
279
+ "Incorrect number of arguments passed to kernel: "
280
+ f"passed {list(kwargs.keys())}, expected {kernel.arg_names}."
281
+ )
282
+
283
+ # Replace all SymExprs with a regular value for TTIR generation
284
+ # Replace all FakeTensor/TensorBox with real tensors
285
+ # These replacements are needed for triton's type, key and config functions
286
+ ordered_args: dict[str, Any] = {}
287
+ for name in kernel.arg_names:
288
+ a = kwargs[name]
289
+ if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)):
290
+ ordered_args[name] = 2
291
+ elif (
292
+ stable_meta := maybe_unpack_tma_stable_metadata(
293
+ tma_descriptor_metadata.get(name, None)
294
+ )
295
+ ) is not None:
296
+ from triton.tools.tensor_descriptor import TensorDescriptor
297
+
298
+ block_shape = stable_meta[0]
299
+ with torch._C._DisableTorchDispatch():
300
+ # need 16-byte aligned strides
301
+ elements_per_dim = max(1, 16 // a.dtype.itemsize)
302
+ base_tensor = torch.empty(
303
+ [elements_per_dim] * len(block_shape), dtype=a.dtype
304
+ )
305
+ ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape)
306
+ elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
307
+ with torch._C._DisableTorchDispatch():
308
+ ordered_args[name] = torch.empty(2, dtype=a.dtype)
309
+ else:
310
+ ordered_args[name] = a
311
+
312
+ def is_stable_tensor_descriptor_arg(arg: Any) -> bool:
313
+ if has_triton_tensor_descriptor_host_tma():
314
+ from triton.tools.tensor_descriptor import TensorDescriptor
315
+
316
+ if isinstance(arg, TensorDescriptor):
317
+ return True
318
+ return False
319
+
320
+ def is_tensor_like_arg(arg: Any) -> bool:
321
+ if isinstance(arg, Tensor) or is_stable_tensor_descriptor_arg(arg):
322
+ return True
323
+ return False
324
+
325
+ # Note: one would expect that each input to the triton kernel maps to
326
+ # one input parameter in the TTIR. This is _not_ true for TMA descriptors:
327
+ # one TMA descriptor gets converted into:
328
+ # * one TMA descriptor input
329
+ # * N strides, for a rank-N tensor
330
+ # * N sizes, for a rank-N tensor
331
+ # To account for this, we inject some fake arg names as placeholders for
332
+ # the stride and size parameters.
333
+ def get_tensor_names(name: str, arg: Any) -> list[str]:
334
+ if isinstance(arg, Tensor):
335
+ return [name]
336
+ if is_stable_tensor_descriptor_arg(arg):
337
+ stable_meta = maybe_unpack_tma_stable_metadata(
338
+ tma_descriptor_metadata[name]
339
+ )
340
+ assert stable_meta is not None
341
+ block_shape = stable_meta[0]
342
+ tensor_rank = len(block_shape)
343
+ names = [name]
344
+ names.extend(name + f" STRIDE PLACEHOLDER {i}" for i in range(tensor_rank))
345
+ names.extend(name + f" SIZE PLACEHOLDER {i}" for i in range(tensor_rank))
346
+ return names
347
+ return []
348
+
349
+ ordered_tensor_names = list(
350
+ itertools.chain.from_iterable(
351
+ get_tensor_names(name, arg) for name, arg in ordered_args.items()
352
+ )
353
+ )
354
+
355
+ def _get_specialization(args): # type: ignore[no-untyped-def]
356
+ # Support multiple triton versions.
357
+ # This code basically copies JITFunction.run() logic to get the attrs to construct an ASTSource.
358
+ if triton_version == TritonAttrsDescriptorVersion.V1_COMPILER:
359
+ return kernel._get_config(*args)
360
+ elif triton_version in {
361
+ TritonAttrsDescriptorVersion.V2_BACKENDS,
362
+ TritonAttrsDescriptorVersion.V3_BACKENDS_TUPLE,
363
+ }:
364
+ from triton.backends.compiler import AttrsDescriptor # noqa: F401
365
+
366
+ target = triton.runtime.driver.active.get_current_target()
367
+ backend_ = triton.compiler.compiler.make_backend(target)
368
+ return backend_.get_attrs_descriptor(args, kernel.params)
369
+ else:
370
+ assert (
371
+ get_triton_attrs_descriptor_version()
372
+ == TritonAttrsDescriptorVersion.V4_DICT
373
+ )
374
+ # specialize_impl switched to create_specialize_impl in https://github.com/triton-lang/triton/pull/6099
375
+ if hasattr(triton.runtime.jit, "create_specialize_impl"):
376
+ try:
377
+ # Latest versions of Triton take specialize_extra as an arg to create_specialize_impl
378
+ specialize_impl = triton.runtime.jit.create_specialize_impl(
379
+ specialize_extra=backend.get_arg_specialization
380
+ )
381
+ except TypeError: # Unknown arg `specialize_extra`
382
+ # Older versions of Triton take specialize_extra as an arg to specialize_impl
383
+ specialize_impl = functools.partial(
384
+ triton.runtime.jit.create_specialize_impl(),
385
+ specialize_extra=backend.get_arg_specialization,
386
+ )
387
+ else:
388
+ from triton.runtime.jit import specialize_impl as specialize_impl_orig
389
+
390
+ specialize_impl = functools.partial(
391
+ specialize_impl_orig,
392
+ specialize_extra=backend.get_arg_specialization,
393
+ )
394
+
395
+ from triton._utils import find_paths_if, get_iterable_path
396
+
397
+ # logic is copied from: binder = create_function_from_signature(self.signature, self.params, backend)
398
+ attrvals = []
399
+ for arg, kp in zip(args, kernel.params):
400
+ if kp.is_constexpr:
401
+ attrvals.append(arg)
402
+ else:
403
+ spec = specialize_impl(
404
+ arg,
405
+ is_const=kp.is_const,
406
+ specialize_value=not kp.do_not_specialize,
407
+ align=not kp.do_not_specialize_on_alignment,
408
+ )
409
+ attrvals.append(spec[1])
410
+
411
+ attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
412
+ attrs = {
413
+ k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs
414
+ }
415
+ return attrs
416
+
417
+ specialization = _get_specialization(ordered_args.values())
418
+ constants = {
419
+ name: arg for name, arg in ordered_args.items() if not is_tensor_like_arg(arg)
420
+ }
421
+
422
+ if (mangle_type := getattr(triton.runtime.jit, "mangle_type", None)) is not None:
423
+
424
+ def get_signature_value(idx: int, arg: Any) -> str:
425
+ if kernel.params[idx].is_constexpr:
426
+ return "constexpr"
427
+ return mangle_type(arg)
428
+
429
+ else:
430
+
431
+ def get_signature_value(idx: int, arg: Any) -> str:
432
+ return kernel._type_of(kernel.key_of(arg))
433
+
434
+ if triton_version_uses_attrs_dict():
435
+ # In newer versions of Triton, the signature includes constexpr args
436
+ signature = {
437
+ name: get_signature_value(i, arg)
438
+ for i, (name, arg) in enumerate(ordered_args.items())
439
+ }
440
+ else:
441
+ # In older versions of Triton, the signature does not include constexpr args
442
+ signature = {
443
+ name: get_signature_value(i, arg)
444
+ for i, (name, arg) in enumerate(ordered_args.items())
445
+ if i not in kernel.constexprs
446
+ }
447
+
448
+ triton._C.libtriton.ir.load_dialects(context)
449
+ backend.load_dialects(context)
450
+
451
+ src = ASTSource(kernel, signature, constants, specialization)
452
+
453
+ # Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
454
+ # backward compatibility here.
455
+ make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
456
+ get_codegen_implementation_sig_params = len(
457
+ inspect.signature(backend.get_codegen_implementation).parameters
458
+ )
459
+ if make_ir_sig_params == 2:
460
+ ttir_module = src.make_ir(options, context)
461
+ elif make_ir_sig_params == 3:
462
+ codegen_fns = backend.get_codegen_implementation()
463
+ ttir_module = src.make_ir(options, codegen_fns, context)
464
+ else:
465
+ codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
466
+ codegen_fns = backend.get_codegen_implementation(*codegen_args)
467
+ module_map = backend.get_module_map()
468
+ ttir_module = src.make_ir(options, codegen_fns, module_map, context)
469
+ if not ttir_module.verify():
470
+ raise RuntimeError("Verification for TTIR module has failed")
471
+
472
+ return ttir_module, ordered_tensor_names
473
+
474
+
475
+ def ttir_to_functions(
476
+ ttir_module: "TritonIRModule",
477
+ ) -> dict[str, dict[Intermediate, list[Op]]]:
478
+ """
479
+ Walk the `ttir_module` bottom up to mine the `functions` from
480
+ the structured MLIR entities representing the Triton kernel
481
+ (mlir::Operation, mlir::Block, mlir::Region).
482
+ """
483
+ functions: dict[str, dict[Intermediate, list[Op]]] = {}
484
+
485
+ # block id --> op result (Intermediate) --> one or more ops
486
+ op_stack: dict[int, dict[Intermediate, list[Op]]] = defaultdict(
487
+ lambda: defaultdict(list)
488
+ )
489
+ region_id_to_block_ids: dict[int, list[int]] = defaultdict(list)
490
+ block_id_to_block_arg_ids: dict[int, list[int]] = {}
491
+ replacements: dict[int, Union[Intermediate, Param]] = {}
492
+ reindex_map: dict[int, int] = {}
493
+ next_fake_intermediate = 0
494
+
495
+ def reindex(idx: int) -> int:
496
+ if idx not in reindex_map:
497
+ reindex_map[idx] = len(reindex_map)
498
+ return reindex_map[idx]
499
+
500
+ def mlir_to_functions(op: "TritonIROperation") -> None:
501
+ name: str = op.get_name()
502
+ if name == "builtin.module":
503
+ # this wraps all tt.func ops
504
+ return
505
+
506
+ operand_ids: list[int] = [
507
+ reindex(op.get_operand(i).id()) for i in range(op.get_num_operands())
508
+ ]
509
+ result_ids: list[int] = [
510
+ reindex(op.get_result(i).id()) for i in range(op.get_num_results())
511
+ ]
512
+
513
+ child_block_ids: list[int] = []
514
+ for i in [op.get_region(i).id() for i in range(op.get_num_regions())]:
515
+ # as the walk is bottom-up, the region_id_to_block_ids[i]
516
+ # must be populated by the time we process the enclosing op
517
+ child_block_ids.extend(region_id_to_block_ids[i])
518
+
519
+ parent_block_id = -1
520
+ parent_block = op.get_block()
521
+ if parent_block is not None:
522
+ parent_block_id = parent_block.id()
523
+ if parent_block_id not in block_id_to_block_arg_ids:
524
+ block_id_to_block_arg_ids[parent_block_id] = []
525
+ for i in range(parent_block.get_num_arguments()):
526
+ block_id_to_block_arg_ids[parent_block_id].append(
527
+ reindex(parent_block.get_argument(i).id()),
528
+ )
529
+ # the region info is collected via ops' parent blocks to be
530
+ # used later when the region's encloding op is traversed
531
+ parent_region = parent_block.get_parent()
532
+ if parent_region is not None:
533
+ region_id_to_block_ids[parent_region.id()].append(parent_block_id)
534
+
535
+ nonlocal next_fake_intermediate
536
+
537
+ if name == "tt.func":
538
+ # for function ops: gather and inline
539
+ # the ops from all child blocks
540
+ fn_ops = defaultdict(list)
541
+ for child_block_id in child_block_ids:
542
+ for result, block_fn_ops in op_stack.pop(child_block_id).items():
543
+ for block_fn_op in block_fn_ops:
544
+ fn_ops[result].append(block_fn_op)
545
+
546
+ # replace the corresponding Intermediates in the
547
+ # child op args with the function args (Params)
548
+ for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]):
549
+ replacements[idx] = Param(i)
550
+
551
+ for fn_op_list in fn_ops.values():
552
+ for fn_op in fn_op_list:
553
+ for i in range(len(fn_op.args)):
554
+ arg = fn_op.args[i]
555
+ seen = set() # to break cycles
556
+ # there can be transitive replacements, but likely
557
+ # no cycles (we keep the `seen` set just in case)
558
+ while (
559
+ isinstance(arg, Intermediate)
560
+ and arg.idx in replacements
561
+ and arg.idx not in seen
562
+ ):
563
+ seen.add(arg.idx)
564
+ arg = fn_op.args[i] = replacements[arg.idx]
565
+
566
+ # next function capture starts
567
+ # with empty replacements
568
+ replacements.clear()
569
+
570
+ fn_name = op.get_str_attr("sym_name")
571
+ functions[fn_name] = fn_ops
572
+ elif child_block_ids:
573
+ if name in {"scf.if", "scf.for", "scf.while", "tt.reduce", "tt.scan"}:
574
+ # for blocked ops: inline the enclosed ops into
575
+ # the parent block + rewire the last op in each
576
+ # child block to return the block result
577
+ return_ops = []
578
+ for block_id in child_block_ids:
579
+ if name == "scf.for":
580
+ # example:
581
+ # %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (i32) ...
582
+ # block args: 2 (%iv, %arg)
583
+ # op operands: 4 (%lb, %ub, %step, %init)
584
+ # `%arg` is mapping to `%init`
585
+ for i, idx in enumerate(block_id_to_block_arg_ids[block_id]):
586
+ if i == 0:
587
+ next_fake_intermediate -= 1
588
+ replacements[idx] = Intermediate(next_fake_intermediate)
589
+ else:
590
+ replacements[idx] = Intermediate(operand_ids[i + 2])
591
+ elif name == "scf.while":
592
+ # example:
593
+ # %3:3 = scf.while (%arg2 = %1, %arg3 = %2, %arg4 = %c0_i32_8) ...
594
+ # block args: 3 (%arg2, %arg3, %arg4)
595
+ # op operands: 3 (%1, %2, %c0_i32_8)
596
+ # `%arg2` is mapping to `%1`, `%arg3` is mapping to `%2`, ...
597
+ for i, idx in enumerate(block_id_to_block_arg_ids[block_id]):
598
+ replacements[idx] = Intermediate(operand_ids[i])
599
+ elif name == "scf.if":
600
+ # the scf block args are ignored by the pass. but, as they
601
+ # may be used as operands of the ops inside the block
602
+ # (and nested blocks inlined in the current block by now),
603
+ # they are replaced by new fake Intermediates to avoid "this
604
+ # operand is not returned by any other op in the fn" error
605
+ # in the downstream analysis
606
+ for idx in block_id_to_block_arg_ids[block_id]:
607
+ next_fake_intermediate -= 1
608
+ replacements[idx] = Intermediate(next_fake_intermediate)
609
+ else:
610
+ assert name in ("tt.reduce", "tt.scan")
611
+ # wire the block arguments to the op arguments
612
+ num_operands = len(operand_ids)
613
+ block_arg_ids = block_id_to_block_arg_ids[block_id]
614
+ assert len(block_arg_ids) == 2 * num_operands, (
615
+ f"{name} is expected to have twice as "
616
+ "many block arguments as op arguments: "
617
+ f"{operand_ids=}, {block_arg_ids=}."
618
+ )
619
+ for i, idx in enumerate(block_arg_ids):
620
+ # for a tt.reduce/tt.scan op with N arguments, the block
621
+ # arguments comprise N reduced values followed by
622
+ # N current values corresponding to the N op args
623
+ replacements[idx] = Intermediate(
624
+ operand_ids[i % num_operands]
625
+ )
626
+
627
+ if block_id in op_stack:
628
+ block_ops = op_stack.pop(block_id)
629
+ if not block_ops:
630
+ continue
631
+ last_ret, last_ops = block_ops.popitem()
632
+ if all(
633
+ op.name
634
+ in ("scf.yield", "tt.reduce.return", "tt.scan.return")
635
+ for op in last_ops
636
+ ):
637
+ # if last_ops are all return ops, treat them separately
638
+ return_ops.extend(last_ops)
639
+ else:
640
+ # otherwise, return last_ops to the block
641
+ block_ops[last_ret] = last_ops
642
+ for op_result, child_ops in block_ops.items():
643
+ op_stack[parent_block_id][op_result].extend(child_ops)
644
+
645
+ scf_results = [Intermediate(idx) for idx in result_ids]
646
+
647
+ if return_ops and all(
648
+ (op.name == "scf.yield" and len(result_ids) == len(op.args))
649
+ for op in return_ops
650
+ ):
651
+ # [Note: scf.yield fix-up]
652
+ #
653
+ # TL;DR: if our scf.yield takes N args, then we'll create N scf.yield ops to handle each of the
654
+ # args.
655
+ #
656
+ # **Context**:
657
+ # During mutation analysis, the analysis pass will identify mutating ops (e.g. tt.store)
658
+ # and then DFS upwards towards the parameters of the function. Specifically, the analysis pass
659
+ # looks at the mutated arg in tt.store; then looks for its source ops; and then recurses on the
660
+ # arguments to each of the source ops.
661
+ #
662
+ # In the case of scf.if/scf.for, we may have multiple return ops, each passed as an arg
663
+ # to scf.yield:
664
+ #
665
+ # %18:2 = scf.if %... -> (!tt.ptr<f32>, !tt.ptr<f32>) {
666
+ # ...
667
+ # scf.yield %1, %2
668
+ # } else {
669
+ # scf.yield %3, %4
670
+ # }
671
+ #
672
+ # And for each of the returns of the scf.if, we'd naively assign the source op of each of the
673
+ # return values to be the scf.yields. But the scf.yields take _all_ the returns as arguments.
674
+ # Therefore, if _any_ of the return values of the scf.if are mutated, then the analysis pass
675
+ # would mark _all_ of the yield args as mutated.
676
+ #
677
+ # **Solution**:
678
+ # For the purposes of this analysis pass, we create N yield ops - one for each
679
+ # return-val/yield-arg. In the example above, we'll have two scf.yield's for each branch of the
680
+ # scf.if.
681
+
682
+ for return_op in return_ops:
683
+ for i, (scf_result, yield_arg) in enumerate(
684
+ zip(scf_results, return_op.args)
685
+ ):
686
+ sub_yield_op = Op(
687
+ return_op.name,
688
+ return_op.fn_call_name,
689
+ [yield_arg],
690
+ return_op.ret,
691
+ sub_idx=i,
692
+ )
693
+ op_stack[parent_block_id][scf_result].append(sub_yield_op)
694
+
695
+ else:
696
+ for scf_result in scf_results:
697
+ for return_op in return_ops:
698
+ op_stack[parent_block_id][scf_result].append(return_op)
699
+ else:
700
+ raise RuntimeError(
701
+ f"Unknown blocked function: {name}. Can't capture the TTIR."
702
+ )
703
+ else:
704
+ callee = None
705
+ if name == "tt.call":
706
+ callee = op.get_flat_symbol_ref_attr("callee")
707
+ args: list[Union[Param, Intermediate]] = [
708
+ Intermediate(operand) for operand in operand_ids
709
+ ]
710
+ block_ops = op_stack[parent_block_id]
711
+
712
+ is_pure = False
713
+ # Handle the case for tt.elementwise_inline_asm to set `is_pure` for mutation analysis
714
+ if name == "tt.elementwise_inline_asm":
715
+ is_pure = op.get_bool_attr("pure")
716
+
717
+ if result_ids:
718
+ for result_id in result_ids:
719
+ res = Intermediate(result_id)
720
+ block_ops[res].append(Op(name, callee, args, res, is_pure=is_pure))
721
+ else:
722
+ next_fake_intermediate -= 1
723
+ fake_res = Intermediate(next_fake_intermediate)
724
+ block_ops[fake_res].append(
725
+ Op(name, callee, args, fake_res, is_pure=is_pure)
726
+ )
727
+
728
+ ttir_module.walk(mlir_to_functions)
729
+
730
+ return functions
731
+
732
+
733
+ class MemoizeWithCycleCheck:
734
+ fn: Callable[..., Any]
735
+ cache: dict[tuple[Any], Any]
736
+
737
+ def __init__(self, fn: Callable[..., Any]) -> None:
738
+ self.fn = fn
739
+ self.reset()
740
+
741
+ def __call__(
742
+ self,
743
+ functions: dict[str, dict[Intermediate, list[Op]]],
744
+ fn_name: str,
745
+ *args: Any,
746
+ ) -> list[bool]:
747
+ key: tuple[Any, ...] = (fn_name, *args)
748
+ if key not in self.cache:
749
+ self.cache[key] = None
750
+ self.cache[key] = self.fn(functions, fn_name, *args)
751
+ if self.cache[key] is None:
752
+ raise RuntimeError("Recursion is not supported")
753
+ return self.cache[key]
754
+
755
+ def reset(self) -> None:
756
+ self.cache = {}
757
+
758
+
759
+ @MemoizeWithCycleCheck
760
+ def get_tma_stores(
761
+ functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str
762
+ ) -> set[Union[Intermediate, Param]]:
763
+ """
764
+ Identifies all intermediates and parameters that are written to by a
765
+ `tt.experimental_descriptor_store`. It tracks only the specific values
766
+ written to via experimental_descriptor_store and the input values to
767
+ `tt.reinterpret_tensor_descriptor` used to construct the direct inputs
768
+ to tt.experimental_descriptor_store - not any recursive values
769
+ used to construct those values.
770
+
771
+ For example: for
772
+ tt.reinterpret_tensor_descriptor(Intermediate(idx=0), ...)
773
+ Intermediate(idx=1) = tt.experimental_descriptor_store(Intermediate(idx=0), ...)
774
+ this function will return [Intermediate(idx=0), Intermediate(idx=1)],
775
+
776
+ However
777
+ Intermediate(idx=4) = arith.addptr(Intermediate(idx=2), Intermediate(idx=3))
778
+ Intermediate(idx=5) = tt.experimental_descriptor_store(Intermediate(idx=4), ...)
779
+ tt.experimental_descriptor_store(Intermediate(idx=5), ...)
780
+ this function will mark only idx=4 and idx=5 (but not idx=2 or idx=3)
781
+
782
+ If an intermediate/parameter is passed into a function and is written to
783
+ via experimental_descriptor_store within that function, the argument to the
784
+ function will also be marked.
785
+ """
786
+
787
+ result: set[Union[Intermediate, Param]] = set()
788
+
789
+ ops = functions[fn_name]
790
+ for op_list in ops.values():
791
+ for op in op_list:
792
+ if op.name == "tt.call":
793
+ assert op.fn_call_name in functions
794
+ tma_stores = get_tma_stores(functions, op.fn_call_name)
795
+ for i, inp in enumerate(op.args):
796
+ if Param(idx=i) in tma_stores:
797
+ result.add(inp)
798
+ elif op.name == "tt.experimental_descriptor_store":
799
+ assert len(op.args) >= 1
800
+ result.add(op.args[0])
801
+
802
+ for val in list(result):
803
+ if val in ops:
804
+ if not isinstance(val, Intermediate):
805
+ continue
806
+ for op in ops[val]:
807
+ if op.name == "tt.reinterpret_tensor_descriptor":
808
+ assert len(op.args) >= 1
809
+ result.add(op.args[0])
810
+
811
+ return result
812
+
813
+
814
+ @MemoizeWithCycleCheck
815
+ def analyze_kernel_mutations(
816
+ functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str, num_args: int
817
+ ) -> list[bool]:
818
+ """
819
+ Analyzes the graph to detect all sinks from a predefined list of sinks
820
+ by using triton's MemWrite trait list. NOTE: What if triton exposed this?
821
+ From each sink, it traverses the CFG backwards to identify all the input
822
+ pointers that are mutated.
823
+ """
824
+ # Name of mutation op to mutated parameter indices
825
+ # List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
826
+ # All the OPs that have MemWrite trait.
827
+ # What if Triton exposed this?
828
+ MUTATION_OPS = {
829
+ "tt.store": [0],
830
+ "tt.atomic_cas": [0],
831
+ "tt.atomic_rmw": [0],
832
+ "tt.experimental_descriptor_store": [0],
833
+ "tt.experimental_tensormap_create": [0],
834
+ "tt.descriptor_store": [0],
835
+ }
836
+ # Ops that we want to bail out on
837
+ UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
838
+
839
+ stack: list[Union[Param, Intermediate]] = []
840
+ visited = set()
841
+ ops = functions[fn_name]
842
+ tma_stores = get_tma_stores(functions, fn_name)
843
+
844
+ for op_list in ops.values():
845
+ for op in op_list:
846
+ # If we encounter an operation with effects that cannot be reliably analyzed
847
+ # (e.g. `tt.elementwise_inline_asm`), we assume it does not mutate any input parameters.
848
+ if op.name in UNKNOWN_OPS:
849
+ if op.name == "tt.elementwise_inline_asm" and op.is_pure:
850
+ log.warning(
851
+ "TTIR mutation analysis: Skipping pure tt.elementwise_inline_asm op (is_pure=True)"
852
+ )
853
+ continue
854
+ raise RuntimeError(
855
+ f"ttir analysis hit an op we do not know how to analyze: {op.name}"
856
+ )
857
+
858
+ if op.name == "tt.experimental_tensormap_create":
859
+ # Note: this is how we implement experimental_descriptor_store mutation analysis.
860
+ # for on-device TMA.
861
+ # experimental_tensormap_store(a, b, ...) stores b to the location specified
862
+ # by descriptor in the memory of a.
863
+ # To track this, we first find all the intermediates/params to which we store via
864
+ # experimental_tensormap_store (get_tma_stores, called above). Then, during this
865
+ # analysis we wait to find the corresponding experimental_tensormap_create (if it
866
+ # exists), at which point we will mark the global_ptr as mutated (as done below).
867
+ assert len(op.args) >= 2
868
+ if op.args[0] in tma_stores:
869
+ stack.append(op.args[1])
870
+
871
+ if op.name == "tt.call":
872
+ assert op.fn_call_name in functions
873
+ mutations = analyze_kernel_mutations(
874
+ functions, op.fn_call_name, len(op.args)
875
+ )
876
+ stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
877
+ else:
878
+ stack.extend(op.args[idx] for idx in MUTATION_OPS.get(op.name, []))
879
+
880
+ # The following is an iterative DFS algorithm
881
+ mutated = [False] * num_args
882
+ while stack:
883
+ arg = stack.pop()
884
+ if arg in visited:
885
+ continue
886
+
887
+ visited.add(arg)
888
+
889
+ if isinstance(arg, Param):
890
+ if arg.idx >= num_args:
891
+ # This is an argument defined in the kernel, not passed in
892
+ continue
893
+ mutated[arg.idx] = True
894
+ elif isinstance(arg, Intermediate) and not arg.fake():
895
+ for op in ops[arg]:
896
+ # Skip arguments to load
897
+ if op.name != "tt.load":
898
+ stack.extend(op.args)
899
+ return mutated
900
+
901
+
902
+ def identify_mutated_tensors(
903
+ kernel: "TritonKernelType",
904
+ kwargs: dict[str, Any],
905
+ tma_descriptor_metadata: TMADescriptorMetadata,
906
+ ) -> list[str]:
907
+ """
908
+ Given a triton kernel and the arguments for this kernel, this function
909
+ 1) Retrieves the TTIR converted version of the kernel from Triton's API.
910
+ 2) Parses the TTIR and creates a control flow graph
911
+ 3) Analyzes the graph to detect all input tensor mutations
912
+ """
913
+
914
+ ttir_module = None
915
+ functions = None
916
+ try:
917
+ ttir_module, ordered_tensor_names = generate_ttir(
918
+ kernel, kwargs, tma_descriptor_metadata
919
+ )
920
+
921
+ # extract functions from TTIR using MLIR bindings exposed by Triton code
922
+ functions = ttir_to_functions(ttir_module)
923
+
924
+ assert functions is not None
925
+ kernel_name = next(iter(functions.keys()))
926
+ # Triton codegen modifies the name
927
+ assert kernel.fn.__name__ in kernel_name
928
+ # Reset the cache between top level invocations
929
+ # The cache for analyze kernel mutations is mainly used for cycle
930
+ # detection, so each top level invocation needs a clean cache
931
+ analyze_kernel_mutations.reset()
932
+ get_tma_stores.reset()
933
+ mutations = analyze_kernel_mutations(
934
+ functions, kernel_name, len(ordered_tensor_names)
935
+ )
936
+
937
+ return [
938
+ ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated
939
+ ]
940
+ except Exception:
941
+ log.warning(
942
+ "Encountered an exception in identify_mutated_tensors, assuming every input is mutated",
943
+ exc_info=True,
944
+ )
945
+ if ttir_module is not None:
946
+ log.debug("TTIR:\n%s", str(ttir_module))
947
+ if functions is not None:
948
+ log.debug("functions:")
949
+ for name, fn in functions.items():
950
+ log.debug("===\t%s\t===", name)
951
+ for ret, ops in fn.items():
952
+ log.debug("%s\t=>\t%s", ret, ops)
953
+ return [key for key, value in kwargs.items() if isinstance(value, Tensor)]
954
+
955
+
956
+ ###############################################################################
957
+ # Triton Kernel Wrappers
958
+
959
+
960
+ # Used for wrapping a Triton Kernel
961
+ class TritonKernelWrapperMutation(HigherOrderOperator):
962
+ def __init__(self) -> None:
963
+ super().__init__("triton_kernel_wrapper_mutation", cacheable=True)
964
+
965
+ def __call__(
966
+ self,
967
+ kernel_idx: int,
968
+ constant_args_idx: int,
969
+ grid: list["TritonGridType"],
970
+ tma_descriptor_metadata: TMADescriptorMetadata,
971
+ kwargs: dict[str, Any],
972
+ ) -> Any:
973
+ return super().__call__(
974
+ kernel_idx=kernel_idx,
975
+ constant_args_idx=constant_args_idx,
976
+ grid=grid,
977
+ tma_descriptor_metadata=tma_descriptor_metadata,
978
+ kwargs=kwargs,
979
+ )
980
+
981
+
982
+ triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
983
+
984
+
985
+ # Used for wrapping a Triton Kernel in a functional manner
986
+ class TritonKernelWrapperFunctional(HigherOrderOperator):
987
+ def __init__(self) -> None:
988
+ super().__init__("triton_kernel_wrapper_functional", cacheable=True)
989
+
990
+ def __call__(
991
+ self,
992
+ kernel_idx: int,
993
+ constant_args_idx: int,
994
+ grid: list["TritonGridType"],
995
+ tma_descriptor_metadata: TMADescriptorMetadata,
996
+ kwargs: dict[str, Any],
997
+ tensors_to_clone: list[str],
998
+ ) -> dict[str, Any]:
999
+ return super().__call__(
1000
+ kernel_idx=kernel_idx,
1001
+ constant_args_idx=constant_args_idx,
1002
+ grid=grid,
1003
+ tma_descriptor_metadata=tma_descriptor_metadata,
1004
+ kwargs=kwargs,
1005
+ tensors_to_clone=tensors_to_clone,
1006
+ )
1007
+
1008
+
1009
+ triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
1010
+
1011
+
1012
+ @triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
1013
+ def triton_kernel_wrapper_mutation_dense(
1014
+ *,
1015
+ kernel_idx: int,
1016
+ constant_args_idx: int,
1017
+ grid: list["TritonGridType"],
1018
+ tma_descriptor_metadata: TMADescriptorMetadata,
1019
+ kwargs: dict[str, Any],
1020
+ ) -> None:
1021
+ from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
1022
+
1023
+ kernel = kernel_side_table.get_kernel(kernel_idx)
1024
+ constant_args = kernel_side_table.get_constant_args(constant_args_idx)
1025
+
1026
+ if len(grid) == 1:
1027
+ grid_fn = grid[0]
1028
+ else:
1029
+ fn_name, code = user_defined_kernel_grid_fn_code(
1030
+ kernel.fn.__name__, kernel.configs, grid
1031
+ )
1032
+ namespace: dict[str, Any] = {}
1033
+ exec(code, namespace)
1034
+ grid_fn = namespace[fn_name]
1035
+
1036
+ if tma_descriptor_metadata:
1037
+ # as we need to launch the kernel here, we "unwrap" the
1038
+ # tma_descriptor_metadata, create the TMA descriptors
1039
+ # from it, and replace the tensors in the kwargs by the
1040
+ # correspoinding TMA descriptors before launching
1041
+ kwargs = kwargs.copy()
1042
+ for k, v in tma_descriptor_metadata.items():
1043
+ tensor = kwargs[k]
1044
+ if (exp_meta := maybe_unpack_tma_experimental_metadata(v)) is not None:
1045
+ from triton.tools.experimental_descriptor import ( # noqa: F401
1046
+ create_1d_tma_descriptor,
1047
+ create_2d_tma_descriptor,
1048
+ )
1049
+
1050
+ dims, block_dims, element_size = exp_meta
1051
+ create_tma_descriptor = (
1052
+ create_1d_tma_descriptor
1053
+ if len(dims) == 1
1054
+ else create_2d_tma_descriptor
1055
+ )
1056
+ kwargs[k] = create_tma_descriptor(
1057
+ tensor.data_ptr(),
1058
+ *dims,
1059
+ *block_dims,
1060
+ element_size,
1061
+ )
1062
+ else:
1063
+ stable_meta = maybe_unpack_tma_stable_metadata(v)
1064
+ assert stable_meta is not None
1065
+ from triton.tools.tensor_descriptor import TensorDescriptor
1066
+
1067
+ block_shape = stable_meta[0]
1068
+ kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape)
1069
+
1070
+ # move as many positional arguments from dicts to args as we
1071
+ # can to circumvent the bug with the kwargs and pre_/post_hook:
1072
+ # https://github.com/triton-lang/triton/issues/5082
1073
+ # TODO: remove this when the Triton issue above is fixed
1074
+ args = []
1075
+ # copy kwargs and constant_args here to
1076
+ # avoid mutating the original inputs
1077
+ kwargs = kwargs.copy()
1078
+ constant_args = constant_args.copy()
1079
+ for name in kernel.arg_names:
1080
+ if name in kwargs:
1081
+ args.append(kwargs.pop(name))
1082
+ elif name in constant_args:
1083
+ args.append(constant_args.pop(name))
1084
+ else:
1085
+ break
1086
+
1087
+ kernel[grid_fn](*args, **kwargs, **constant_args)
1088
+
1089
+
1090
+ @triton_kernel_wrapper_mutation.py_impl(FakeTensorMode)
1091
+ def triton_kernel_wrapper_mutation_fake_tensor_mode(
1092
+ mode: FakeTensorMode,
1093
+ *,
1094
+ kernel_idx: int,
1095
+ constant_args_idx: int,
1096
+ grid: list["TritonGridType"],
1097
+ tma_descriptor_metadata: TMADescriptorMetadata,
1098
+ kwargs: dict[str, Any],
1099
+ ) -> None:
1100
+ with mode:
1101
+ return None
1102
+
1103
+
1104
+ @triton_kernel_wrapper_mutation.py_impl(DispatchKey.Meta)
1105
+ def _(
1106
+ *,
1107
+ kernel_idx: int,
1108
+ constant_args_idx: int,
1109
+ grid: list["TritonGridType"],
1110
+ tma_descriptor_metadata: TMADescriptorMetadata,
1111
+ kwargs: dict[str, Any],
1112
+ ) -> None:
1113
+ return None
1114
+
1115
+
1116
+ def trace_triton_kernel_wrapper(
1117
+ proxy_mode: ProxyTorchDispatchMode,
1118
+ func_overload: Callable[..., Any],
1119
+ node_args: dict[str, Any],
1120
+ ) -> Optional[dict[str, Any]]:
1121
+ with disable_proxy_modes_tracing():
1122
+ out = func_overload(**node_args)
1123
+
1124
+ proxy_args = pytree.tree_map(
1125
+ proxy_mode.tracer.unwrap_proxy, node_args # type: ignore[union-attr]
1126
+ )
1127
+ out_proxy = proxy_mode.tracer.create_proxy(
1128
+ "call_function",
1129
+ func_overload,
1130
+ (),
1131
+ proxy_args,
1132
+ name=func_overload.__name__ + "_proxy",
1133
+ )
1134
+
1135
+ ret = track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
1136
+ return ret
1137
+
1138
+
1139
+ @triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
1140
+ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
1141
+ mode: ProxyTorchDispatchMode,
1142
+ *,
1143
+ kernel_idx: int,
1144
+ constant_args_idx: int,
1145
+ grid: list["TritonGridType"],
1146
+ tma_descriptor_metadata: TMADescriptorMetadata,
1147
+ kwargs: dict[str, Any],
1148
+ ) -> None:
1149
+ trace_triton_kernel_wrapper(
1150
+ mode,
1151
+ triton_kernel_wrapper_mutation,
1152
+ {
1153
+ "kernel_idx": kernel_idx,
1154
+ "constant_args_idx": constant_args_idx,
1155
+ "grid": grid,
1156
+ "tma_descriptor_metadata": tma_descriptor_metadata,
1157
+ "kwargs": kwargs,
1158
+ },
1159
+ )
1160
+
1161
+ return None
1162
+
1163
+
1164
+ def get_mutated_tensors(
1165
+ kernel_idx: int,
1166
+ constant_args_idx: int,
1167
+ kwargs: dict[str, Any],
1168
+ tma_descriptor_metadata: TMADescriptorMetadata,
1169
+ ) -> list[str]:
1170
+ kernel = kernel_side_table.get_kernel(kernel_idx)
1171
+ constant_args = kernel_side_table.get_constant_args(constant_args_idx)
1172
+ return identify_mutated_tensors(
1173
+ kernel, {**kwargs, **constant_args}, tma_descriptor_metadata
1174
+ )
1175
+
1176
+
1177
+ @triton_kernel_wrapper_mutation.py_functionalize_impl
1178
+ def triton_kernel_wrapper_mutation_functionalize(
1179
+ ctx: "BaseFunctionalizeAPI",
1180
+ kernel_idx: int,
1181
+ constant_args_idx: int,
1182
+ grid: list["TritonGridType"],
1183
+ tma_descriptor_metadata: TMADescriptorMetadata,
1184
+ kwargs: dict[str, Any],
1185
+ ) -> None:
1186
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
1187
+ # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
1188
+ # other, and one gets mutated in kernel, and later another gets mutated,
1189
+ # they are no longer equal. Fix this by graph breaking on this condition
1190
+ # earlier in dynamo.
1191
+ tensors_to_clone = get_mutated_tensors(
1192
+ kernel_idx, constant_args_idx, unwrapped_kwargs, tma_descriptor_metadata
1193
+ )
1194
+ with ctx.redispatch_to_next():
1195
+ unwrapped_outputs = triton_kernel_wrapper_functional(
1196
+ kernel_idx=kernel_idx,
1197
+ constant_args_idx=constant_args_idx,
1198
+ grid=grid,
1199
+ tma_descriptor_metadata=tma_descriptor_metadata,
1200
+ kwargs=unwrapped_kwargs,
1201
+ tensors_to_clone=tensors_to_clone,
1202
+ )
1203
+
1204
+ assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys()))
1205
+ for key, output_arg in unwrapped_outputs.items():
1206
+ if not isinstance(output_arg, Tensor):
1207
+ continue
1208
+ input_arg = kwargs[key]
1209
+ assert isinstance(input_arg, Tensor)
1210
+
1211
+ ctx.replace(input_arg, output_arg)
1212
+ # indicate that above replace is hidden from autograd
1213
+ ctx.mark_mutation_hidden_from_autograd(input_arg)
1214
+ ctx.commit_update(input_arg)
1215
+ ctx.sync(input_arg)
1216
+ return None
1217
+
1218
+
1219
+ @triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd)
1220
+ def triton_kernel_wrapper_functional_dense(
1221
+ *,
1222
+ kernel_idx: int,
1223
+ constant_args_idx: int,
1224
+ grid: list["TritonGridType"],
1225
+ tma_descriptor_metadata: TMADescriptorMetadata,
1226
+ kwargs: dict[str, Any],
1227
+ tensors_to_clone: list[str],
1228
+ ) -> dict[str, Any]:
1229
+ # TODO(oulgen): For performance reasons, we want to ensure that these
1230
+ # `clone_preserve_strides` calls are never executed at runtime
1231
+ # (inductor should always optimize them away).
1232
+ # Requires https://github.com/pytorch/pytorch/issues/109240
1233
+ kwargs = {
1234
+ key: (clone_preserve_strides(val) if key in tensors_to_clone else val)
1235
+ for key, val in kwargs.items()
1236
+ }
1237
+ triton_kernel_wrapper_mutation(
1238
+ kernel_idx=kernel_idx,
1239
+ constant_args_idx=constant_args_idx,
1240
+ grid=grid,
1241
+ tma_descriptor_metadata=tma_descriptor_metadata,
1242
+ kwargs=kwargs,
1243
+ )
1244
+ return {key: val for key, val in kwargs.items() if key in tensors_to_clone}
1245
+
1246
+
1247
+ @triton_kernel_wrapper_functional.py_impl(FakeTensorMode)
1248
+ def triton_kernel_wrapper_functional_fake_tensor_mode(
1249
+ mode: FakeTensorMode,
1250
+ *,
1251
+ kernel_idx: int,
1252
+ constant_args_idx: int,
1253
+ grid: list["TritonGridType"],
1254
+ tma_descriptor_metadata: TMADescriptorMetadata,
1255
+ kwargs: dict[str, Any],
1256
+ tensors_to_clone: list[str],
1257
+ ) -> dict[str, Any]:
1258
+ # TODO(oulgen): For performance reasons, we want to ensure that these
1259
+ # `clone_preserve_strides` calls are never executed at runtime
1260
+ # (inductor should always optimize them away).
1261
+ # Requires https://github.com/pytorch/pytorch/issues/109240
1262
+ with mode:
1263
+ return {
1264
+ key: clone_preserve_strides(val)
1265
+ for key, val in kwargs.items()
1266
+ if key in tensors_to_clone
1267
+ }
1268
+
1269
+
1270
+ @triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode)
1271
+ def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
1272
+ mode: ProxyTorchDispatchMode,
1273
+ *,
1274
+ kernel_idx: int,
1275
+ constant_args_idx: int,
1276
+ grid: list["TritonGridType"],
1277
+ tma_descriptor_metadata: TMADescriptorMetadata,
1278
+ kwargs: dict[str, Any],
1279
+ tensors_to_clone: list[str],
1280
+ ) -> dict[str, Any]:
1281
+ ret = trace_triton_kernel_wrapper(
1282
+ mode,
1283
+ triton_kernel_wrapper_functional,
1284
+ {
1285
+ "kernel_idx": kernel_idx,
1286
+ "constant_args_idx": constant_args_idx,
1287
+ "grid": grid,
1288
+ "tma_descriptor_metadata": tma_descriptor_metadata,
1289
+ "kwargs": kwargs,
1290
+ "tensors_to_clone": tensors_to_clone,
1291
+ },
1292
+ )
1293
+ assert ret is not None
1294
+ return ret
1295
+
1296
+
1297
+ @triton_kernel_wrapper_functional.py_functionalize_impl
1298
+ def triton_kernel_wrapper_functional_functionalize(
1299
+ ctx: "BaseFunctionalizeAPI",
1300
+ kernel_idx: int,
1301
+ constant_args_idx: int,
1302
+ grid: list["TritonGridType"],
1303
+ tma_descriptor_metadata: TMADescriptorMetadata,
1304
+ kwargs: dict[str, Any],
1305
+ tensors_to_clone: list[str],
1306
+ ) -> dict[str, Any]:
1307
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
1308
+ with ctx.redispatch_to_next():
1309
+ outputs = triton_kernel_wrapper_functional(
1310
+ kernel_idx=kernel_idx,
1311
+ constant_args_idx=constant_args_idx,
1312
+ grid=grid,
1313
+ tma_descriptor_metadata=tma_descriptor_metadata,
1314
+ kwargs=unwrapped_kwargs,
1315
+ tensors_to_clone=tensors_to_clone,
1316
+ )
1317
+ return ctx.wrap_tensors(outputs) # type: ignore[return-value,arg-type]
1318
+
1319
+
1320
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
1321
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
1322
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView)
1323
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect)
1324
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
1325
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
1326
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA)
1327
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU)
1328
+
1329
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
1330
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
1331
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView)
1332
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect)
1333
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
1334
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
1335
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
1336
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
1337
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU)
1338
+
1339
+
1340
+ ###############################################################################
1341
+ # The "TritonHOPifier": a class that transforms a call to a triton kernel into
1342
+ # a call to the triton_kernel_wrapper_mutation HOP.
1343
+
1344
+
1345
+ class TritonHOPifier:
1346
+ """Orchestrator for converting a user-defined triton kernel into a call
1347
+ to the triton_kernel_wrapper_mutation HOP.
1348
+
1349
+ It has two main use cases.
1350
+
1351
+ 1. When Dynamo sees a triton kernel, it wraps it into a TritonKernelVariable
1352
+ and uses the TritonHOPifier to convert calls to the TritonKernelVariable
1353
+ into a call to the HOP.
1354
+
1355
+ 2. In order to capture a user-defined triton kernel while performing
1356
+ tracing (via make_fx or non-strict export), a user must annotate their
1357
+ triton kernel with the `wrap_triton` decorator. The decorator uses
1358
+ TritonHOPifier to convert calls to the triton kernel into a call
1359
+ to the HOP (which can then be traced).
1360
+
1361
+ Because Dynamo has its own calling conventions for e.g. invoking a user-defined function
1362
+ TritonHOPifier is an abstract class that can be overridden by its subclasses.
1363
+ """
1364
+
1365
+ def raise_unsupported(self, msg: str) -> Never:
1366
+ raise NotImplementedError("abstract method")
1367
+
1368
+ def is_callable(self, maybe_callable: Any) -> bool:
1369
+ raise NotImplementedError("abstract method")
1370
+
1371
+ def get_value(self, val: Any) -> Any:
1372
+ raise NotImplementedError("abstract method")
1373
+
1374
+ def call_grid( # type: ignore[no-untyped-def]
1375
+ self,
1376
+ grid,
1377
+ meta,
1378
+ tx,
1379
+ ) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]:
1380
+ raise NotImplementedError("abstract method")
1381
+
1382
+ def wrap_user_defined_obj(
1383
+ self,
1384
+ user_obj: Any,
1385
+ tx: Optional["InstructionTranslator"],
1386
+ variable: Optional[
1387
+ Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
1388
+ ],
1389
+ name: str,
1390
+ ) -> Any:
1391
+ raise NotImplementedError("abstract method")
1392
+
1393
+ def call_user_defined_fn(
1394
+ self,
1395
+ user_fn: Callable[..., Any],
1396
+ args: list,
1397
+ kwargs: dict,
1398
+ tx: Optional["InstructionTranslator"],
1399
+ variable: Optional[
1400
+ Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
1401
+ ],
1402
+ ) -> Any:
1403
+ raise NotImplementedError("abstract method")
1404
+
1405
+ def maybe_unpack_configs(
1406
+ self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"]
1407
+ ) -> list["TritonConfig"]:
1408
+ raise NotImplementedError("abstract method")
1409
+
1410
+ def maybe_unpack_heuristic_result(self, result: Any) -> Any:
1411
+ raise NotImplementedError("abstract method")
1412
+
1413
+ @staticmethod
1414
+ def do_prune_configs( # type: ignore[no-untyped-def]
1415
+ autotuner: "TritonAutotunerType",
1416
+ early_config_prune: Optional[Callable],
1417
+ perf_model: Optional[Callable],
1418
+ top_k: float,
1419
+ configs: list,
1420
+ named_args: dict,
1421
+ kwargs: dict,
1422
+ ) -> list["TritonConfig"]:
1423
+ # Reimplement autotuner.prune_configs(...) here
1424
+ # see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
1425
+ # We do this to avoid calling prune_configs, which in turn calls early_config_prune and perf_model
1426
+ # These are both user-defined functions which can contain side effects, so we want to sandbox them in Dynamo
1427
+
1428
+ if early_config_prune:
1429
+ configs = early_config_prune(configs, named_args, **kwargs)
1430
+
1431
+ if perf_model:
1432
+ # we assert top_k is a float before calling this
1433
+ if isinstance(top_k, float) and top_k <= 1.0:
1434
+ top_k = int(len(configs) * top_k)
1435
+ elif not isinstance(top_k, int):
1436
+ """
1437
+ Slice index must be an integer, SupportsIndex or None
1438
+ """
1439
+ raise TypeError(
1440
+ "Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
1441
+ )
1442
+ if len(configs) > top_k:
1443
+ est_timing = [
1444
+ (
1445
+ config,
1446
+ float(
1447
+ perf_model(**named_args, **kwargs, **config.all_kwargs())
1448
+ ),
1449
+ )
1450
+ for config in configs
1451
+ ]
1452
+ configs = [
1453
+ config[0]
1454
+ for config in sorted(est_timing, key=operator.itemgetter(1))[:top_k]
1455
+ ]
1456
+ return configs
1457
+
1458
+ def call_HOP( # type: ignore[no-untyped-def]
1459
+ self,
1460
+ variable,
1461
+ grids,
1462
+ combined_args: dict[str, Any],
1463
+ tx,
1464
+ ) -> Optional["ConstantVariable"]:
1465
+ raise NotImplementedError("abstract method")
1466
+
1467
+ def check_grid( # type: ignore[no-untyped-def]
1468
+ self, grid
1469
+ ) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]:
1470
+ raise NotImplementedError("abstract method")
1471
+
1472
+ def init_variable(
1473
+ self,
1474
+ variable: Union["TraceableTritonKernelWrapper", "TritonKernelVariable"],
1475
+ kernel: "TritonKernelType",
1476
+ kernel_idx: Optional[int],
1477
+ grid: Optional["TritonGridType"],
1478
+ ) -> None:
1479
+ from triton.runtime.autotuner import Autotuner
1480
+
1481
+ assert kernel is not None
1482
+
1483
+ variable.kernel = kernel
1484
+ variable.kernel_idx = kernel_side_table.add_kernel(kernel)
1485
+
1486
+ assert kernel_idx is None or variable.kernel_idx == kernel_idx
1487
+
1488
+ variable.grid = grid
1489
+
1490
+ if isinstance(kernel, Autotuner):
1491
+ import torch
1492
+ import torch._dynamo
1493
+
1494
+ # We only support configs, keys, and restore_value arguments
1495
+ # of triton.autotune. Make sure other arguments are defaulted.
1496
+ defaults = inspect.signature(Autotuner.__init__).parameters
1497
+ # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
1498
+ # The call to get_first_attr is to maintain backward-compatibility.
1499
+
1500
+ def defaults_ok(
1501
+ attr: str, alternates: tuple[str, ...], values: tuple[Any, ...]
1502
+ ) -> bool:
1503
+ if attr not in defaults:
1504
+ return True
1505
+ value = torch._dynamo.utils.get_first_attr(kernel, attr, *alternates)
1506
+ if value == defaults[attr].default:
1507
+ return True
1508
+ return value in values
1509
+
1510
+ if (
1511
+ not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args
1512
+ and (
1513
+ not defaults_ok("num_warmups", ("warmup",), (25, None))
1514
+ or not defaults_ok("num_reps", ("rep",), (100, None))
1515
+ or not defaults_ok("use_cuda_graph", (), (False,))
1516
+ )
1517
+ ):
1518
+ self.raise_unsupported(
1519
+ "Only configs, keys, restore_value, and reset_to_zero are supported for triton.autotune"
1520
+ )
1521
+ if (
1522
+ not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args
1523
+ and (
1524
+ # pre_hook requires running arbitrary code at runtime, which we cannot handle at this time
1525
+ # https://github.com/pytorch/pytorch/issues/139059
1526
+ # we can't support pre_hook or post_hook in user defined triton kernels at the moment,
1527
+ # as they require the ability to execute code at runtime (AOTI can't support this)
1528
+ (
1529
+ hasattr(kernel, "user_defined_pre_hook")
1530
+ and kernel.user_defined_pre_hook is not False
1531
+ )
1532
+ or (
1533
+ hasattr(kernel, "user_defined_post_hook")
1534
+ and kernel.user_defined_post_hook is not False
1535
+ )
1536
+ or (
1537
+ # Check Config passed to autotuner in configs
1538
+ any(cfg.pre_hook is not None for cfg in kernel.configs)
1539
+ )
1540
+ )
1541
+ ):
1542
+ self.raise_unsupported(
1543
+ "pre_hook and post_hook are not supported in triton.Autotune or triton.Config"
1544
+ )
1545
+
1546
+ def call_getitem(
1547
+ self,
1548
+ variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
1549
+ args: Sequence[Any],
1550
+ ) -> Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]:
1551
+ # __getitem__ should only be called if we don't already have a grid
1552
+ # Only grid needs to be passed
1553
+ if variable.grid is not None or len(args) != 1:
1554
+ self.raise_unsupported(
1555
+ "Triton kernels should be called with only a single grid"
1556
+ )
1557
+
1558
+ return type(variable)(
1559
+ kernel=variable.kernel,
1560
+ kernel_idx=variable.kernel_idx,
1561
+ grid=args[0],
1562
+ )
1563
+
1564
+ def call_run(
1565
+ self,
1566
+ variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
1567
+ args: Sequence[Any],
1568
+ kwargs: dict[str, Any],
1569
+ tx: Optional["InstructionTranslator"],
1570
+ ) -> Optional["ConstantVariable"]:
1571
+ if "grid" not in kwargs:
1572
+ self.raise_unsupported("Triton kernel requires to be called with a grid")
1573
+ grid = kwargs.pop("grid")
1574
+ kwargs.pop("warmup", None)
1575
+ # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args)
1576
+ return self.call_triton_kernel(
1577
+ type(variable)(
1578
+ kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid
1579
+ ),
1580
+ args,
1581
+ kwargs,
1582
+ tx,
1583
+ )
1584
+
1585
+ def call_triton_kernel(
1586
+ self,
1587
+ variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
1588
+ args: Sequence[Any],
1589
+ kwargs: dict[str, Any],
1590
+ tx: Optional["InstructionTranslator"],
1591
+ ) -> Optional["ConstantVariable"]:
1592
+ from triton import JITFunction
1593
+ from triton.runtime.autotuner import autotune, Autotuner, Config, Heuristics
1594
+
1595
+ # Check if num_ctas is in kwargs
1596
+ if "num_ctas" in kwargs:
1597
+ self.raise_unsupported(
1598
+ "Passing num_ctas directly to the Triton kernel is not supported. "
1599
+ "Please use a Config in @triton.autotune instead."
1600
+ )
1601
+
1602
+ # Make sure the kernel has a grid
1603
+ if variable.grid is None:
1604
+ self.raise_unsupported("Triton kernels should always be called with a grid")
1605
+
1606
+ # raise an exception if there are multiple @triton.autotune decorators
1607
+ iter_kernel = variable.kernel
1608
+ autotuner_count = 0
1609
+ while not isinstance(iter_kernel, JITFunction):
1610
+ if isinstance(iter_kernel, Autotuner):
1611
+ autotuner_count += 1
1612
+ if autotuner_count > 1:
1613
+ self.raise_unsupported(
1614
+ "Passing multiple @triton.autotune decorators is not supported. "
1615
+ "Please use a single @triton.autotune decorator instead."
1616
+ )
1617
+ iter_kernel = iter_kernel.fn
1618
+
1619
+ # Process the @triton.heuristics decorator:
1620
+ # - We know there is only 1 autotuner decorator here
1621
+ # - We can apply the heuristic to all triton.Configs in the order that the decorators appear
1622
+ # This way, when the config is selected, the heuristics have already been applied.
1623
+ # - Decorators that appear *before* the autotuner are already processed correctly
1624
+ if isinstance(variable.kernel, Autotuner) and isinstance(
1625
+ variable.kernel.fn, Heuristics
1626
+ ):
1627
+ # unwrap the heuristics decorator, we don't need it anymore
1628
+ # variable.kernel ==> Autotuner
1629
+ # variable.kernel.fn ==> Heuristics
1630
+ # ...
1631
+ # There can be arbitrarily many heuristics wrappers here!
1632
+ # ...
1633
+ # variable.kernel.fn ==> JITFunction
1634
+
1635
+ # Copy the configs, we are going to be modifying them
1636
+ new_configs = copy.deepcopy(variable.kernel.configs)
1637
+
1638
+ named_args = dict(zip(variable.kernel.arg_names, args))
1639
+
1640
+ # Iterate through all of the heuristics wrappers that come after the autotune wrapper
1641
+ iter_kernel = variable.kernel.fn
1642
+ while isinstance(iter_kernel, Heuristics):
1643
+ # For each config, apply the heuristic fn(s)
1644
+ for config_idx in range(len(new_configs)):
1645
+ for kwarg_key, heuristic_fn in iter_kernel.values.items():
1646
+ # Run heuristics on the combined configs + kwargs
1647
+ heuristic_result = self.call_user_defined_fn(
1648
+ heuristic_fn,
1649
+ [
1650
+ {
1651
+ **named_args,
1652
+ **kwargs,
1653
+ **new_configs[config_idx].__dict__["kwargs"],
1654
+ },
1655
+ ],
1656
+ {},
1657
+ tx,
1658
+ variable,
1659
+ )
1660
+
1661
+ # Update the kwargs in each config
1662
+ # maybe_unpack_heuristic_result raises unsupported if the value is non-constant
1663
+ new_configs[config_idx].__dict__["kwargs"][
1664
+ kwarg_key
1665
+ ] = self.maybe_unpack_heuristic_result(heuristic_result)
1666
+
1667
+ iter_kernel = iter_kernel.fn
1668
+ assert isinstance(iter_kernel, JITFunction)
1669
+ prune_configs_by = {
1670
+ "perf_model": variable.kernel.perf_model,
1671
+ "early_config_prune": variable.kernel.early_config_prune,
1672
+ "configs_top_k": variable.kernel.configs_top_k,
1673
+ }
1674
+ new_kernel = autotune(
1675
+ configs=new_configs, key=[], prune_configs_by=prune_configs_by
1676
+ )(iter_kernel)
1677
+ # create a new variable to contain the new (wrapped) kernel;
1678
+ # skip kernel_idx to get a new record in the kernel side table
1679
+ new_var = type(variable)(new_kernel, None, variable.grid)
1680
+ return self.call_triton_kernel(new_var, args, kwargs, tx)
1681
+
1682
+ SPECIAL_CONFIG_NAMES = {
1683
+ "num_warps",
1684
+ "num_stages",
1685
+ "num_ctas",
1686
+ "num_consumer_groups",
1687
+ "num_buffers_warp_spec",
1688
+ }
1689
+
1690
+ # move special config names to configs out of kwargs
1691
+ special_kwargs = {}
1692
+ for name in SPECIAL_CONFIG_NAMES:
1693
+ if name in kwargs:
1694
+ # remove special kwargs from `kwargs`
1695
+ val = kwargs.pop(name)
1696
+ special_kwargs[name] = self.get_value(val)
1697
+
1698
+ if special_kwargs:
1699
+ if isinstance(variable.kernel, Autotuner):
1700
+ # if there is Autotuner already, set
1701
+ # special kwargs to each of its configs
1702
+ new_configs = copy.deepcopy(variable.kernel.configs)
1703
+ for config in new_configs:
1704
+ config.__dict__.update(special_kwargs)
1705
+ prune_configs_by = {
1706
+ "perf_model": variable.kernel.perf_model,
1707
+ "early_config_prune": variable.kernel.early_config_prune,
1708
+ "configs_top_k": variable.kernel.configs_top_k,
1709
+ }
1710
+
1711
+ new_kernel = autotune(
1712
+ configs=new_configs, key=[], prune_configs_by=prune_configs_by
1713
+ )(variable.kernel.fn)
1714
+ else:
1715
+ # if there is no Autotuner, wrap the kernel into a
1716
+ # new one with a single config with special kwargs
1717
+ new_config = Config(kwargs={}, **special_kwargs)
1718
+
1719
+ new_kernel = autotune(configs=[new_config], key=[])(variable.kernel)
1720
+
1721
+ # create a new variable to contain the new (wrapped) kernel;
1722
+ # skip kernel_idx to get a new record in the kernel side table
1723
+ new_var = type(variable)(new_kernel, None, variable.grid)
1724
+ return self.call_triton_kernel(new_var, args, kwargs, tx)
1725
+
1726
+ if isinstance(variable.kernel, Autotuner):
1727
+ special_param_names = []
1728
+ for name in SPECIAL_CONFIG_NAMES:
1729
+ if name in variable.kernel.fn.arg_names:
1730
+ special_param_names.append(name)
1731
+
1732
+ if special_param_names:
1733
+ # If the Triton kernel has SPECIAL_CONFIG_NAMES in parameters, those should
1734
+ # be passed from the kernel configs: the behavior of Triton runtime is that
1735
+ # those values get folded into the kernel arguments iff there are parameters
1736
+ # with the same name. Normally the values of those parameters are defined
1737
+ # outside the `kwargs` part of the autotuning configs. Here we move them to
1738
+ # the `kwargs` part (if they're absent there) to facilitate passing them as
1739
+ # arguments to the kernel downstream.
1740
+ updated = False
1741
+ new_configs = copy.deepcopy(variable.kernel.configs)
1742
+ for config in new_configs:
1743
+ for name in special_param_names:
1744
+ if name not in config.__dict__["kwargs"]:
1745
+ assert (
1746
+ name in config.__dict__
1747
+ ), f"{name} must be in autotuning configs to be used as a kernel parameter"
1748
+ config.__dict__["kwargs"][name] = config.__dict__[name]
1749
+ updated = True
1750
+
1751
+ if updated:
1752
+ prune_configs_by = {
1753
+ "perf_model": variable.kernel.perf_model,
1754
+ "early_config_prune": variable.kernel.early_config_prune,
1755
+ "configs_top_k": variable.kernel.configs_top_k,
1756
+ }
1757
+
1758
+ new_kernel = autotune(
1759
+ configs=new_configs, prune_configs_by=prune_configs_by, key=[]
1760
+ )(variable.kernel.fn)
1761
+ new_var = type(variable)(new_kernel, None, variable.grid)
1762
+ return self.call_triton_kernel(new_var, args, kwargs, tx)
1763
+
1764
+ # These are the default values in upstream Triton
1765
+ # see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
1766
+ default_perf_model = None
1767
+ default_early_config_prune = None
1768
+
1769
+ # run prune_configs_by
1770
+ if isinstance(variable.kernel, Autotuner) and (
1771
+ variable.kernel.perf_model != default_perf_model
1772
+ or variable.kernel.early_config_prune != default_early_config_prune
1773
+ ):
1774
+ # Prune the configs
1775
+ named_args = dict(zip(variable.kernel.arg_names, args))
1776
+
1777
+ # The source information is important here so the guards are installed correctly
1778
+
1779
+ wrapped_early_configs_prune = self.wrap_user_defined_obj(
1780
+ variable.kernel.early_config_prune,
1781
+ tx,
1782
+ variable,
1783
+ "early_config_prune",
1784
+ )
1785
+
1786
+ wrapped_perf_model = self.wrap_user_defined_obj(
1787
+ variable.kernel.perf_model, tx, variable, "perf_model"
1788
+ )
1789
+
1790
+ wrapped_configs_top_k = self.wrap_user_defined_obj(
1791
+ variable.kernel.configs_top_k, tx, variable, "configs_top_k"
1792
+ )
1793
+
1794
+ wrapped_configs = self.wrap_user_defined_obj(
1795
+ variable.kernel.configs, tx, variable, "configs"
1796
+ )
1797
+
1798
+ pruned_configs = self.call_user_defined_fn(
1799
+ self.do_prune_configs,
1800
+ [
1801
+ variable,
1802
+ wrapped_early_configs_prune,
1803
+ wrapped_perf_model,
1804
+ wrapped_configs_top_k,
1805
+ wrapped_configs,
1806
+ named_args,
1807
+ kwargs,
1808
+ ],
1809
+ {},
1810
+ tx,
1811
+ variable,
1812
+ )
1813
+
1814
+ pruned_configs = self.maybe_unpack_configs(pruned_configs, tx)
1815
+
1816
+ # after pruning the configs, create a new autotuner object with
1817
+ # these configs and recurse.
1818
+ new_kernel = autotune(configs=pruned_configs, key=[])(variable.kernel.fn)
1819
+ # create a new variable to contain the new (wrapped) kernel;
1820
+ # skip kernel_idx to get a new record in the kernel side table
1821
+ new_var = type(variable)(new_kernel, None, variable.grid)
1822
+ return self.call_triton_kernel(new_var, args, kwargs, tx)
1823
+
1824
+ # Both for grid's meta as well as for the kernel, we need combined
1825
+ # args and kwargs combined and normalized
1826
+ combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
1827
+
1828
+ # precompute the grid for the kernel
1829
+ configs = (
1830
+ [config.kwargs for config in variable.kernel.configs]
1831
+ if isinstance(variable.kernel, Autotuner)
1832
+ else [{}]
1833
+ )
1834
+ grids = []
1835
+ for config_args in configs:
1836
+ # If the grid is a function, then lets execute it and convert it to
1837
+ # a list
1838
+ grid = variable.grid
1839
+ assert grid is not None
1840
+ if self.is_callable(grid):
1841
+ # Populate the special "meta" argument to call the grid function
1842
+ meta = {**combined_args_raw, **config_args}
1843
+ grid = self.call_grid(grid, meta, tx) # type: ignore[arg-type]
1844
+ grids.append(self.check_grid(grid))
1845
+
1846
+ for i in range(len(grids)):
1847
+ if not isinstance(grids[i], tuple):
1848
+ self.raise_unsupported("Only tuple grids are supported")
1849
+ # inductor expects all grids to be 3-tuple so lets make it
1850
+ if len(grids[i]) == 1:
1851
+ grids[i] = (grids[i][0], 1, 1)
1852
+ elif len(grids[i]) == 2:
1853
+ grids[i] = (grids[i][0], grids[i][1], 1)
1854
+ elif len(grids[i]) > 3:
1855
+ self.raise_unsupported("Grid can have at most rank 3")
1856
+
1857
+ assert len(grids) != 0
1858
+ if isinstance(variable.kernel, JITFunction):
1859
+ constexprs = variable.kernel.constexprs
1860
+ else:
1861
+ # If we are looking at an @triton.autotune decorator, the nested function should be a JITFunction
1862
+ # This is because we don't support @triton.heuristics or nested @triton.autotune decorators yet
1863
+ assert isinstance(variable.kernel, Autotuner)
1864
+ constexprs = variable.kernel.fn.constexprs
1865
+
1866
+ for idx, arg_name in enumerate(variable.kernel.arg_names):
1867
+ if idx in constexprs:
1868
+ if arg_name in combined_args_raw:
1869
+ # [Note: Specialize tl.constexpr args in user-defined triton kernels]
1870
+ # This arg is marked as tl.constexpr. That means that triton will recompile every time
1871
+ # this value changes.
1872
+ # https://github.com/pytorch/pytorch/issues/136504
1873
+ # One option is to correctly pass the symints in so that the symbolic expressions are defined
1874
+ # when the triton code is being executed.
1875
+ # But since triton will have to recompile either way, we instead just specialize on the value.
1876
+ #
1877
+ # Depending on the type of `variable` we might expect different types for the symbolic args:
1878
+ # either SymNodeVariables (for TritonKernelVariables) or SymInts (TracingTritonKernelWrapper)
1879
+ combined_args_raw[arg_name] = variable.specialize_symbolic(
1880
+ combined_args_raw[arg_name]
1881
+ )
1882
+ return self.call_HOP(variable, grids, combined_args_raw, tx)
1883
+
1884
+
1885
+ ###############################################################################
1886
+ # Helpers for wrap_triton API that makes a user-defined triton kernel traceable into
1887
+ # a graph via make_fx or non-strict export (coming soon)
1888
+
1889
+
1890
+ class TracingTritonHOPifier(TritonHOPifier):
1891
+ def raise_unsupported(self, msg: str) -> Never:
1892
+ raise RuntimeError(msg)
1893
+
1894
+ def is_callable(self, maybe_callable: Any) -> bool:
1895
+ return callable(maybe_callable)
1896
+
1897
+ def get_value(self, val: Any) -> Any:
1898
+ return val
1899
+
1900
+ def call_grid(
1901
+ self,
1902
+ grid: "TritonGridCallableType",
1903
+ meta: "TritonMetaParamsType",
1904
+ tx: None,
1905
+ ) -> tuple[Union[int, sympy.Expr, SymInt], ...]:
1906
+ assert tx is None
1907
+ assert isinstance(meta, dict)
1908
+ assert callable(grid)
1909
+ return grid(meta)
1910
+
1911
+ def wrap_user_defined_obj(
1912
+ self,
1913
+ user_obj: Any,
1914
+ tx: Optional["InstructionTranslator"],
1915
+ variable: Optional[
1916
+ Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
1917
+ ],
1918
+ name: str,
1919
+ ) -> Any:
1920
+ assert tx is None
1921
+ return user_obj
1922
+
1923
+ def call_user_defined_fn(
1924
+ self,
1925
+ user_fn: Callable[..., Any],
1926
+ args: list,
1927
+ kwargs: dict,
1928
+ tx: Optional["InstructionTranslator"],
1929
+ variable: Optional[
1930
+ Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
1931
+ ],
1932
+ ) -> Any:
1933
+ assert isinstance(args, list)
1934
+ assert isinstance(kwargs, dict)
1935
+ assert callable(user_fn)
1936
+ return user_fn(*args, **kwargs)
1937
+
1938
+ def maybe_unpack_configs(
1939
+ self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"]
1940
+ ) -> list["TritonConfig"]:
1941
+ assert isinstance(configs, list)
1942
+ return configs
1943
+
1944
+ def maybe_unpack_heuristic_result(self, result: Any) -> Any:
1945
+ return result
1946
+
1947
+ def check_grid(
1948
+ self,
1949
+ grid: "TritonGridType",
1950
+ ) -> tuple[Union[int, sympy.Expr, SymInt], ...]:
1951
+ if not isinstance(grid, collections.abc.Sequence):
1952
+ raise RuntimeError(
1953
+ "wrap_triton can only handle grids that resolve to Sequence[int]."
1954
+ )
1955
+ # normalize to tuple
1956
+ return tuple(grid)
1957
+
1958
+ def store_non_graphable_args(
1959
+ self,
1960
+ combined_args: dict[str, Any],
1961
+ ) -> tuple[dict, int]:
1962
+ """
1963
+ Some args cannot be stored in the FX graph.
1964
+ Put them in the side table.
1965
+ """
1966
+
1967
+ def is_graphable(val: Any) -> bool:
1968
+ return isinstance(val, (fx.node.base_types, fx.Node))
1969
+
1970
+ non_graphable_args = {
1971
+ k: v for k, v in combined_args.items() if not is_graphable(v)
1972
+ }
1973
+ graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
1974
+
1975
+ constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
1976
+
1977
+ return graphable_args, constant_args_idx
1978
+
1979
+ def call_HOP(
1980
+ self,
1981
+ variable: "TraceableTritonKernelWrapper",
1982
+ grids: list["TritonGridTupleType"],
1983
+ combined_args: dict[str, Any],
1984
+ tx: None,
1985
+ ) -> None:
1986
+ assert tx is None
1987
+ assert isinstance(variable, TraceableTritonKernelWrapper)
1988
+
1989
+ graphable_args, constant_args_idx = self.store_non_graphable_args(combined_args)
1990
+
1991
+ assert isinstance(variable.kernel_idx, int)
1992
+ return triton_kernel_wrapper_mutation(
1993
+ kernel_idx=variable.kernel_idx,
1994
+ constant_args_idx=constant_args_idx,
1995
+ grid=grids, # type: ignore[arg-type]
1996
+ # TMA descriptor capturing not yet
1997
+ # supported in non-dynamo tracing
1998
+ tma_descriptor_metadata={},
1999
+ kwargs=graphable_args,
2000
+ )
2001
+
2002
+
2003
+ tracing_triton_hopifier_singleton = TracingTritonHOPifier()
2004
+
2005
+
2006
+ class TraceableTritonKernelWrapper:
2007
+ kernel: "TritonKernelType"
2008
+ kernel_idx: Optional[int]
2009
+ grid: Optional["TritonGridType"]
2010
+
2011
+ def __init__(
2012
+ self,
2013
+ kernel: "TritonKernelType",
2014
+ kernel_idx: Optional[int],
2015
+ grid: Optional["TritonGridType"],
2016
+ ) -> None:
2017
+ self.kernel = None
2018
+ self.grid = None
2019
+ tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
2020
+ assert self.kernel is not None
2021
+
2022
+ def __getitem__(self, *args: Sequence[Any]) -> "TraceableTritonKernelWrapper":
2023
+ return tracing_triton_hopifier_singleton.call_getitem(self, args) # type: ignore[return-value]
2024
+
2025
+ def run(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
2026
+ from torch._library.triton import is_wrap_triton_enabled
2027
+
2028
+ if is_wrap_triton_enabled():
2029
+ return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
2030
+ else:
2031
+ assert self.kernel is not None
2032
+ return self.kernel.run(*args, **kwargs)
2033
+
2034
+ def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
2035
+ from torch._library.triton import is_wrap_triton_enabled
2036
+
2037
+ if is_wrap_triton_enabled():
2038
+ return tracing_triton_hopifier_singleton.call_triton_kernel(
2039
+ self, args, kwargs, None
2040
+ )
2041
+ else:
2042
+ assert self.kernel is not None
2043
+ return self.kernel[self.grid](*args, **kwargs)
2044
+
2045
+ def specialize_symbolic(self, arg: Sequence[Any]) -> Any:
2046
+ import torch
2047
+
2048
+ # See [Note: Specialize tl.constexpr args in user-defined triton kernels]
2049
+ if isinstance(arg, (torch.SymInt, torch.SymBool, torch.SymFloat)):
2050
+ return guard_scalar(arg)
2051
+ return arg
archive/.venv/Lib/site-packages/torch/_higher_order_ops/utils.py ADDED
@@ -0,0 +1,1134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import functools
4
+ from contextlib import contextmanager, ExitStack, nullcontext
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Optional, overload, TypeVar, Union
7
+
8
+ import torch
9
+ import torch.fx.traceback as fx_traceback
10
+ import torch.utils._pytree as pytree
11
+ from torch._dispatch.python import suspend_functionalization
12
+ from torch._guards import detect_fake_mode
13
+ from torch._higher_order_ops.schema import HopSchema
14
+ from torch._ops import HigherOrderOperator, OperatorBase, OpOverload
15
+ from torch._subclasses.fake_tensor import FakeTensor
16
+ from torch._subclasses.functional_tensor import (
17
+ disable_functional_mode,
18
+ FunctionalTensor,
19
+ )
20
+ from torch.fx.experimental.proxy_tensor import (
21
+ _temp_remove_metadata_torch_function_mode,
22
+ disable_proxy_modes_tracing,
23
+ make_fx,
24
+ )
25
+ from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
26
+ from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
27
+ from torch.multiprocessing.reductions import StorageWeakRef
28
+
29
+
30
+ @dataclass
31
+ class UnsupportedAliasMutationException(RuntimeError):
32
+ reason: str
33
+
34
+
35
+ def autograd_not_implemented_inner(
36
+ operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
37
+ ) -> Any:
38
+ """If autograd is enabled and any of the arguments require grad this will either
39
+ raise an error or return a DelayedError depending on the value of delayed.
40
+
41
+ Args:
42
+ operator: The Operator to call with the *args and **kwargs with
43
+ op_name: The name of the Operator
44
+ delayed_error: If True, return a DelayedError instead of raising an error
45
+ args: The flattened operands to the Operator
46
+ kwargs: The keyword arguments to the Operator
47
+
48
+ Raises:
49
+ RuntimeError: If autograd is enabled and any of the arguments to the Operator
50
+ """
51
+ with torch._C._AutoDispatchBelowAutograd():
52
+ result = operator(*args, **kwargs)
53
+ flat_operands = pytree.arg_tree_leaves(*args)
54
+ if torch.is_grad_enabled() and any(
55
+ f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
56
+ ):
57
+ if delayed_error:
58
+ err_fn = torch._C._functions.DelayedError(
59
+ f"Autograd not implemented for {str(operator)}",
60
+ 1,
61
+ )
62
+
63
+ def fake_requires_grad(tensor):
64
+ if torch.is_floating_point(tensor) or torch.is_complex(tensor):
65
+ tensor = tensor.detach()
66
+ tensor.requires_grad = True
67
+ return tensor
68
+
69
+ return pytree.tree_map_only(
70
+ torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
71
+ )
72
+ else:
73
+ raise RuntimeError(f"Autograd not implemented for {str(operator)}")
74
+ return result
75
+
76
+
77
+ def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
78
+ def inner(*args, **kwargs):
79
+ return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
80
+
81
+ return inner
82
+
83
+
84
+ def _maybe_run_with_interpreter(fn):
85
+ maybe_interpreted_fn = fn
86
+ if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
87
+ # Running graph with interpreter is needed for propagating the stack_trace
88
+ def graph_with_interpreter(*args):
89
+ with fx_traceback.preserve_node_meta():
90
+ return torch.fx.Interpreter(fn).run(*args)
91
+
92
+ maybe_interpreted_fn = graph_with_interpreter
93
+ return maybe_interpreted_fn
94
+
95
+
96
+ def _maybe_compile_and_run_fn(fn, *args):
97
+ if not torch.compiler.is_dynamo_compiling():
98
+ from torch._dynamo.backends.debugging import (
99
+ make_eager_backend_with_torch_function_mode,
100
+ )
101
+
102
+ with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
103
+ with _temp_remove_metadata_torch_function_mode() as metadata_mode:
104
+ if metadata_mode:
105
+ backend = make_eager_backend_with_torch_function_mode(metadata_mode)
106
+ else:
107
+ backend = "eager"
108
+ return torch.compile(fn, backend=backend, fullgraph=True)(*args)
109
+ else:
110
+ return fn(*args)
111
+
112
+
113
+ def reenter_make_fx(fn):
114
+ from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
115
+
116
+ @functools.wraps(fn)
117
+ def wrapped(*args):
118
+ assert (
119
+ _CURRENT_MAKE_FX_TRACER is not None
120
+ ), "Cannot reenter make_fx when we're not under a make_fx tracing session"
121
+ return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
122
+ _maybe_run_with_interpreter(fn), *args
123
+ )
124
+
125
+ return wrapped
126
+
127
+
128
+ def _maybe_reenter_make_fx(fn):
129
+ from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
130
+
131
+ if _CURRENT_MAKE_FX_TRACER is not None:
132
+ return reenter_make_fx(fn)
133
+ else:
134
+
135
+ def _maybe_make_fx_with_fake_mode(fn):
136
+ @functools.wraps(fn)
137
+ def wrapped(*args):
138
+ from torch._guards import detect_fake_mode
139
+
140
+ fake_mode = detect_fake_mode(args)
141
+ if fake_mode is None:
142
+ # we creaeta a fake_mode here to make sure we could
143
+ # trace the graph with data-dependent calls e.g. .item()
144
+ return make_fx(fn, tracing_mode="fake")(*args)
145
+ # Tracing with real if all inputs have been fakfied
146
+ return make_fx(fn)(*args)
147
+
148
+ return wrapped
149
+
150
+ return _maybe_make_fx_with_fake_mode(fn)
151
+
152
+
153
+ def check_meta_consistency(
154
+ lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
155
+ rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
156
+ lhs_name: str,
157
+ rhs_name: str,
158
+ include_contiguity: bool = True,
159
+ ) -> None:
160
+ def diff_meta_pairs(
161
+ lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
162
+ rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
163
+ ) -> list[str]:
164
+ def diff_meta(
165
+ lhs: Union[torch.Tensor, torch.SymInt, int],
166
+ rhs: Union[torch.Tensor, torch.SymInt, int],
167
+ ) -> str:
168
+ if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
169
+ return ", ".join(
170
+ diff_tensor_meta(
171
+ _extract_tensor_metadata(
172
+ lhs, include_contiguity=include_contiguity
173
+ ),
174
+ _extract_tensor_metadata(
175
+ rhs, include_contiguity=include_contiguity
176
+ ),
177
+ check_grad=False,
178
+ )
179
+ )
180
+ else:
181
+
182
+ def _both_int_types(lhs, rhs):
183
+ return isinstance(lhs, (int, torch.SymInt)) and isinstance(
184
+ rhs, (int, torch.SymInt)
185
+ )
186
+
187
+ def _both_tensor(lhs, rhs):
188
+ return isinstance(lhs, torch.Tensor) and isinstance(
189
+ rhs, torch.Tensor
190
+ )
191
+
192
+ if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs):
193
+ return f"type: {lhs} vs {rhs}"
194
+
195
+ return ""
196
+
197
+ # Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata
198
+ def diff_device(
199
+ lhs: Union[torch.Tensor, torch.SymInt, int],
200
+ rhs: Union[torch.Tensor, torch.SymInt, int],
201
+ ) -> str:
202
+ if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
203
+ if (
204
+ rhs.device.type == lhs.device.type
205
+ and rhs.device.index == lhs.device.index
206
+ ):
207
+ return ""
208
+ else:
209
+ return "device"
210
+ return ""
211
+
212
+ if len(lhs_list) != len(rhs_list):
213
+ raise torch._dynamo.exc.UncapturedHigherOrderOpError(
214
+ f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}"
215
+ )
216
+ all_diffs = []
217
+ for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)):
218
+ if diff := diff_meta(lhs, rhs):
219
+ all_diffs.append(
220
+ f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
221
+ )
222
+ if diff := diff_device(lhs, rhs):
223
+ all_diffs.append(
224
+ f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
225
+ )
226
+ return all_diffs
227
+
228
+ if all_diffs := diff_meta_pairs(lhs_list, rhs_list):
229
+ diff_str = "\n".join(all_diffs)
230
+ raise torch._dynamo.exc.UncapturedHigherOrderOpError(
231
+ f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}"
232
+ )
233
+
234
+
235
+ @contextmanager
236
+ def _set_compilation_env():
237
+ _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
238
+ _old_allow_empty_graphs = torch._dynamo.config.allow_empty_graphs
239
+ # The issue is tracked in https://github.com/pytorch/pytorch/issues/144360: when dynamo finds
240
+ # the top-level frame produces no graph, the default behavior is to fallback to eager.
241
+ # Then when it encounters an inner function, it will try to trace that function again, which is unnecessary.
242
+ # For while_loop, during inspecting the inner call, we trace into the python dispathcer
243
+ # logic, which is not tracable as of today. So the proper fix can be either 1. allow dispatch
244
+ # logic to be dynamo tracable or 2. fixing https://github.com/pytorch/pytorch/issues/144360.
245
+ # but it exposes some bugs in existing tests so we have to have a temporary flag to control
246
+ # the behavior, which allows dynamo to store an empty graph for a frame without falling back to eager
247
+ try:
248
+ # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
249
+ # once we are confident fx tracing works with dynamo.
250
+ torch.fx._symbolic_trace._is_fx_tracing_flag = False
251
+ torch._dynamo.config.allow_empty_graphs = True
252
+ yield
253
+ finally:
254
+ torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
255
+ torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs
256
+
257
+
258
+ # The invariant here is that we always trace the branch with fake tensor
259
+ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
260
+ fake_mode = detect_fake_mode(inputs)
261
+ tracing_mode = "real"
262
+ if fake_mode is None:
263
+ fake_mode = nullcontext()
264
+ tracing_mode = "fake"
265
+
266
+ # Note: we need to turn off proxy tensor mode to avoid tracing infra
267
+ # code that happens in make_fx e.g. we now call as_strided when wrapping tensor
268
+ # as fake tensor.
269
+ with fake_mode, disable_proxy_modes_tracing():
270
+ gm = make_fx(
271
+ fn,
272
+ tracing_mode=tracing_mode,
273
+ pre_dispatch=pre_dispatch,
274
+ _error_on_data_dependent_ops=False,
275
+ )(*inputs)
276
+ if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None:
277
+ insert_deferred_runtime_asserts(
278
+ gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True
279
+ )
280
+ return gm
281
+
282
+
283
+ def potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
284
+ try:
285
+ gm = _maybe_fake_tracing(gm, inputs, pre_dispatch)
286
+ except UnsupportedAliasMutationException:
287
+ # this can happen when nested cond_op is
288
+ # functionalized
289
+ return True
290
+ except Exception as e:
291
+ raise e
292
+
293
+ example_inputs = [
294
+ ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
295
+ ]
296
+ (
297
+ inp_inp_alias_map,
298
+ inp_out_alias_map,
299
+ out_out_alias_map,
300
+ inp_mutation,
301
+ ) = check_input_alias_and_mutation(gm, example_inputs)
302
+ return (inp_inp_alias_map, inp_out_alias_map, out_out_alias_map), inp_mutation
303
+
304
+
305
+ def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations):
306
+ if any(len(a) > 0 for a in aliases):
307
+ # TODO: Investigate here further which node is exactly aliasing
308
+ raise RuntimeError(
309
+ f"{name} where aliases appear. "
310
+ + f"In particular, these inputs \
311
+ {set(el for el_map in aliases if len(el_map.keys()) > 0 for el in el_map.keys())} " # noqa: C401
312
+ + "get aliased. Please ensure that this doesn't happen."
313
+ )
314
+ if len(input_mutations):
315
+ # TODO: Investigate here further which node is exactly mutating the inputs
316
+ raise RuntimeError(
317
+ f"{name} where the inputs are mutated. "
318
+ + f"In particular, these nodes are mutating the inputs \
319
+ {set(el for el in input_mutations)}." # noqa: C401
320
+ + "Please ensure that this doesn't happen."
321
+ )
322
+
323
+
324
+ def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False):
325
+ (
326
+ _,
327
+ _,
328
+ _,
329
+ ), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
330
+
331
+ return len(inp_mutation) > 0
332
+
333
+
334
+ def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
335
+ (
336
+ inp_inp_alias_map,
337
+ inp_out_alias_map,
338
+ out_out_alias_map,
339
+ ), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
340
+ return (
341
+ any(
342
+ (
343
+ len(inp_inp_alias_map) > 0,
344
+ len(inp_out_alias_map) > 0,
345
+ len(out_out_alias_map) > 0,
346
+ )
347
+ ),
348
+ len(inp_mutation) > 0,
349
+ )
350
+
351
+
352
+ def _collect_fake_inputs(inputs):
353
+ from torch._subclasses.fake_tensor import FakeTensor
354
+
355
+ # Get the example values of the inputs.
356
+ inputs_fake: list[Union[FakeTensor, torch.Tensor, int]] = []
357
+ for inp in inputs:
358
+ if isinstance(inp, (torch.fx.proxy.Proxy, torch.fx.node.Node)):
359
+ inp = inp.node if isinstance(inp, torch.fx.proxy.Proxy) else inp
360
+ if hasattr(inp, "meta"):
361
+ val = inp.meta["example_value"]
362
+ if isinstance(val, torch.Tensor):
363
+ if torch._C._functorch.is_batchedtensor(
364
+ val
365
+ ) or torch._C._functorch.is_functionaltensor(val):
366
+ # This case is for batched or functional tensors
367
+ # Unwrap the tensors
368
+ while torch._C._functorch.is_batchedtensor(
369
+ val
370
+ ) or torch._C._functorch.is_functionaltensor(val):
371
+ val = torch._C._functorch.get_unwrapped(val)
372
+ assert isinstance(val, FakeTensor)
373
+ inputs_fake.append(val)
374
+ else:
375
+ # This is the standard case of a TensorVariable
376
+ assert isinstance(val, FakeTensor)
377
+ inputs_fake.append(val)
378
+ else:
379
+ # This case is for SymInts and other non-Tensor elements
380
+ assert not isinstance(val, torch.Tensor)
381
+ inputs_fake.append(val)
382
+ else:
383
+ # This case is for ints
384
+ assert isinstance(inp, int)
385
+ inputs_fake.append(inp)
386
+
387
+ return inputs_fake
388
+
389
+
390
+ def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch):
391
+ aliases, inp_mutation = has_potential_input_alias_or_mutation(
392
+ graph_module, inputs_fake, pre_dispatch=pre_dispatch
393
+ )
394
+ if aliases:
395
+ raise RuntimeError(
396
+ f"{name} might be aliasing the input or the output!"
397
+ ) # noqa: F541
398
+ if inp_mutation:
399
+ raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541
400
+
401
+
402
+ def unique_graph_id(proxy_mode, prefix):
403
+ """Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
404
+ # There are probably better ways - I know that create_arg has some self incrementing name
405
+ # magic to it, but since we explicitly have to get the name for register_module,
406
+ # I was not sure how to do that. This kinda simulates it.
407
+ return unique_graph_name_with_root(proxy_mode.tracer.root, prefix)
408
+
409
+
410
+ def unique_graph_name_with_root(
411
+ root: torch.fx.GraphModule, prefix: str
412
+ ) -> tuple[int, str]:
413
+ next_name = None
414
+ i = 0
415
+ while not next_name:
416
+ candidate = f"{prefix}_{i}"
417
+ if hasattr(root, candidate):
418
+ i += 1
419
+ else:
420
+ next_name = candidate
421
+ return i, next_name
422
+
423
+
424
+ def _from_fun(t):
425
+ from torch._functorch.aot_autograd import from_fun
426
+
427
+ if isinstance(t, torch.Tensor):
428
+ if t.dtype != torch.bool:
429
+ return torch.empty_strided(
430
+ t.size(),
431
+ t.stride(),
432
+ dtype=t.dtype,
433
+ requires_grad=t.requires_grad,
434
+ device=t.device,
435
+ )
436
+ else:
437
+ # clone of a functional tensor produces a functional tensor
438
+ # but we want to avoid it so we clone a non-functional version
439
+ maybe_unfunc_t = t
440
+ if isinstance(t, FunctionalTensor):
441
+ torch._sync(t)
442
+ maybe_unfunc_t = from_fun(t)
443
+ elif torch._is_functional_tensor(t):
444
+ # need to handle both types of functionalization here:
445
+ # these are the tensors that came from the user,
446
+ # which could be either FunctionalTensorWrapper or FunctionalTensor
447
+ torch._sync(t)
448
+ maybe_unfunc_t = torch._from_functional_tensor(t)
449
+ return maybe_unfunc_t.clone()
450
+ return t
451
+
452
+
453
+ def clone_outputs_aliasing_inputs(args):
454
+ input_storage = {
455
+ StorageWeakRef(arg._typed_storage())
456
+ for arg in args
457
+ if isinstance(arg, torch.Tensor)
458
+ }
459
+
460
+ def maybe_clone(t):
461
+ if (
462
+ isinstance(t, torch.Tensor)
463
+ and StorageWeakRef(t._typed_storage()) in input_storage
464
+ ):
465
+ return t.clone()
466
+ return t
467
+
468
+ return maybe_clone
469
+
470
+
471
+ def prepare_fw_with_masks(fn):
472
+ def fw_with_masks(*args):
473
+ fw_out = fn(*args)
474
+ return fw_out, [
475
+ True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
476
+ for ret in fw_out
477
+ ]
478
+
479
+ return fw_with_masks
480
+
481
+
482
+ def prepare_fw_with_masks_all_requires_grad(fn):
483
+ def fw_with_masks(*args):
484
+ fw_out = fn(*args)
485
+ # Note [force all outputs to be require grad]
486
+ # Instead of using the original fn, we set the output of original
487
+ # fn to all require grad. This is consistent with the behavior
488
+ # of autograd.Function, where if any one of the inputs requires grad
489
+ # all output will be require grad. This also makes the downstream
490
+ # require_gradness reasoning much easier.
491
+ if pytree.tree_any_only(torch.Tensor, lambda t: t.requires_grad, args):
492
+ fw_out = pytree.tree_map_only(
493
+ torch.Tensor, lambda x: x.requires_grad_(True), fw_out
494
+ )
495
+ return fw_out, pytree.tree_map_only(
496
+ torch.Tensor, lambda x: x.requires_grad, fw_out
497
+ )
498
+
499
+ return fw_with_masks
500
+
501
+
502
+ # This function replaces None gradients with all-zero gradients.
503
+ # `None` gradients are problematic for CUDA graphs. Those gradients are
504
+ # replaced with an all-zero tensor for better optimization
505
+ def unmask_none_gradients(grads, operands):
506
+ allowed_types = (torch.Tensor, int, torch.SymInt)
507
+ assert all(
508
+ isinstance(o, allowed_types) for o in operands
509
+ ), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
510
+
511
+ unmasked_grads = []
512
+ for g, o in zip(grads, operands):
513
+ if g is not None:
514
+ unmasked_grads.append(g)
515
+ else:
516
+ # In case the operand is an int or a torch.SymInt, return None
517
+ # This can happen for lifted_arguments. E.g., the shapes of a dynamic tensor are lifted and passed
518
+ # as additional arguments
519
+ unmasked_grads.append(
520
+ torch.zeros_like(o) if isinstance(o, torch.Tensor) else None
521
+ )
522
+
523
+ return unmasked_grads
524
+
525
+
526
+ def _maybe_fake_prop_ignore_unbacked(fn, args):
527
+ with ExitStack() as ctx_stack:
528
+ if (fake_mode := detect_fake_mode(args)) is not None:
529
+ ctx_stack.enter_context(fake_mode)
530
+ if fake_mode.shape_env is not None:
531
+ ctx_stack.enter_context(
532
+ fake_mode.shape_env.ignore_fresh_unbacked_symbols()
533
+ )
534
+ return fn(*args)
535
+
536
+
537
+ def redirect_to_mode(hop: OperatorBase, mode):
538
+ """Utility for redispatching HOP to underlying mode
539
+
540
+ Args:
541
+ hop: The HOP to redispatch
542
+ mode: The mode to redispatch to
543
+
544
+ Returns:
545
+ A decorated function that implements the HOP for the given mode
546
+ """
547
+
548
+ @hop.py_impl(mode)
549
+ def impl(mode, *args, **kwargs):
550
+ return mode.__torch_dispatch__(hop, [], args, kwargs)
551
+
552
+ return impl
553
+
554
+
555
+ # TODO: The parameter use_output_and_grad_bw is required because some operations
556
+ # that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
557
+ def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
558
+ from torch._functorch.aot_autograd import AOTConfig, create_joint
559
+
560
+ # Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
561
+ # between Autograd and Python key. Currently, we only suspend functionalization but more can be
562
+ # added when required. Will encounter two problems if we don't suspend functionalization:
563
+ #
564
+ # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
565
+ # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
566
+ # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
567
+ # fetch the proxy for the inputs and fail to capture any operations on them.
568
+ #
569
+ # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
570
+ # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
571
+ # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
572
+ # when creating the output node, it fails to associate the wrapped tensor with its proxy.
573
+ # Instead, it will create _tensor_constant as output.
574
+
575
+ dummy_aot_config = AOTConfig(
576
+ fw_compiler=None, # type: ignore[arg-type]
577
+ bw_compiler=None, # type: ignore[arg-type]
578
+ partition_fn=None, # type: ignore[arg-type]
579
+ decompositions={},
580
+ num_params_buffers=0,
581
+ aot_id=0,
582
+ keep_inference_input_mutations=False,
583
+ )
584
+
585
+ example_grad = [_from_fun(out) for out in fw_outputs]
586
+ num_grads = len(example_grad)
587
+ fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
588
+
589
+ def joint_fn(*joint_operands_grads):
590
+ if use_output_and_grad_bw:
591
+ grads = joint_operands_grads[0]
592
+ inputs = joint_operands_grads[1][-1:]
593
+ else:
594
+ grads = joint_operands_grads[:num_grads]
595
+ inputs = joint_operands_grads[num_grads:]
596
+
597
+ joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
598
+ _, grads = joint(
599
+ list(inputs),
600
+ [grad for grad in grads if grad is not None and grad.requires_grad],
601
+ )
602
+
603
+ # Unmask None gradients to all-zero gradients
604
+ unmasked_grads = unmask_none_gradients(grads, inputs)
605
+
606
+ # In order to keep map functional for backward graph,
607
+ # we clone outputs that are aliasing inputs
608
+ maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
609
+
610
+ return pytree.tree_map(maybe_clone, unmasked_grads)
611
+
612
+ if use_output_and_grad_bw:
613
+ example_xs_out = list(fw_inputs) + list(fw_outputs)
614
+ joint_graph = _maybe_reenter_make_fx(joint_fn)(
615
+ (list(example_grad), list(example_xs_out))
616
+ )
617
+ else:
618
+ example_xs_out = list(fw_inputs)
619
+ joint_graph = _maybe_reenter_make_fx(joint_fn)(
620
+ *(list(example_grad) + list(example_xs_out))
621
+ )
622
+
623
+ return fw_graph, joint_graph
624
+
625
+
626
+ def _unstack_pytree(xs):
627
+ flat_xs, inspec = pytree.tree_flatten(xs)
628
+ if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
629
+ raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
630
+
631
+ if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
632
+ raise RuntimeError(
633
+ f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
634
+ )
635
+
636
+ a = zip(*flat_xs)
637
+
638
+ pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
639
+ return pytrees
640
+
641
+
642
+ def _stack_pytree(pytrees):
643
+ flat_out = []
644
+ out_spec = None
645
+ for pt in pytrees:
646
+ flat_pt, out_spec = pytree.tree_flatten(pt)
647
+ flat_out.append(flat_pt)
648
+ assert out_spec is not None
649
+ b = zip(*flat_out)
650
+ stacked_out = []
651
+ for leaves in b:
652
+ if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
653
+ stacked_out.append(torch.stack(leaves))
654
+ elif all(leaf is None for leaf in leaves):
655
+ # Backward graph can return None output when forward inputs doesn't require grad.
656
+ # When we eagerly execute backward graph, we need to call _stack_pytree on its output,
657
+ # therefore we need to deal with None output.
658
+ stacked_out.append(None) # type: ignore[arg-type]
659
+ else:
660
+ raise RuntimeError(f"Cannot stack {leaves}.")
661
+ return pytree.tree_unflatten(stacked_out, out_spec)
662
+
663
+
664
+ # We cannot call save_for_backward for symints. This helper function
665
+ # can be used to save symints as direct attributes of ctx in autograd.Function.
666
+ #
667
+ # For example, if args = (x, y, s0, z, s1),
668
+ # save_tensors_and_symints_for_backward will partition the args into two lists, and a bookkeeping list pos:
669
+ # partitioned_args[0] = (x, y, z)
670
+ # partitioned_args[1] = (s0, s1)
671
+ # pos = (0, 0, 1, 0, 1)
672
+ # pos list keeps track of which partition the args
673
+ # is partitioned into in order to recover it in saved_tensors_and_symints.
674
+ #
675
+ # In saved_tensors_and_symints, we can recover the original args by:
676
+ # iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
677
+ # We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
678
+ def save_tensors_and_symints_for_backward(ctx, args):
679
+ assert all(
680
+ isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
681
+ ), args
682
+ partitioned_args: list[Any] = [[], []]
683
+ pos = []
684
+ for arg in args:
685
+ idx = 0 if isinstance(arg, torch.Tensor) else 1
686
+ partitioned_args[idx].append(arg)
687
+ pos.append(idx)
688
+
689
+ assert not hasattr(ctx, "sym_int_args"), "ctx already has sym_int_args attribute."
690
+ assert not hasattr(ctx, "pos"), "ctx already has pos attribute."
691
+ ctx.save_for_backward(*partitioned_args[0])
692
+ ctx.sym_int_args = partitioned_args[1]
693
+ ctx.pos = pos
694
+
695
+
696
+ def saved_tensors_and_symints(ctx):
697
+ args = []
698
+ t_idx = 0
699
+ s_idx = 0
700
+ saved_tensors = ctx.saved_tensors
701
+ for p in ctx.pos:
702
+ if p == 0:
703
+ args.append(saved_tensors[t_idx])
704
+ t_idx += 1
705
+ else:
706
+ args.append(ctx.sym_int_args[s_idx])
707
+ s_idx += 1
708
+ assert t_idx + s_idx == len(ctx.pos)
709
+ return tuple(args)
710
+
711
+
712
+ def get_dummy_aot_autograd_config():
713
+ from torch._functorch.aot_autograd import AOTConfig
714
+
715
+ return AOTConfig(
716
+ fw_compiler=None, # type: ignore[arg-type]
717
+ bw_compiler=None, # type: ignore[arg-type]
718
+ partition_fn=None, # type: ignore[arg-type]
719
+ decompositions={},
720
+ num_params_buffers=0,
721
+ aot_id=0,
722
+ keep_inference_input_mutations=False,
723
+ )
724
+
725
+
726
+ # Slices off the first element of a given dimension
727
+ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
728
+ return torch.select_copy(t, dim, 0)
729
+
730
+
731
+ # Reports the difference between meta of two tensors in a string
732
+ def diff_tensor_meta(
733
+ meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
734
+ ) -> list[str]:
735
+ from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
736
+
737
+ pair_diffs = []
738
+ for meta_name in TensorMetadata._fields:
739
+ if not check_grad and meta_name == "requires_grad":
740
+ continue
741
+ val1 = getattr(meta1, meta_name)
742
+ val2 = getattr(meta2, meta_name)
743
+ try:
744
+ if val1 != val2:
745
+ pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
746
+ except GuardOnDataDependentSymNode as _:
747
+ pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
748
+ continue
749
+ return pair_diffs
750
+
751
+
752
+ # Note [lifted arg types in hop]
753
+ # For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
754
+ # This has implications for the types of lifted args for different dispatch keys:
755
+ # 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint
756
+ # lifted args because it's on the path of torch.compile(dynamic=True).
757
+ # 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need
758
+ # to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
759
+ # hops may receive int inputs from the shape of outer tensor inputs.
760
+ # However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
761
+ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]):
762
+ allowed_types = (torch.Tensor, int, torch.SymInt)
763
+ assert all(
764
+ isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
765
+ ), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
766
+
767
+
768
+ # TODO: Return a more detailed information as to which node
769
+ # causes a mutation or an alias. This may requires a per operator tensor version checking
770
+ def check_input_alias_and_mutation(
771
+ gm: torch.fx.GraphModule,
772
+ fake_args: list[FakeTensor],
773
+ ) -> tuple[dict[int, int], dict[int, int], dict[int, int], list[int]]:
774
+ (
775
+ inp_inp_alias_map,
776
+ inp_out_alias_map,
777
+ out_out_alias_map,
778
+ mutated_inputs,
779
+ ) = check_input_alias_and_mutation_return_outputs(gm, fake_args)[:-1]
780
+ return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
781
+
782
+
783
+ def check_input_alias_and_mutation_return_outputs(
784
+ gm: torch.fx.GraphModule,
785
+ fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]],
786
+ ) -> tuple[
787
+ dict[int, int],
788
+ dict[int, int],
789
+ dict[int, int],
790
+ list[int],
791
+ Union[tuple[Any, ...], list[Any]],
792
+ ]:
793
+ # This function can be called under autograd, functional, proxy and fake tensor mode.
794
+ # We need to return either a fake tensor or a real tensor depending on the mode.
795
+ # to detect the input mutation/aliasing.
796
+ with disable_proxy_modes_tracing(), disable_functional_mode(), suspend_functionalization():
797
+
798
+ def _from_functional_tensor(t: torch.Tensor) -> torch.Tensor:
799
+ if isinstance(t, FunctionalTensor) or torch._is_functional_tensor(t):
800
+ return torch.empty_strided(
801
+ t.size(),
802
+ t.stride(),
803
+ dtype=t.dtype,
804
+ requires_grad=t.requires_grad,
805
+ device=t.device,
806
+ )
807
+ return t
808
+
809
+ fake_args = pytree.tree_map_only(
810
+ torch.Tensor, _from_functional_tensor, fake_args
811
+ )
812
+ # We want to disable active functional, proxy and fake modes if any.
813
+ # to create a encapsulated environment for fake tensor prop
814
+ with torch.utils._python_dispatch._disable_current_modes():
815
+ """This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias
816
+ in the graph module gm. It checks whether input tensor versions have
817
+ changed after run gm once to detect mutation and checks tensor storage
818
+ to detect alias.
819
+ """
820
+
821
+ def _tensor_version(t) -> Optional[int]:
822
+ if isinstance(t, torch.Tensor):
823
+ if not isinstance(t, FakeTensor):
824
+ raise RuntimeError("Only fake tensor is allowed")
825
+ return t._version
826
+ return None
827
+
828
+ def _tensor_storage(t) -> StorageWeakRef:
829
+ return StorageWeakRef(t._typed_storage())
830
+
831
+ def _get_shape_env(
832
+ fake_args,
833
+ ) -> Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]:
834
+ # detect_fake_mode requires there could be only one active fake mode. This
835
+ # restricts the usage of this function because the global TracingContext
836
+ # has a persistent fake mode but fake tensors can be created
837
+ # outside of the tracing context (e.g. in testing).
838
+ # Instead, we just look at fake_args fake tensor mode
839
+ if len(fake_args) == 0:
840
+ return torch.fx.experimental.symbolic_shapes.ShapeEnv()
841
+
842
+ for arg in fake_args:
843
+ if isinstance(arg, FakeTensor):
844
+ return arg.fake_mode.shape_env
845
+ return None
846
+
847
+ # Clone the fake args to avoid mutating the original fake args
848
+ with ExitStack() as ctx_stack:
849
+ # We need to re-use prev_fake_mode's shape env to resolve
850
+ # the runtime assertions for unbacked symbols.
851
+ new_fake_mode = torch._subclasses.FakeTensorMode(
852
+ shape_env=_get_shape_env(fake_args),
853
+ allow_non_fake_inputs=False,
854
+ )
855
+ # We need to temporarily turn inference_mode off because
856
+ # under inference mode, tensor version counter is not tracked.
857
+ no_inference_mode_ctx = torch.inference_mode(False)
858
+ ctx_stack.enter_context(new_fake_mode)
859
+ ctx_stack.enter_context(no_inference_mode_ctx)
860
+ if new_fake_mode.shape_env is not None:
861
+ ctx_stack.enter_context(
862
+ new_fake_mode.shape_env.ignore_fresh_unbacked_symbols()
863
+ )
864
+
865
+ # create new fake tensors in new fake mode to avoid mutating original tensors
866
+ cloned = [
867
+ torch.empty_strided(
868
+ arg.size(),
869
+ arg.stride(),
870
+ dtype=arg.dtype,
871
+ device=arg.device,
872
+ requires_grad=arg.requires_grad,
873
+ layout=arg.layout,
874
+ )
875
+ if isinstance(arg, torch.Tensor)
876
+ else arg
877
+ for arg in fake_args
878
+ ]
879
+ before = [_tensor_version(arg) for arg in cloned]
880
+ outputs = gm(*cloned)
881
+ outputs = [outputs] if not isinstance(outputs, (list, tuple)) else outputs
882
+ after = [_tensor_version(arg) for arg in cloned]
883
+ mutated_inputs = [
884
+ i for i, (v1, v2) in enumerate(zip(before, after)) if v1 != v2
885
+ ]
886
+ # We need to analyze the original fake_args to detect
887
+ # inp-inp alias.
888
+ inp_storage_map = {
889
+ _tensor_storage(inp): i
890
+ for i, inp in enumerate(fake_args)
891
+ if isinstance(inp, torch.Tensor)
892
+ }
893
+ inp_inp_alias_map = {
894
+ i: inp_storage_map[_tensor_storage(inp)]
895
+ for i, inp in enumerate(fake_args)
896
+ if isinstance(inp, torch.Tensor)
897
+ and inp_storage_map[_tensor_storage(inp)] != i
898
+ }
899
+ out_storage_map = {
900
+ _tensor_storage(out): i
901
+ for i, out in enumerate(outputs)
902
+ if isinstance(out, torch.Tensor)
903
+ }
904
+ out_out_alias_map = {
905
+ i: out_storage_map[_tensor_storage(out)]
906
+ for i, out in enumerate(outputs)
907
+ if isinstance(out, torch.Tensor)
908
+ and out_storage_map[_tensor_storage(out)] != i
909
+ }
910
+ inp_out_alias_map = {
911
+ i: out_storage_map[_tensor_storage(inp)]
912
+ for i, inp in enumerate(cloned)
913
+ if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map
914
+ }
915
+ return (
916
+ inp_inp_alias_map,
917
+ inp_out_alias_map,
918
+ out_out_alias_map,
919
+ mutated_inputs,
920
+ outputs,
921
+ )
922
+
923
+
924
+ registered_hop_fake_fns: dict[torch._ops.OpOverload, Callable] = {}
925
+
926
+
927
+ F = TypeVar("F", bound=Callable)
928
+
929
+
930
+ @overload
931
+ def register_fake(hop, fn: None = None) -> Callable[[F], F]:
932
+ ...
933
+
934
+
935
+ @overload
936
+ def register_fake(hop, fn: F) -> F:
937
+ ...
938
+
939
+
940
+ def register_fake(hop, fn=None):
941
+ """
942
+ Register a fake function for a HOP. This is conceptually equivalent of the
943
+ register_fake utility for the custom ops. The registered function is called
944
+ inside the fake_tensor _dispatch_impl.
945
+ """
946
+ assert hop not in registered_hop_fake_fns
947
+
948
+ def register(func: F) -> F:
949
+ from torch._subclasses.fake_tensor import FakeTensorMode
950
+
951
+ redirect_to_mode(hop, FakeTensorMode)
952
+
953
+ registered_hop_fake_fns[hop] = func
954
+ return func
955
+
956
+ if fn is None:
957
+ return register
958
+ return register(fn)
959
+
960
+
961
+ class FunctionalizeCtxWrapper:
962
+ """
963
+ This is a dummy wrapper to facilitate fake tensor caching.
964
+
965
+ For AOT Dispatcher metadata collection pass, HOPs go from functionalization
966
+ key to fake tensor key. The functionalization key wraps the subgraphs in a
967
+ function, which changes from call to call even though the subgraph might
968
+ still be same.
969
+
970
+ To enable fake tensor caching, we just wrap the ctx and subgraph in this
971
+ class and then use the subgraph as the hash.
972
+ """
973
+
974
+ # Prevents PYTORCH_TEST_WITH_DYNAMO=1 test failures
975
+ @torch._disable_dynamo
976
+ def __init__(self, ctx, subgraph):
977
+ self.ctx = ctx
978
+ self.subgraph = subgraph
979
+
980
+ def __hash__(self):
981
+ return id(self.subgraph)
982
+
983
+ def __repr__(self):
984
+ return f"FunctionalizeCtxWrapper on subgraph {self.subgraph})"
985
+
986
+ def __call__(self, *args, **kwargs):
987
+ if isinstance(self.subgraph, torch.fx.GraphModule):
988
+ # Running graph with interpreter is needed for propagating the stack_trace
989
+ with fx_traceback.preserve_node_meta():
990
+ return self.ctx.functionalize(torch.fx.Interpreter(self.subgraph).run)(
991
+ *args, **kwargs
992
+ )
993
+ return self.ctx.functionalize(self.subgraph)(*args, **kwargs)
994
+
995
+
996
+ # A wrapper over HigherOrderOperator that also carries its schema
997
+ class HopInstance:
998
+ def __init__(self, op: HigherOrderOperator, schema: HopSchema):
999
+ assert isinstance(op, HigherOrderOperator), op
1000
+ self._op = op
1001
+ # Using "_" to be consistent with how we access _schema of OpOverload
1002
+ self._schema = schema
1003
+
1004
+ def __call__(self, *args, **kwargs):
1005
+ return self._op(*args, **kwargs)
1006
+
1007
+ @staticmethod
1008
+ def create(hop: HigherOrderOperator, *args, **kwargs):
1009
+ return HopInstance(hop, hop.gen_schema(*args, **kwargs))
1010
+
1011
+
1012
+ # This call_op can be used to call a HopInstance with
1013
+ # flat args and kwargs. We need to make use of the hop's schema's tree_spec
1014
+ # to unflatten the args and kwargs before calling the hop.
1015
+ def call_op(op: Union[OpOverload, HopInstance], args, kwargs):
1016
+ if isinstance(op, OpOverload):
1017
+ return op(*args, **kwargs)
1018
+
1019
+ assert isinstance(op, HopInstance), op
1020
+ schema = op._schema
1021
+ bound_args = list(args)
1022
+ bound_kwargs = {}
1023
+ for arg in schema.arguments[len(bound_args) :]:
1024
+ assert arg.name in kwargs, (arg.name, kwargs)
1025
+ val = kwargs[arg.name]
1026
+ if not arg.kwarg_only:
1027
+ bound_args.append(val)
1028
+ else:
1029
+ bound_kwargs[arg.name] = val
1030
+
1031
+ if schema.tree_spec is not None:
1032
+ assert len(bound_args) == len(schema.arguments) and len(bound_kwargs) == 0
1033
+ args, kwargs = pytree.tree_unflatten(bound_args, schema.tree_spec)
1034
+ return op(*args, **kwargs)
1035
+ else:
1036
+ assert len(bound_args) + len(bound_kwargs) == len(schema.arguments)
1037
+ return op(*bound_args, **bound_kwargs)
1038
+
1039
+
1040
+ def materialize_as_graph(
1041
+ fn: Callable,
1042
+ args: tuple[Any],
1043
+ include_key_set: Optional[torch._C.DispatchKeySet] = None,
1044
+ exclude_key_set: Optional[torch._C.DispatchKeySet] = None,
1045
+ force_enable_grad=False,
1046
+ ) -> torch.fx.GraphModule:
1047
+ if include_key_set is None:
1048
+ include_key_set = torch._C._dispatch_tls_local_include_set()
1049
+ if exclude_key_set is None:
1050
+ exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
1051
+
1052
+ @torch._dynamo.disable(recursive=True, reason=None)
1053
+ def _materialize_as_graph_inner():
1054
+ with suspend_functionalization(), disable_functional_mode():
1055
+ with disable_proxy_modes_tracing():
1056
+ unfunc_t = [_from_fun(arg) for arg in args]
1057
+ with contextlib.ExitStack() as stack:
1058
+ stack.enter_context(
1059
+ torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
1060
+ )
1061
+ if force_enable_grad:
1062
+ stack.enter_context(torch.enable_grad())
1063
+ return _maybe_reenter_make_fx(fn)(*unfunc_t)
1064
+
1065
+ gm = _materialize_as_graph_inner()
1066
+ assert gm is not None
1067
+ return gm
1068
+
1069
+
1070
+ def materialize_callable_in_args(op: HopInstance, args, kwargs):
1071
+ schema = op._schema
1072
+ hop = op._op
1073
+ flat_args, flat_spec = pytree.tree_flatten((args, kwargs))
1074
+
1075
+ def wrapped_fn(*flat_args):
1076
+ return call_op(op, args, kwargs)
1077
+
1078
+ # We need to trace the higher order op in order to materilaize the callable inputs that
1079
+ # are a callable (e.g. after functionalization key)
1080
+ gm = reenter_make_fx(wrapped_fn)(*flat_args)
1081
+ hop_node = gm.graph.find_nodes(op="call_function", target=hop)[0]
1082
+ arg_proxies = pytree.tree_leaves((hop_node.args, hop_node.kwargs))
1083
+ assert isinstance(schema, torch._C.FunctionSchema) and len(arg_proxies) == len(
1084
+ schema.arguments
1085
+ )
1086
+
1087
+ # call_op preserves ordering of proxies via schema
1088
+ materialized_args = []
1089
+ for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)):
1090
+ if (
1091
+ isinstance(proxy, torch.fx.Node)
1092
+ and proxy.op == "get_attr"
1093
+ and isinstance(getattr(gm, proxy.target), torch.fx.GraphModule) # type: ignore[arg-type]
1094
+ ):
1095
+ assert callable(flat_args[i]), (schema, args, kwargs)
1096
+ materialized_args.append(getattr(gm, proxy.target)) # type: ignore[arg-type]
1097
+ else:
1098
+ materialized_args.append(flat_args[i])
1099
+
1100
+ return pytree.tree_unflatten(materialized_args, flat_spec)
1101
+
1102
+
1103
+ def has_user_subclass(args, allowed_subclasses):
1104
+ """Check if any tensor arguments are user subclasses.
1105
+
1106
+ This is used to determine if tensor subclasses should get a chance to run
1107
+ their own implementation first before falling back to the default implementation.
1108
+
1109
+ Args:
1110
+ args: Arguments to check (will be flattened with pytree)
1111
+ allowed_subclasses: Tuple of allowed subclass types
1112
+
1113
+ Returns:
1114
+ True if user tensor subclasses are found, False otherwise
1115
+ """
1116
+ flat_args, _ = pytree.tree_flatten(args)
1117
+
1118
+ val = any(
1119
+ isinstance(a, torch.Tensor)
1120
+ and type(a) is not torch.Tensor
1121
+ and not isinstance(a, allowed_subclasses)
1122
+ for a in flat_args
1123
+ )
1124
+ return val
1125
+
1126
+
1127
+ def _has_gen_schema(op: HigherOrderOperator):
1128
+ # There is an InvokeQuant argument we cannot gen_schema.
1129
+ if op is torch.ops.higher_order.invoke_quant_packed:
1130
+ return False
1131
+ method = "gen_schema"
1132
+ return hasattr(type(op), method) and getattr(type(op), method) is not getattr(
1133
+ HigherOrderOperator, method
1134
+ )
archive/.venv/Lib/site-packages/torch/_higher_order_ops/while_loop.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ from typing import Callable, Union
4
+
5
+ import torch
6
+ import torch.utils._pytree as pytree
7
+ from torch._C import DispatchKey
8
+ from torch._higher_order_ops.utils import (
9
+ _maybe_run_with_interpreter,
10
+ _set_compilation_env,
11
+ autograd_not_implemented,
12
+ check_meta_consistency,
13
+ reenter_make_fx,
14
+ validate_subgraph_args_types,
15
+ )
16
+ from torch._ops import HigherOrderOperator
17
+ from torch._subclasses.fake_tensor import FakeTensorMode
18
+ from torch.fx.experimental.proxy_tensor import (
19
+ _temp_remove_metadata_torch_function_mode,
20
+ ProxyTorchDispatchMode,
21
+ track_tensor_tree,
22
+ )
23
+
24
+
25
+ class WhileLoopOp(HigherOrderOperator):
26
+ def __init__(self) -> None:
27
+ super().__init__("while_loop")
28
+
29
+ def __call__(
30
+ self,
31
+ cond_fn: Callable,
32
+ body_fn: Callable,
33
+ carried_inputs: tuple[Union[torch.Tensor, int, float, bool]],
34
+ additional_inputs: tuple[Union[torch.Tensor, torch.SymInt, int], ...],
35
+ /,
36
+ ):
37
+ if not isinstance(carried_inputs, (tuple, list)):
38
+ raise RuntimeError(
39
+ f"carried_inputs must be a tuple or list, got {type(carried_inputs)}"
40
+ )
41
+ if not isinstance(additional_inputs, (tuple, list)):
42
+ raise RuntimeError(
43
+ f"additional_inputs must be a tuple or list, got {type(additional_inputs)}"
44
+ )
45
+
46
+ validate_subgraph_args_types(carried_inputs)
47
+ validate_subgraph_args_types(additional_inputs)
48
+ return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
49
+
50
+
51
+ while_loop_op = WhileLoopOp()
52
+
53
+
54
+ def while_loop(cond_fn, body_fn, carried_inputs):
55
+ r"""
56
+ Run body_fn(*carried_inputs) while cond_fn(*carried_inputs) returns a True scalar tensor. Returns the output of body_fn or
57
+ initial carried_inputs.
58
+
59
+ .. warning::
60
+ `torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and
61
+ doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
62
+ Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
63
+
64
+ `while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export.
65
+
66
+ `while_loop` is equivalent to the following:
67
+
68
+ def while_loop(cond_fn, body_fn, carried_inputs):
69
+ val = carried_inputs
70
+ while cond_fn(*val):
71
+ val = body_fn(*val)
72
+ return val
73
+
74
+ Args:
75
+ cond_fn (Callable): A callable function that returns a boolean Scalar tensor or a python boolean.
76
+
77
+ body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors or ints
78
+
79
+ carried_inputs (Tuple of possibly nested dict/list/tuple of tensors or ints): A tuple of inputs to cond_fn and body_fn.
80
+ It's also the initial value of states that are carried across iterations. Note that when pass an integer as carry,
81
+ the corresponding return of while_loop will be another int with unknown values because we don't know how many
82
+ iterations while_loop will run.
83
+
84
+ Example 1:
85
+
86
+ def cond_fn(iter, x):
87
+ return iter.sum() < 10
88
+
89
+ def body_fn(iter, x):
90
+ return iter + 1, x.sin()
91
+
92
+ while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4)))
93
+
94
+ Example 2:
95
+
96
+ def cond_fn(int_iter, x):
97
+ return 2 * int_iter < x.shape[0]
98
+
99
+ def body_fn(int_iter, x):
100
+ return int_iter + 1, x + int_iter
101
+
102
+ while_loop(cond,_fn, body_fn, (0, torch.randn(3, 4)))
103
+
104
+ Restrictions:
105
+
106
+ - body_fn must return tensors or int with the same metadata (e.g.shape, dtype) as inputs.
107
+
108
+ - body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required.
109
+
110
+ - body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn.
111
+
112
+ - body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required.
113
+
114
+ .. warning::
115
+ Temporal Limitations:
116
+
117
+ - 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
118
+
119
+ """
120
+ from torch._dynamo.backends.debugging import (
121
+ make_eager_backend_with_torch_function_mode,
122
+ )
123
+
124
+ # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
125
+ # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
126
+ additional_inputs: tuple = ()
127
+
128
+ # The reason we flatten the output before calling into dynamo is that
129
+ # we want to create a consistent input ordering for cond_fn and body_fn.
130
+ # and we also want to the input ordering matches the output ordering.
131
+ # Also see NOTE: [why we cannot use "automatic" for while_loop]
132
+ # Construct flat cond_fn and flat_body_fn, which takes flattened inputs
133
+ flat_inputs, in_spec = pytree.tree_flatten((carried_inputs, additional_inputs))
134
+
135
+ def flat_cond_fn(*flat_args):
136
+ carried, additional = pytree.tree_unflatten(flat_args, in_spec)
137
+ return cond_fn(*carried, *additional)
138
+
139
+ def flat_body_fn(*flat_args):
140
+ carried, additional = pytree.tree_unflatten(flat_args, in_spec)
141
+ return body_fn(*carried, *additional)
142
+
143
+ if torch.compiler.is_dynamo_compiling():
144
+ return while_loop_op(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
145
+
146
+ def _validate_input(cond_fn, body_fn, carried_inputs):
147
+ from torch._higher_order_ops.utils import validate_subgraph_args_types
148
+
149
+ if not callable(cond_fn) or not callable(body_fn):
150
+ raise RuntimeError("Expect cond_fn and body_fn to be callable.")
151
+
152
+ validate_subgraph_args_types(flat_inputs)
153
+
154
+ if not pytree.tree_all(
155
+ lambda t: isinstance(t, (torch.Tensor, torch.SymInt, int)), carried_inputs
156
+ ):
157
+ raise RuntimeError(
158
+ "Expect carried_inputs to be a tuple of possibly nested dict/list/tuple that only"
159
+ f"consists of tensor or int leaves, but got {carried_inputs}."
160
+ )
161
+
162
+ _validate_input(cond_fn, body_fn, carried_inputs)
163
+
164
+ # Dynamo is expecting a callable with "__code__" attribute.
165
+ # We cannot directly pass cond_op to it. So we wrap it in a dummy function.
166
+ def _while_loop_op_wrapper(*args, **kwargs):
167
+ return while_loop_op(*args, **kwargs)
168
+
169
+ with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
170
+ with _temp_remove_metadata_torch_function_mode() as metadata_mode:
171
+ with _temp_remove_metadata_torch_function_mode() as metadata_mode:
172
+ if metadata_mode:
173
+ backend = make_eager_backend_with_torch_function_mode(metadata_mode)
174
+ else:
175
+ backend = "eager"
176
+ return torch.compile(
177
+ _while_loop_op_wrapper, backend=backend, fullgraph=True
178
+ )(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
179
+
180
+
181
+ @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
182
+ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
183
+ carried_vals = carried_inputs
184
+
185
+ def _validate_cond_output(pred):
186
+ if (
187
+ isinstance(pred, torch.Tensor)
188
+ and pred.size() == torch.Size([])
189
+ and pred.dtype == torch.bool
190
+ ) or isinstance(pred, bool):
191
+ return
192
+ else:
193
+ raise RuntimeError(
194
+ f"cond_fn must return a boolean scalar tensor or a boolean but got {pred}"
195
+ )
196
+
197
+ if not isinstance(carried_inputs, (tuple, list)):
198
+ raise RuntimeError(
199
+ f"carried_inputs must be a tuple or list but got {type(carried_inputs)}"
200
+ )
201
+
202
+ while pred := cond_fn(*carried_vals, *additional_inputs):
203
+ _validate_cond_output(pred)
204
+ out = body_fn(*carried_vals, *additional_inputs)
205
+ assert isinstance(
206
+ out, tuple
207
+ ), f"body_fn should return a tuple but got {type(out)}"
208
+ assert len(out) == len(
209
+ carried_inputs
210
+ ), "body_fn should return the same number of elements as carried_inputs"
211
+ carried_vals = out
212
+ return carried_vals
213
+
214
+
215
+ while_loop_op.py_autograd_impl(
216
+ autograd_not_implemented(while_loop_op, deferred_error=True)
217
+ )
218
+
219
+
220
+ def _find_or_create_fake_mode() -> FakeTensorMode:
221
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
222
+
223
+ fake_mode = torch._guards.detect_fake_mode()
224
+ if fake_mode is None:
225
+ fake_mode = FakeTensorMode(shape_env=ShapeEnv())
226
+
227
+ return fake_mode
228
+
229
+
230
+ def _create_unbacked_symint(
231
+ fake_mode: FakeTensorMode, ignore_fresh_unbacked_symbols: bool
232
+ ) -> torch.SymInt:
233
+ assert (
234
+ fake_mode is not None and fake_mode.shape_env is not None
235
+ ), "Must provide a fake_mode with shape_env."
236
+ ctx = (
237
+ contextlib.nullcontext()
238
+ if not ignore_fresh_unbacked_symbols
239
+ else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
240
+ )
241
+ with ctx:
242
+ return fake_mode.shape_env.create_unbacked_symint()
243
+
244
+
245
+ @while_loop_op.py_impl(ProxyTorchDispatchMode)
246
+ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
247
+ def _trace_while_loop(
248
+ proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
249
+ ):
250
+ # NOTE [unspecialize int carry with unbacked symints]
251
+ # When we support int carry, we'll also need to support int output of body_fn because.
252
+ # previous iteration's output is next iteration's input and they must match.
253
+ # For carries, when we start tracing while_loop, they can be
254
+ # - constants e.g. (0, [1, 3])
255
+ # - backed symints (x.shape[0], [x.shape[1] + x.stride[1], x.shape[2]])
256
+ # - unbacked symints e.g. (u0, [u0 + u1, u2])
257
+ # We choose the most conservative design: in all cases, we create new unbacked symints to trace the
258
+ # subgraph. It's possible to do some analysis on initial carry and the output of first
259
+ # iteration to determine a better range for the output unbacked symbol e.g. when input is an unbacked
260
+ # symint >= 0 before the while_loop but in general this is difficult because we don't know
261
+ # the number of iterations. Users would have to re-constrain the unbacked symint in subgraph if needed.
262
+ #
263
+ # For output of fake cond_fn, it could be constant bool or SymBool (e.g. return x.shape[0] < 4,
264
+ # where x.shape[0] can be either static of dynamic). In the case of constant bool, we should do a
265
+ # specialization (NYI).
266
+
267
+ # For output of fake body_fn, it could be all three types though from user's point of view,
268
+ # they're all integers e.g.
269
+
270
+ # init_carry = (0, s0, u1, t)
271
+ # def body_fn(u0, s0, u1, t):
272
+ # ...
273
+ # return (t.shape[0], t.shape[1], t.shape[2], y + 1)
274
+ #
275
+ # It may seem that a constant output isn't possible: users shouldn't write a while_loop
276
+ # that always return 0. But it could be that a shape is not set as dynamic properly (e.g.
277
+ # automatic dynamic hasn't been triggered).
278
+ #
279
+ # For this reason, we treat int, symint outputs in the same way:
280
+ # - they can match against any of int, symint carry
281
+ # - we unspecialize them with new unbacked symints in fake while_loop
282
+ # Similarly, we could do some analysis to refine the output ranges but it's eaiser to start with
283
+ # fresh unbacked symints. One suprising case can be: an input unbacked symint is constrained by
284
+ # users to be >= 0 (either before while_loop or inside body_fn) and it increments by 1 in each
285
+ # iteration. Ideally, we should know that the final output is >= 0 but we didn't constrain the
286
+ # unbacked symint output of subgraph as of today because this requires a smart range analysis.
287
+ fake_mode: FakeTensorMode = _find_or_create_fake_mode()
288
+ unspecialized_carried_inputs = pytree.tree_map_only(
289
+ (int, torch.SymInt),
290
+ # For temporarily created unbacked symints, we don't need to bind them to any proxy
291
+ lambda _: _create_unbacked_symint(
292
+ fake_mode, ignore_fresh_unbacked_symbols=True
293
+ ),
294
+ carried_inputs,
295
+ )
296
+
297
+ cond_graph = reenter_make_fx(cond_fn)(
298
+ *unspecialized_carried_inputs, *additional_inputs
299
+ )
300
+ body_graph = reenter_make_fx(body_fn)(
301
+ *unspecialized_carried_inputs, *additional_inputs
302
+ )
303
+
304
+ next_name = None
305
+ i = 0
306
+ while not next_name:
307
+ candidate = f"while_loop_cond_graph_{i}"
308
+ if hasattr(proxy_mode.tracer.root, candidate):
309
+ i += 1
310
+ else:
311
+ next_name = candidate
312
+ cond_graph_name = next_name
313
+ body_graph_name = f"while_loop_body_graph_{i}"
314
+ assert not hasattr(proxy_mode.tracer.root, body_graph_name)
315
+
316
+ proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
317
+ proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
318
+
319
+ args = (cond_graph, body_graph, carried_inputs, additional_inputs)
320
+
321
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
322
+
323
+ out_proxy = proxy_mode.tracer.create_proxy(
324
+ "call_function", while_loop_op, proxy_args, {}, name="while_loop"
325
+ )
326
+
327
+ out = while_loop_op(
328
+ cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
329
+ )
330
+ return track_tensor_tree(
331
+ out, out_proxy, constant=None, tracer=proxy_mode.tracer
332
+ )
333
+
334
+ return _trace_while_loop(
335
+ mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
336
+ )
337
+
338
+
339
+ @while_loop_op.py_impl(FakeTensorMode)
340
+ def while_loop_fake_tensor_mode(
341
+ mode, cond_fn, body_fn, carried_inputs, additional_inputs
342
+ ):
343
+ with mode:
344
+ # NOTE: [Handling unback symints in subgraph of while_loop]
345
+ # The idea is that the scope of unbacked symints are limited to the subgraph.
346
+ #
347
+ # We're implementing the fake tensor mode of while_loop operator.
348
+ # and we run body_fn once to get an fake output.
349
+ # Let's first consider the case that unbacked symints are tensor shapes:
350
+ #
351
+ # Case 1:
352
+ # if the unbacked symints is local to the subgraph e.g.
353
+ # def body_fn(it, x):
354
+ # nz = x.nonzero()
355
+ # return it+1. nz.sum()
356
+ # we can just ignore the newly created unbacked symints because it has
357
+ # no effect on the output of while_loop and it's tracked when we tracing.
358
+ # the subgraph.
359
+ #
360
+ # Case 2:
361
+ # if the unbacked symints are shape of output of while_loop e.g.
362
+ # def body_fn(it, x):
363
+ # nz = x.nonzero()
364
+ # return it+1, nz
365
+ # This will fail the shape check because in each iteration, the carried_input's shape
366
+ # must match the output shape as nz.shape contains newly allocated unbacked symint, this
367
+ # won't match the carried_input's shape.
368
+ #
369
+ # Case 3:
370
+ # if the unbacked symints are shape of carried_inputs e.g.
371
+ # nz = a.nonzero()
372
+ # body_fn(it, nz):
373
+ # return it+1. nz.sin() + 1,
374
+ # There's no new unbacked symints allocated in subgraph, so we're safe.
375
+ with mode.shape_env.ignore_fresh_unbacked_symbols():
376
+ # body_fn return output with the same pytree and tensor meta data as carried_inputs
377
+ # so we could just return the output after one iteration.
378
+ body_outs = body_fn(*carried_inputs, *additional_inputs)
379
+ check_meta_consistency(
380
+ carried_inputs,
381
+ body_outs,
382
+ "carried_inputs",
383
+ "body_output",
384
+ include_contiguity=False,
385
+ )
386
+ # See NOTE [unspecialize int carry with unbacked symints]
387
+ return pytree.tree_map_only(
388
+ (int, torch.SymInt),
389
+ # For while_loop's unbacked symint output, we want them to be bound
390
+ # to the proxy of while_loop's output.
391
+ lambda _: _create_unbacked_symint(
392
+ mode, ignore_fresh_unbacked_symbols=False
393
+ ),
394
+ body_outs,
395
+ )
396
+
397
+
398
+ @while_loop_op.py_functionalize_impl
399
+ def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
400
+ from torch._higher_order_ops.utils import _check_alias_and_mutation
401
+
402
+ unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
403
+ unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
404
+ unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
405
+ with ctx.redispatch_to_next():
406
+ functional_cond_fn = ctx.functionalize(_maybe_run_with_interpreter(cond_fn))
407
+ functional_body_fn = ctx.functionalize(_maybe_run_with_interpreter(body_fn))
408
+ pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
409
+ for fn, fn_name in [
410
+ (cond_fn, "cond_fn"),
411
+ (body_fn, "body_fn"),
412
+ ]:
413
+ _check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch)
414
+ ret = while_loop_op(
415
+ functional_cond_fn,
416
+ functional_body_fn,
417
+ unwrapped_carried_inputs,
418
+ unwrapped_additional_inputs,
419
+ )
420
+ return ctx.wrap_tensors(ret)
archive/.venv/Lib/site-packages/torch/_higher_order_ops/wrap.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import inspect
3
+ import itertools
4
+ import logging
5
+ from typing import Optional
6
+
7
+ from torch._logging import warning_once
8
+ from torch._ops import HigherOrderOperator
9
+ from torch.types import _dtype
10
+
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ uid = itertools.count(1)
15
+
16
+
17
+ # Used for testing the HigherOrderOperator mechanism
18
+ class Wrap(HigherOrderOperator):
19
+ def __init__(self) -> None:
20
+ super().__init__("wrap")
21
+
22
+ def __call__(self, func, *args, **kwargs):
23
+ # Dynamo already traces the body of HigherOrderOp beforehand when it
24
+ # so no need to trace into it.
25
+ import torch._dynamo # noqa: F401
26
+ from torch._dynamo import disable
27
+
28
+ @disable
29
+ def wrapper():
30
+ result = func(*args, **kwargs)
31
+ return result
32
+
33
+ return wrapper()
34
+
35
+
36
+ wrap = Wrap()
37
+
38
+
39
+ class WrapWithSetGradEnabled(HigherOrderOperator):
40
+ def __init__(self) -> None:
41
+ super().__init__("wrap_with_set_grad_enabled")
42
+
43
+ def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
44
+ # Dynamo already traces the body of HigherOrderOp beforehand when it
45
+ # so no need to trace into it.
46
+ import torch._dynamo # noqa: F401
47
+ from torch._dynamo import disable
48
+
49
+ @disable
50
+ def wrapper():
51
+ with torch.set_grad_enabled(enable_grad):
52
+ return wrapped_func(*args, **kwargs)
53
+
54
+ return wrapper()
55
+
56
+
57
+ wrap_with_set_grad_enabled = WrapWithSetGradEnabled()
58
+
59
+
60
+ class WrapWithAutocast(HigherOrderOperator):
61
+ def __init__(self):
62
+ super().__init__("wrap_with_autocast")
63
+
64
+ def __call__(
65
+ self,
66
+ device_type: str,
67
+ dtype: Optional[_dtype],
68
+ enabled: bool,
69
+ cache_enabled: Optional[bool],
70
+ wrapped_func,
71
+ *args,
72
+ **kwargs,
73
+ ):
74
+ # Dynamo already traces the body of HigherOrderOp beforehand when it
75
+ # so no need to trace into it.
76
+ import torch._dynamo # noqa: F401
77
+ from torch._dynamo import disable
78
+
79
+ @disable
80
+ def wrapper():
81
+ with torch.autocast(device_type, dtype, enabled, cache_enabled):
82
+ return wrapped_func(*args, **kwargs)
83
+
84
+ return wrapper()
85
+
86
+
87
+ wrap_with_autocast = WrapWithAutocast()
88
+
89
+
90
+ # This HOP allows you to bypass dynamo tracing of the wrapper function while
91
+ # still tracing the inner function.
92
+ # Takes two callables: The first, `wrapper_fn`, accepts `inner_fn` and returns a
93
+ # callable with the same signature. The second is the `inner_fn` itself. Any
94
+ # extra *args and **kwargs are forwarded to `wrapper_fn(inner_fn)` when it is
95
+ # executed.
96
+ class DynamoBypassingWrapper(HigherOrderOperator):
97
+ def __init__(self):
98
+ super().__init__("dynamo_bypassing_wrapper")
99
+
100
+ def __call__(
101
+ self,
102
+ wrapper_fn_or_key,
103
+ inner_fn,
104
+ *args,
105
+ **kwargs,
106
+ ):
107
+ # Dynamo already traces the body of HigherOrderOp beforehand when it
108
+ # so no need to trace into it.
109
+ import torch._dynamo # noqa: F401
110
+ from torch._dynamo import disable
111
+
112
+ is_compiling = isinstance(wrapper_fn_or_key, str)
113
+ if is_compiling:
114
+ assert isinstance(inner_fn, torch.fx.GraphModule)
115
+ wrapper_fn = inner_fn.meta[wrapper_fn_or_key]
116
+ else:
117
+ wrapper_fn = wrapper_fn_or_key
118
+
119
+ @disable
120
+ def wrapper():
121
+ return wrapper_fn(inner_fn)(*args, **kwargs)
122
+
123
+ return wrapper()
124
+
125
+
126
+ dynamo_bypassing_wrapper = DynamoBypassingWrapper()
127
+
128
+
129
+ class WrapActivationCheckpoint(HigherOrderOperator):
130
+ """
131
+ This operator is used to wrap torch.utils.checkpoint. This avoids
132
+ TorchDynamo to look into saved tensor hooks and directly passes the control
133
+ to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of
134
+ AOT tracing torch.utils.checkpoint code, we have a backward graph with
135
+ recomputed forward nodes.
136
+
137
+ However, we might deprecate this operator soon. The difficulty arises in the
138
+ functionalization of rng ops. Today, there are two different
139
+ functionalization of rng ops - one at AOT autograd and other at Inductor.
140
+ And they are difficult to map to each other. The rng states also complicate
141
+ pattern matching in Inductor. Due to the ease of implementation, we are
142
+ currently inclined towards functionalization at Inductor level, which means
143
+ that duplication/recomputation is done as a compiler pass in the
144
+ partitioners. See TagActivationCheckpoint for more information.
145
+ """
146
+
147
+ def __init__(self) -> None:
148
+ super().__init__("wrap_activation_checkpoint", cacheable=False)
149
+
150
+ def __call__(self, function, *args, **kwargs):
151
+ # use_reentrant is set to False because this op is going to be traced.
152
+ # And we ensure that AOT Autograd traces through the non reentrant
153
+ # version of checkpointing.
154
+ import torch.fx.traceback as fx_traceback
155
+ from torch.fx import Interpreter
156
+
157
+ kwargs["use_reentrant"] = False
158
+ kwargs["preserve_rng_state"] = False
159
+ # Using interpreter allows preservation of metadata through torch.compile stack.
160
+ with fx_traceback.preserve_node_meta():
161
+ from torch.utils.checkpoint import checkpoint
162
+
163
+ return checkpoint(Interpreter(function).run, *args, **kwargs)
164
+
165
+
166
+ wrap_activation_checkpoint = WrapActivationCheckpoint()
167
+
168
+
169
+ class TagActivationCheckpoint(HigherOrderOperator):
170
+ """
171
+ This operator is supposed to be used only with torch.compile stack. This
172
+ accepts a Fx graph module which needs to be checkpointed. This operator adds
173
+ "recomputable" tag to the nodes of the Fx graph that should be recomputed.
174
+
175
+ The goal is to:
176
+ 1. Avoid using Dynamo to trace through saved tensor hooks.
177
+ 2. For selective checkpointing case, let AOTAutograd trace through
178
+ saved tensor hooks but has special logic with TorchDispatchMode to override
179
+ the usual saved_tensor_hooks fn logic in order to tag the nodes.
180
+ 3. Rely on the partitioners to actually duplicate the nodes.
181
+ This sits well in the torch.compile stack, because by the time graph
182
+ reaches partitioner, inductor has already run its functionalization of rng
183
+ ops (by setting fixed seed for each random op, see `replace_random_passes`).
184
+ Therefore, the duplication of nodes, by design, respects the rng states in
185
+ the forward and recomputed forward in backward.
186
+ """
187
+
188
+ def __init__(self) -> None:
189
+ super().__init__("tag_activation_checkpoint", cacheable=False)
190
+
191
+ @staticmethod
192
+ def divide_kwargs(kwargs):
193
+ """
194
+ checkpoint fn can have mixed kwargs between checkpointed fn and
195
+ checkpoint fn itself. For example
196
+ >> def gn(x, y, z=None):
197
+ >> a = torch.matmul(x, y)
198
+ >> if z is not None:
199
+ >> return torch.matmul(a, z)
200
+ >> return a
201
+ >> def fn(x, y, z):
202
+ >> return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
203
+ In the above case, z belongs to checkpointed function gn, but
204
+ use_reentrant belongs to the checkpoint function. This function splits
205
+ the kwargs into checkpoint_kwargs and gmod_kwargs (or
206
+ checkpointed_fn_kwargs).
207
+ We do sorting to ensure same graph from run to run for better
208
+ debuggability. It is not required for correctness.
209
+ """
210
+ from torch.utils.checkpoint import checkpoint
211
+
212
+ ckpt_signature = inspect.signature(checkpoint)
213
+ checkpoint_keys = set()
214
+ for name in ckpt_signature.parameters:
215
+ if name in ("function", "args", "kwargs"):
216
+ continue
217
+ checkpoint_keys.add(name)
218
+
219
+ # `preserve_rng_state` is not a regular kwarg
220
+ checkpoint_keys.add("preserve_rng_state")
221
+
222
+ checkpoint_kwargs = {
223
+ name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys
224
+ }
225
+ gmod_kwargs = {
226
+ name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys
227
+ }
228
+ return checkpoint_kwargs, gmod_kwargs
229
+
230
+ def tag_nodes(self, gmod, is_sac):
231
+ from torch.utils.checkpoint import CheckpointPolicy
232
+
233
+ unique_graph_id = next(uid)
234
+ for node in gmod.graph.nodes:
235
+ if node.op in ("call_function", "call_method", "call_module"):
236
+ node.meta["ac_graph_id"] = unique_graph_id
237
+ if is_sac:
238
+ # For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode.
239
+ node.meta["recompute"] = None
240
+ else:
241
+ # Under vanilla activation checkpointing, all nodes should be recomputed.
242
+ node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
243
+ return gmod
244
+
245
+ def __call__(self, gmod, *args, **kwargs):
246
+ import torch.fx.traceback as fx_traceback
247
+ from torch.fx import Interpreter
248
+
249
+ if "_checkpoint_context_fn" in gmod.meta:
250
+ warning_once(
251
+ log,
252
+ """
253
+ Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
254
+ Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
255
+ """,
256
+ )
257
+ # use_reentrant is set to False because this op is going to be traced.
258
+ # And we ensure that AOT Autograd traces through the non reentrant
259
+ # version of checkpointing.
260
+ kwargs["use_reentrant"] = False
261
+ # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through
262
+ # `torch.random.fork_rng` op (which is not supported yet under CUDA).
263
+ # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state
264
+ # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor
265
+ # instead of in AOTAutograd).
266
+ kwargs["preserve_rng_state"] = False
267
+ kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"]
268
+ # We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag
269
+ # for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py.
270
+ gmod = self.tag_nodes(gmod, is_sac=True)
271
+ # Using interpreter allows preservation of metadata through torch.compile stack.
272
+ with fx_traceback.preserve_node_meta():
273
+ from torch.utils.checkpoint import checkpoint
274
+
275
+ return checkpoint(Interpreter(gmod).run, *args, **kwargs)
276
+ else:
277
+ gmod = self.tag_nodes(gmod, is_sac=False)
278
+ # Using interpreter allows preservation of metadata through torch.compile stack.
279
+ # TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here
280
+ # as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile.
281
+ # (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test)
282
+ with fx_traceback.preserve_node_meta():
283
+ return Interpreter(gmod).run(*args)
284
+
285
+
286
+ tag_activation_checkpoint = TagActivationCheckpoint()