Add files using upload-large-folder tool
Browse files- archive/.venv/Lib/site-packages/torch/_C/_cudnn.pyi +14 -0
- archive/.venv/Lib/site-packages/torch/_C/_cusparselt.pyi +1 -0
- archive/.venv/Lib/site-packages/torch/_C/_distributed_autograd.pyi +26 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/executorch_call_delegate.py +175 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/flat_apply.py +125 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/flex_attention.py +1268 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/foreach_map.py +23 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/hints_wrap.py +142 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/invoke_subgraph.py +658 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/map.py +291 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/out_dtype.py +163 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/run_const_graph.py +60 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/scan.py +929 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/schema.py +306 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/strict_mode.py +108 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/torchbind.py +164 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py +2051 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/utils.py +1134 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/while_loop.py +420 -0
- archive/.venv/Lib/site-packages/torch/_higher_order_ops/wrap.py +286 -0
archive/.venv/Lib/site-packages/torch/_C/_cudnn.pyi
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import IntEnum
|
| 2 |
+
|
| 3 |
+
# Defined in torch/csrc/cuda/shared/cudnn.cpp
|
| 4 |
+
is_cuda: bool
|
| 5 |
+
|
| 6 |
+
def getRuntimeVersion() -> tuple[int, int, int]: ...
|
| 7 |
+
def getCompileVersion() -> tuple[int, int, int]: ...
|
| 8 |
+
def getVersionInt() -> int: ...
|
| 9 |
+
|
| 10 |
+
class RNNMode(IntEnum):
|
| 11 |
+
rnn_relu = ...
|
| 12 |
+
rnn_tanh = ...
|
| 13 |
+
lstm = ...
|
| 14 |
+
gru = ...
|
archive/.venv/Lib/site-packages/torch/_C/_cusparselt.pyi
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
def getVersionInt() -> int: ...
|
archive/.venv/Lib/site-packages/torch/_C/_distributed_autograd.pyi
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# This module is defined in torch/csrc/distributed/autograd/init.cpp
|
| 6 |
+
|
| 7 |
+
class DistAutogradContext:
|
| 8 |
+
def _context_id(self) -> int: ...
|
| 9 |
+
def _recv_functions(self) -> dict[int, Any]: ...
|
| 10 |
+
def _send_functions(self) -> dict[int, Any]: ...
|
| 11 |
+
def _known_worker_ids(self) -> set[int]: ...
|
| 12 |
+
|
| 13 |
+
def _new_context() -> DistAutogradContext: ...
|
| 14 |
+
def _release_context(context_id: int) -> None: ...
|
| 15 |
+
def _get_max_id() -> int: ...
|
| 16 |
+
def _is_valid_context(worker_id: int) -> bool: ...
|
| 17 |
+
def _retrieve_context(context_id: int) -> DistAutogradContext: ...
|
| 18 |
+
def _current_context() -> DistAutogradContext: ...
|
| 19 |
+
def _init(worker_id: int) -> None: ...
|
| 20 |
+
def _get_debug_info() -> dict[str, str]: ...
|
| 21 |
+
def backward(
|
| 22 |
+
context_id: int,
|
| 23 |
+
roots: list[torch.Tensor],
|
| 24 |
+
retain_graph: bool = False,
|
| 25 |
+
) -> None: ...
|
| 26 |
+
def get_gradients(context_id: int) -> dict[torch.Tensor, torch.Tensor]: ...
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/executorch_call_delegate.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
# pyre-strict
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import Any, cast
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.utils._pytree as pytree
|
| 17 |
+
from torch._ops import HigherOrderOperator
|
| 18 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 19 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 20 |
+
disable_proxy_modes_tracing,
|
| 21 |
+
get_proxy_slot,
|
| 22 |
+
ProxyTorchDispatchMode,
|
| 23 |
+
track_tensor_tree,
|
| 24 |
+
)
|
| 25 |
+
from torch.utils._pytree import tree_flatten
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ExecutorchCallDelegate(HigherOrderOperator):
|
| 29 |
+
def __init__(self):
|
| 30 |
+
super().__init__("executorch_call_delegate")
|
| 31 |
+
|
| 32 |
+
def __call__(self, lowered_module, *args):
|
| 33 |
+
return super().__call__(lowered_module, *args)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
executorch_call_delegate = ExecutorchCallDelegate()
|
| 37 |
+
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
|
| 38 |
+
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
|
| 39 |
+
executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)
|
| 40 |
+
executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
|
| 41 |
+
|
| 42 |
+
LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# pyre-ignore
|
| 46 |
+
def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
|
| 47 |
+
# pyre-ignore
|
| 48 |
+
def _unwrap_proxy(e):
|
| 49 |
+
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
|
| 50 |
+
return e
|
| 51 |
+
return get_proxy_slot(
|
| 52 |
+
cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy # type: ignore[attr-defined]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if not is_lowered_module(lowered_module):
|
| 56 |
+
raise ValueError(
|
| 57 |
+
"executorch_call_delegate()'s first argument must be a LoweredBackendModule"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
with disable_proxy_modes_tracing():
|
| 61 |
+
out = call_delegate_cpu(lowered_module, *args)
|
| 62 |
+
|
| 63 |
+
get_lowered_module_name(proxy_mode.tracer.root, lowered_module)
|
| 64 |
+
|
| 65 |
+
node_args = (lowered_module, *args)
|
| 66 |
+
proxy_args = pytree.tree_map(_unwrap_proxy, node_args)
|
| 67 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 68 |
+
"call_function", func_overload, proxy_args, {}, name="executorch_call_delegate"
|
| 69 |
+
)
|
| 70 |
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
|
| 74 |
+
# pyre-ignore
|
| 75 |
+
def call_delegate_cpu(lowered_module, *args):
|
| 76 |
+
# FX creates this immutable_dict/list concept. Get rid of this.
|
| 77 |
+
map_types: dict[type, type] = {
|
| 78 |
+
torch.fx.immutable_collections.immutable_dict: dict,
|
| 79 |
+
torch.fx.immutable_collections.immutable_list: list,
|
| 80 |
+
}
|
| 81 |
+
new_args = pytree.tree_map_only(
|
| 82 |
+
tuple(map_types.keys()),
|
| 83 |
+
lambda a: map_types[type(a)](a),
|
| 84 |
+
args,
|
| 85 |
+
lambda a: isinstance(a, tuple(map_types.keys())),
|
| 86 |
+
)
|
| 87 |
+
return lowered_module.original_module.module()(*new_args)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@executorch_call_delegate.py_autograd_impl
|
| 91 |
+
# pyre-ignore
|
| 92 |
+
def call_delegate_autograd(lowered_module, *args):
|
| 93 |
+
# TODO: support autograd
|
| 94 |
+
flat_operands, _ = tree_flatten([lowered_module, *args])
|
| 95 |
+
requires_grad = any(
|
| 96 |
+
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
with torch._C._ExcludeDispatchKeyGuard(
|
| 100 |
+
torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
| 101 |
+
):
|
| 102 |
+
res = executorch_call_delegate(lowered_module, *args)
|
| 103 |
+
|
| 104 |
+
if requires_grad:
|
| 105 |
+
# Create aliases of the output that has requires_grad=True. We need
|
| 106 |
+
# at least one of the inputs to err_fn to require grad so that the
|
| 107 |
+
# output will have a grad_fn.
|
| 108 |
+
|
| 109 |
+
# pyre-ignore
|
| 110 |
+
def fake_requires_grad(var):
|
| 111 |
+
if var is not None:
|
| 112 |
+
var = var.detach()
|
| 113 |
+
if torch.is_floating_point(var) or torch.is_complex(var):
|
| 114 |
+
var.requires_grad = True
|
| 115 |
+
return var
|
| 116 |
+
|
| 117 |
+
return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res)
|
| 118 |
+
|
| 119 |
+
return res
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@executorch_call_delegate.py_impl(ProxyTorchDispatchMode)
|
| 123 |
+
# pyre-ignore
|
| 124 |
+
def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args):
|
| 125 |
+
res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args)
|
| 126 |
+
return res
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@executorch_call_delegate.py_impl(FakeTensorMode)
|
| 130 |
+
# pyre-ignore
|
| 131 |
+
def call_delegate_fake_tensor_mode(mode, lowered_module, *args):
|
| 132 |
+
with mode:
|
| 133 |
+
return call_delegate_cpu(lowered_module, *args)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@executorch_call_delegate.py_functionalize_impl
|
| 137 |
+
# pyre-ignore
|
| 138 |
+
def call_delegate_functionalize(ctx, lowered_module, *args):
|
| 139 |
+
unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
|
| 140 |
+
with ctx.redispatch_to_next():
|
| 141 |
+
res = executorch_call_delegate(lowered_module, *unwrapped_args)
|
| 142 |
+
return ctx.wrap_tensors(res)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre
|
| 146 |
+
def is_lowered_module(obj: Any) -> bool:
|
| 147 |
+
"""
|
| 148 |
+
This function is added to avoid using isinstance(obj,
|
| 149 |
+
LoweredBackendModule) as it will import LoweredBackendModule, which may
|
| 150 |
+
cause a circular import.
|
| 151 |
+
"""
|
| 152 |
+
return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_lowered_module_name(
|
| 156 |
+
root: torch.nn.Module,
|
| 157 |
+
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
|
| 158 |
+
lowered_module: LOWERED_BACKEND_MODULE_TYPE, # type: ignore[valid-type]
|
| 159 |
+
) -> str:
|
| 160 |
+
"""
|
| 161 |
+
Adds the given lowered_module into the given root module and returns the
|
| 162 |
+
name of the module added.
|
| 163 |
+
"""
|
| 164 |
+
# Find a qualifying name for the lowered submodule
|
| 165 |
+
qualname = None
|
| 166 |
+
i = 0
|
| 167 |
+
while True:
|
| 168 |
+
qualname = f"lowered_module_{i}"
|
| 169 |
+
if not hasattr(root, qualname):
|
| 170 |
+
break
|
| 171 |
+
i += 1
|
| 172 |
+
assert qualname is not None
|
| 173 |
+
|
| 174 |
+
root.add_module(qualname, lowered_module)
|
| 175 |
+
return qualname
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/flat_apply.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.fx.node
|
| 7 |
+
import torch.utils._pytree as pytree
|
| 8 |
+
from torch._ops import HigherOrderOperator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def is_graphable(val) -> bool:
|
| 12 |
+
"""Definition: a graphable type is a type that that is an acceptable input/output type to a FX node."""
|
| 13 |
+
return isinstance(val, torch.fx.node.base_types)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def is_graphable_type(typ) -> bool:
|
| 17 |
+
"""Return whether the given type is graphable"""
|
| 18 |
+
return issubclass(typ, torch.fx.node.base_types)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def to_graphable(stuff):
|
| 22 |
+
"""Flattens stuff into a flat list of graphable types."""
|
| 23 |
+
# We can consider preserving things like List[int] to improve
|
| 24 |
+
# perf and readability (right now that is all flattened out)
|
| 25 |
+
flat_args, spec = pytree.tree_flatten(stuff)
|
| 26 |
+
for arg in flat_args:
|
| 27 |
+
if not is_graphable(arg):
|
| 28 |
+
raise RuntimeError(
|
| 29 |
+
f"Expected all pytree.tree_leaves of (args, kwargs) to be graphable types, but found "
|
| 30 |
+
f"non-fx-graphable type {type(arg)}. If this type is meant to be constant, mark it as "
|
| 31 |
+
f"via pytree.register_constant; otherwise, register it as a pytree."
|
| 32 |
+
)
|
| 33 |
+
return flat_args, spec
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def from_graphable(flat_args, spec):
|
| 37 |
+
"""The inverse of to_graphable."""
|
| 38 |
+
stuff = pytree.tree_unflatten(flat_args, spec)
|
| 39 |
+
return stuff
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def func_to_graphable(func):
|
| 43 |
+
"""
|
| 44 |
+
Pack and flatten a function type into graphable types.
|
| 45 |
+
This is useful for legalizing the function argument of `flat_apply`.
|
| 46 |
+
"""
|
| 47 |
+
return pytree.tree_flatten(_ConstantFunction(func))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass(frozen=True)
|
| 51 |
+
class _ConstantFunction:
|
| 52 |
+
func: Callable
|
| 53 |
+
|
| 54 |
+
def __call__(self, *args, **kwargs):
|
| 55 |
+
return self.func(*args, **kwargs)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
pytree.register_constant(_ConstantFunction)
|
| 59 |
+
|
| 60 |
+
_op_types = (
|
| 61 |
+
torch._ops.OpOverload,
|
| 62 |
+
torch._ops.OpOverloadPacket,
|
| 63 |
+
torch._ops.HigherOrderOperator,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class FlatApply(HigherOrderOperator):
|
| 68 |
+
def __init__(self) -> None:
|
| 69 |
+
super().__init__("flat_apply")
|
| 70 |
+
|
| 71 |
+
def __call__(self, func, in_spec, *flat_args, **_unused):
|
| 72 |
+
"""
|
| 73 |
+
Functions that take in non-graphable types cannot directly be put into FX graph.
|
| 74 |
+
|
| 75 |
+
Given func(*args, **kwargs), if all of the non-graphable types are pytrees,
|
| 76 |
+
then we're able to store a call to flat_apply(func, in_spec, *flat_args) in the FX graph.
|
| 77 |
+
|
| 78 |
+
The semantics of flat_apply(func, in_spec, *flat_args) are roughly equivalent to:
|
| 79 |
+
|
| 80 |
+
>>> def flat_apply_impl(func, in_spec, *flat_args):
|
| 81 |
+
>>> args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
|
| 82 |
+
>>> output = func(*args, **kwargs)
|
| 83 |
+
>>> return output
|
| 84 |
+
|
| 85 |
+
flat_apply supports the following two cases:
|
| 86 |
+
- an input type is a container type (e.g. of tensors) registered as a pytree.
|
| 87 |
+
We'll tree_flatten the input type and store the spec.
|
| 88 |
+
- an input type is a constant type (i.e. torch.compile will specialize on it)
|
| 89 |
+
registered with pytree.register_constant. The constant type goes directly
|
| 90 |
+
into the spec.
|
| 91 |
+
|
| 92 |
+
"""
|
| 93 |
+
assert isinstance(func, _op_types) or pytree._is_constant_holder(func)
|
| 94 |
+
assert len(_unused) == 0
|
| 95 |
+
return impl(func, in_spec, *flat_args)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def impl(func, in_spec, *flat_args):
|
| 99 |
+
if not isinstance(func, _op_types):
|
| 100 |
+
# assume _ConstantFunction
|
| 101 |
+
func = pytree._retrieve_constant(func)
|
| 102 |
+
assert isinstance(func, _ConstantFunction)
|
| 103 |
+
|
| 104 |
+
args, kwargs = from_graphable(flat_args, in_spec)
|
| 105 |
+
out = func(*args, **kwargs)
|
| 106 |
+
|
| 107 |
+
# Right now, all outputs must either be graphable or lists/tuples of graphables.
|
| 108 |
+
#
|
| 109 |
+
# TODO: The following can be updated to support non-graphable outputs and pytrees.
|
| 110 |
+
# For non-graphable constant outputs: the assumption would be that they are constant
|
| 111 |
+
# (everytime the function runs those MUST be the same)
|
| 112 |
+
# For pytree outputs:
|
| 113 |
+
# I'm not sure if we need to return (flat_output, spec) or just (flat_output,):
|
| 114 |
+
# in the latter case the tracers need to carry out the output specs
|
| 115 |
+
# (they need to know how to reconstruct the object from just the flat_output).
|
| 116 |
+
def is_valid_output(x):
|
| 117 |
+
if isinstance(x, (tuple, list)):
|
| 118 |
+
return all(map(is_valid_output, x))
|
| 119 |
+
return is_graphable(x)
|
| 120 |
+
|
| 121 |
+
assert is_valid_output(out)
|
| 122 |
+
return out
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
flat_apply = FlatApply()
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/flex_attention.py
ADDED
|
@@ -0,0 +1,1268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections.abc import Sequence
|
| 3 |
+
from typing import Any, Callable, Optional, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils._pytree as pytree
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch._C import DispatchKey
|
| 9 |
+
from torch._higher_order_ops.utils import (
|
| 10 |
+
_has_potential_branch_input_mutation,
|
| 11 |
+
_maybe_reenter_make_fx,
|
| 12 |
+
autograd_not_implemented,
|
| 13 |
+
has_user_subclass,
|
| 14 |
+
redirect_to_mode,
|
| 15 |
+
reenter_make_fx,
|
| 16 |
+
register_fake,
|
| 17 |
+
save_tensors_and_symints_for_backward,
|
| 18 |
+
saved_tensors_and_symints,
|
| 19 |
+
UnsupportedAliasMutationException,
|
| 20 |
+
validate_subgraph_args_types,
|
| 21 |
+
)
|
| 22 |
+
from torch._ops import HigherOrderOperator
|
| 23 |
+
from torch._subclasses import FakeTensor
|
| 24 |
+
from torch._subclasses.functional_tensor import FunctionalTensor
|
| 25 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 26 |
+
make_fx,
|
| 27 |
+
ProxyTorchDispatchMode,
|
| 28 |
+
track_tensor_tree,
|
| 29 |
+
)
|
| 30 |
+
from torch.fx.graph_module import GraphModule
|
| 31 |
+
from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import
|
| 35 |
+
def _construct_strides(
|
| 36 |
+
sizes: Sequence[int],
|
| 37 |
+
fill_order: Sequence[int],
|
| 38 |
+
) -> Sequence[int]:
|
| 39 |
+
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
|
| 40 |
+
# Initialize strides
|
| 41 |
+
assert len(sizes) == len(
|
| 42 |
+
fill_order
|
| 43 |
+
), "Length of sizes must match the length of the fill order"
|
| 44 |
+
strides = [0] * len(sizes)
|
| 45 |
+
|
| 46 |
+
# Start with stride 1 for the innermost dimension
|
| 47 |
+
current_stride = 1
|
| 48 |
+
|
| 49 |
+
# Iterate through the fill order populating strides
|
| 50 |
+
for dim in fill_order:
|
| 51 |
+
strides[dim] = current_stride
|
| 52 |
+
current_stride *= sizes[dim]
|
| 53 |
+
|
| 54 |
+
return strides
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _permute_strides(out: torch.Tensor, query_strides: tuple[int, ...]) -> torch.Tensor:
|
| 58 |
+
"""
|
| 59 |
+
Create a new tensor with the same data and shape as the input,
|
| 60 |
+
but with strides permuted based on the input tensor's stride order.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
out (torch.Tensor): The output tensor of attention.
|
| 64 |
+
query_strides (List[int]): The stride order of the input query tensor
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
torch.Tensor: A new tensor with same shape and data as the input,
|
| 68 |
+
but with strides permuted based on the query tensor's stride order.
|
| 69 |
+
"""
|
| 70 |
+
from torch._inductor.ir import get_fill_order
|
| 71 |
+
|
| 72 |
+
fill_order = get_fill_order(query_strides)
|
| 73 |
+
assert out.storage_offset() == 0, "Only support storage_offset == 0"
|
| 74 |
+
out_strides = _construct_strides(out.shape, fill_order)
|
| 75 |
+
new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides)
|
| 76 |
+
new_out.copy_(out)
|
| 77 |
+
return new_out
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class FlexAttentionHOP(HigherOrderOperator):
|
| 81 |
+
def __init__(self) -> None:
|
| 82 |
+
super().__init__("flex_attention", cacheable=True)
|
| 83 |
+
|
| 84 |
+
def __call__(
|
| 85 |
+
self,
|
| 86 |
+
query: torch.Tensor,
|
| 87 |
+
key: torch.Tensor,
|
| 88 |
+
value: torch.Tensor,
|
| 89 |
+
score_mod: Callable,
|
| 90 |
+
block_mask: tuple,
|
| 91 |
+
scale: float,
|
| 92 |
+
kernel_options: dict[str, Any],
|
| 93 |
+
score_mod_other_buffers: tuple = (),
|
| 94 |
+
mask_mod_other_buffers: tuple = (),
|
| 95 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 96 |
+
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
|
| 97 |
+
return super().__call__(
|
| 98 |
+
query,
|
| 99 |
+
key,
|
| 100 |
+
value,
|
| 101 |
+
score_mod,
|
| 102 |
+
block_mask,
|
| 103 |
+
scale,
|
| 104 |
+
kernel_options,
|
| 105 |
+
score_mod_other_buffers,
|
| 106 |
+
mask_mod_other_buffers,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
flex_attention = FlexAttentionHOP()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class FlexAttentionBackwardHOP(HigherOrderOperator):
|
| 114 |
+
def __init__(self) -> None:
|
| 115 |
+
super().__init__("flex_attention_backward")
|
| 116 |
+
|
| 117 |
+
def __call__(
|
| 118 |
+
self,
|
| 119 |
+
query: torch.Tensor,
|
| 120 |
+
key: torch.Tensor,
|
| 121 |
+
value: torch.Tensor,
|
| 122 |
+
out: torch.Tensor,
|
| 123 |
+
logsumexp: torch.Tensor,
|
| 124 |
+
grad_out: torch.Tensor,
|
| 125 |
+
grad_logsumexp: torch.Tensor,
|
| 126 |
+
fw_graph: Union[Callable, GraphModule],
|
| 127 |
+
joint_graph: GraphModule,
|
| 128 |
+
block_mask: tuple,
|
| 129 |
+
scale: float,
|
| 130 |
+
kernel_options: dict[str, Any],
|
| 131 |
+
score_mod_other_buffers: tuple = (),
|
| 132 |
+
mask_mod_other_buffers: tuple = (),
|
| 133 |
+
) -> tuple[
|
| 134 |
+
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
| 135 |
+
]:
|
| 136 |
+
validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers)
|
| 137 |
+
return super().__call__(
|
| 138 |
+
query,
|
| 139 |
+
key,
|
| 140 |
+
value,
|
| 141 |
+
out,
|
| 142 |
+
logsumexp,
|
| 143 |
+
grad_out,
|
| 144 |
+
grad_logsumexp,
|
| 145 |
+
fw_graph,
|
| 146 |
+
joint_graph,
|
| 147 |
+
block_mask,
|
| 148 |
+
scale,
|
| 149 |
+
kernel_options,
|
| 150 |
+
score_mod_other_buffers,
|
| 151 |
+
mask_mod_other_buffers,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
flex_attention_backward = FlexAttentionBackwardHOP()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _math_attention_inner(
|
| 159 |
+
query: torch.Tensor,
|
| 160 |
+
key: torch.Tensor,
|
| 161 |
+
value: torch.Tensor,
|
| 162 |
+
score_mod: Callable,
|
| 163 |
+
block_mask: tuple,
|
| 164 |
+
scale: float,
|
| 165 |
+
kernel_options: dict[str, Any],
|
| 166 |
+
score_mod_other_buffers: tuple = (),
|
| 167 |
+
mask_mod_other_buffers: tuple = (),
|
| 168 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 169 |
+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
| 170 |
+
|
| 171 |
+
working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32
|
| 172 |
+
|
| 173 |
+
scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)
|
| 174 |
+
|
| 175 |
+
b = torch.arange(0, scores.size(0), device=scores.device)
|
| 176 |
+
h = torch.arange(0, scores.size(1), device=scores.device)
|
| 177 |
+
m = torch.arange(0, scores.size(2), device=scores.device)
|
| 178 |
+
n = torch.arange(0, scores.size(3), device=scores.device)
|
| 179 |
+
|
| 180 |
+
captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
|
| 181 |
+
from torch.nn.attention.flex_attention import _vmap_for_bhqkv
|
| 182 |
+
|
| 183 |
+
# first input is score
|
| 184 |
+
score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim)
|
| 185 |
+
|
| 186 |
+
mask_mod = block_mask[-1]
|
| 187 |
+
mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers)
|
| 188 |
+
mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers)
|
| 189 |
+
|
| 190 |
+
with TransformGetItemToIndex():
|
| 191 |
+
scores = (scores * scale).to(working_precision)
|
| 192 |
+
post_mod_scores = torch.where(
|
| 193 |
+
mask_mod(b, h, m, n, *mask_mod_other_buffers),
|
| 194 |
+
score_mod(scores, b, h, m, n, *score_mod_other_buffers),
|
| 195 |
+
torch.tensor(-float("inf"), dtype=working_precision, device=scores.device),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
return scores, post_mod_scores
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def math_attention(
|
| 202 |
+
query: torch.Tensor,
|
| 203 |
+
key: torch.Tensor,
|
| 204 |
+
value: torch.Tensor,
|
| 205 |
+
score_mod: Callable,
|
| 206 |
+
block_mask: tuple,
|
| 207 |
+
scale: float,
|
| 208 |
+
kernel_options: dict[str, Any],
|
| 209 |
+
score_mod_other_buffers: tuple = (),
|
| 210 |
+
mask_mod_other_buffers: tuple = (),
|
| 211 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 212 |
+
"""Eager implementation
|
| 213 |
+
|
| 214 |
+
This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions.
|
| 215 |
+
We then apply the vectorized score_mod function to the scores matrix. Each wrap of vmap applies one of the
|
| 216 |
+
batch, head, m, or n dimensions. We need to apply vmap 4 times to vectorized over all 4 dimensions.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
query: The query tensor
|
| 220 |
+
key: The key tensor
|
| 221 |
+
value: The value tensor
|
| 222 |
+
score_mod: The score_mod function
|
| 223 |
+
other_buffers: Other buffers that are passed to the score_mod function
|
| 224 |
+
"""
|
| 225 |
+
# broadcast query & key along head dim for GQA
|
| 226 |
+
G = query.size(1) // key.size(1)
|
| 227 |
+
value = torch.repeat_interleave(value, G, dim=1)
|
| 228 |
+
key = torch.repeat_interleave(key, G, dim=1)
|
| 229 |
+
|
| 230 |
+
Bq, Bkv = query.size(0), key.size(0)
|
| 231 |
+
if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)):
|
| 232 |
+
raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}")
|
| 233 |
+
|
| 234 |
+
key = key.expand((Bq, *key.size()[1:]))
|
| 235 |
+
value = value.expand((Bq, *value.size()[1:]))
|
| 236 |
+
|
| 237 |
+
_, post_mod_scores = _math_attention_inner(
|
| 238 |
+
query,
|
| 239 |
+
key,
|
| 240 |
+
value,
|
| 241 |
+
score_mod,
|
| 242 |
+
block_mask,
|
| 243 |
+
scale,
|
| 244 |
+
kernel_options,
|
| 245 |
+
score_mod_other_buffers,
|
| 246 |
+
mask_mod_other_buffers,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Set fully masked rows' sumexp to 0.0
|
| 250 |
+
logsumexp = post_mod_scores.logsumexp(dim=-1)
|
| 251 |
+
masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1)
|
| 252 |
+
logsumexp = torch.where(masked_rows, -float("inf"), logsumexp)
|
| 253 |
+
|
| 254 |
+
post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1)
|
| 255 |
+
|
| 256 |
+
return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@flex_attention.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 260 |
+
def sdpa_dense(
|
| 261 |
+
query: torch.Tensor,
|
| 262 |
+
key: torch.Tensor,
|
| 263 |
+
value: torch.Tensor,
|
| 264 |
+
score_mod: Callable,
|
| 265 |
+
block_mask: tuple,
|
| 266 |
+
scale: float,
|
| 267 |
+
kernel_options: dict[str, Any],
|
| 268 |
+
score_mod_other_buffers: tuple = (),
|
| 269 |
+
mask_mod_other_buffers: tuple = (),
|
| 270 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 271 |
+
out, lse = math_attention(
|
| 272 |
+
query,
|
| 273 |
+
key,
|
| 274 |
+
value,
|
| 275 |
+
score_mod,
|
| 276 |
+
block_mask,
|
| 277 |
+
scale,
|
| 278 |
+
kernel_options,
|
| 279 |
+
score_mod_other_buffers,
|
| 280 |
+
mask_mod_other_buffers,
|
| 281 |
+
)
|
| 282 |
+
out = _permute_strides(out, query.stride())
|
| 283 |
+
return out, lse
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def trace_flex_attention(
|
| 287 |
+
proxy_mode: ProxyTorchDispatchMode,
|
| 288 |
+
query: torch.Tensor,
|
| 289 |
+
key: torch.Tensor,
|
| 290 |
+
value: torch.Tensor,
|
| 291 |
+
score_mod: Callable,
|
| 292 |
+
block_mask: tuple,
|
| 293 |
+
scale: float,
|
| 294 |
+
kernel_options: dict[str, Any],
|
| 295 |
+
score_mod_other_buffers: tuple = (),
|
| 296 |
+
mask_mod_other_buffers: tuple = (),
|
| 297 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 298 |
+
"""Traces the flex_attention operator with the given score_mod function and other_buffers.
|
| 299 |
+
|
| 300 |
+
Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function
|
| 301 |
+
This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We
|
| 302 |
+
access this graph module in inductor to inline the score_mod function to the triton template.
|
| 303 |
+
"""
|
| 304 |
+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
| 305 |
+
|
| 306 |
+
example_out = flex_attention(
|
| 307 |
+
query,
|
| 308 |
+
key,
|
| 309 |
+
value,
|
| 310 |
+
score_mod,
|
| 311 |
+
block_mask,
|
| 312 |
+
scale,
|
| 313 |
+
kernel_options,
|
| 314 |
+
score_mod_other_buffers,
|
| 315 |
+
mask_mod_other_buffers,
|
| 316 |
+
)
|
| 317 |
+
example_vals = [query.new_zeros((), requires_grad=query.requires_grad)] + [
|
| 318 |
+
query.new_zeros((), dtype=torch.int) for _ in range(4)
|
| 319 |
+
]
|
| 320 |
+
mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)]
|
| 321 |
+
mask_mod = block_mask[-1]
|
| 322 |
+
with TransformGetItemToIndex():
|
| 323 |
+
score_graph = reenter_make_fx(score_mod)(
|
| 324 |
+
*example_vals, *score_mod_other_buffers
|
| 325 |
+
)
|
| 326 |
+
mask_graph = reenter_make_fx(mask_mod)(
|
| 327 |
+
*mask_example_vals, *mask_mod_other_buffers
|
| 328 |
+
)
|
| 329 |
+
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
| 330 |
+
block_mask = block_mask[:-1] + (mask_graph,)
|
| 331 |
+
qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_score")
|
| 332 |
+
proxy_mode.tracer.root.register_module(qualname, score_graph)
|
| 333 |
+
mask_qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_mask")
|
| 334 |
+
proxy_mode.tracer.root.register_module(mask_qualname, mask_graph)
|
| 335 |
+
node_args = (
|
| 336 |
+
query,
|
| 337 |
+
key,
|
| 338 |
+
value,
|
| 339 |
+
score_graph,
|
| 340 |
+
block_mask,
|
| 341 |
+
scale,
|
| 342 |
+
kernel_options,
|
| 343 |
+
score_mod_other_buffers,
|
| 344 |
+
mask_mod_other_buffers,
|
| 345 |
+
)
|
| 346 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
| 347 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 348 |
+
"call_function", flex_attention, proxy_args, {}
|
| 349 |
+
)
|
| 350 |
+
return track_tensor_tree(
|
| 351 |
+
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@flex_attention.py_impl(ProxyTorchDispatchMode)
|
| 356 |
+
def flex_attention_proxy_torch_dispatch_mode(
|
| 357 |
+
mode: ProxyTorchDispatchMode,
|
| 358 |
+
query: torch.Tensor,
|
| 359 |
+
key: torch.Tensor,
|
| 360 |
+
value: torch.Tensor,
|
| 361 |
+
score_mod: Callable,
|
| 362 |
+
block_mask: tuple,
|
| 363 |
+
scale: float,
|
| 364 |
+
kernel_options: dict[str, Any],
|
| 365 |
+
score_mod_other_buffers: tuple = (),
|
| 366 |
+
mask_mod_other_buffers: tuple = (),
|
| 367 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 368 |
+
assert mode is not None, "Mode should always be enabled for python fallback key"
|
| 369 |
+
return trace_flex_attention(
|
| 370 |
+
mode,
|
| 371 |
+
query,
|
| 372 |
+
key,
|
| 373 |
+
value,
|
| 374 |
+
score_mod,
|
| 375 |
+
block_mask,
|
| 376 |
+
scale,
|
| 377 |
+
kernel_options,
|
| 378 |
+
score_mod_other_buffers,
|
| 379 |
+
mask_mod_other_buffers,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
@flex_attention.py_functionalize_impl
|
| 384 |
+
def flex_attention_functionalize(
|
| 385 |
+
ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,
|
| 386 |
+
query: torch.Tensor,
|
| 387 |
+
key: torch.Tensor,
|
| 388 |
+
value: torch.Tensor,
|
| 389 |
+
score_mod: Callable,
|
| 390 |
+
block_mask: tuple,
|
| 391 |
+
scale: float,
|
| 392 |
+
kernel_options: dict[str, Any],
|
| 393 |
+
score_mod_other_buffers: tuple = (),
|
| 394 |
+
mask_mod_other_buffers: tuple = (),
|
| 395 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 396 |
+
"""Defines the functionalization rules for the flex_attention operator.
|
| 397 |
+
|
| 398 |
+
Write now we are unwrapping each tensor and then redispatching to the next, however we want to
|
| 399 |
+
guard against any mutations in the score_mod function, to the other_buffers since those
|
| 400 |
+
are free variables.
|
| 401 |
+
"""
|
| 402 |
+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
| 403 |
+
|
| 404 |
+
if has_user_subclass(
|
| 405 |
+
(
|
| 406 |
+
query,
|
| 407 |
+
key,
|
| 408 |
+
value,
|
| 409 |
+
score_mod,
|
| 410 |
+
block_mask,
|
| 411 |
+
scale,
|
| 412 |
+
kernel_options,
|
| 413 |
+
score_mod_other_buffers,
|
| 414 |
+
mask_mod_other_buffers,
|
| 415 |
+
),
|
| 416 |
+
allowed_subclasses=(FakeTensor, FunctionalTensor),
|
| 417 |
+
):
|
| 418 |
+
return NotImplemented
|
| 419 |
+
|
| 420 |
+
query_unwrapped = ctx.unwrap_tensors(query)
|
| 421 |
+
key_unwrapped = ctx.unwrap_tensors(key)
|
| 422 |
+
value_unwrapped = ctx.unwrap_tensors(value)
|
| 423 |
+
block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
|
| 424 |
+
score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
|
| 425 |
+
mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
|
| 426 |
+
|
| 427 |
+
# Appease the mypy overlords
|
| 428 |
+
assert isinstance(query_unwrapped, torch.Tensor)
|
| 429 |
+
assert isinstance(key_unwrapped, torch.Tensor)
|
| 430 |
+
assert isinstance(value_unwrapped, torch.Tensor)
|
| 431 |
+
assert isinstance(block_mask_unwrapped, tuple)
|
| 432 |
+
assert isinstance(score_mod_other_buffers_unwrapped, tuple)
|
| 433 |
+
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
|
| 434 |
+
|
| 435 |
+
example_vals = (
|
| 436 |
+
[query_unwrapped.new_zeros(())]
|
| 437 |
+
+ [query_unwrapped.new_zeros((), dtype=torch.int) for _ in range(4)]
|
| 438 |
+
+ list(score_mod_other_buffers_unwrapped)
|
| 439 |
+
)
|
| 440 |
+
with ctx.redispatch_to_next():
|
| 441 |
+
functional_score_mod = ctx.functionalize(score_mod)
|
| 442 |
+
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
| 443 |
+
with TransformGetItemToIndex():
|
| 444 |
+
# TODO: So far only the input mutations are checked
|
| 445 |
+
# In the other HOPs, also aliases are checked which is
|
| 446 |
+
# omitted here
|
| 447 |
+
mutates = _has_potential_branch_input_mutation(
|
| 448 |
+
score_mod, example_vals, pre_dispatch
|
| 449 |
+
)
|
| 450 |
+
# The only care about mutations of existing buffers since we can't replay these.
|
| 451 |
+
# However, we can just error if anything is detected
|
| 452 |
+
if mutates:
|
| 453 |
+
raise UnsupportedAliasMutationException("Mutations detected in score_mod")
|
| 454 |
+
|
| 455 |
+
out = flex_attention(
|
| 456 |
+
query_unwrapped,
|
| 457 |
+
key_unwrapped,
|
| 458 |
+
value_unwrapped,
|
| 459 |
+
functional_score_mod,
|
| 460 |
+
block_mask_unwrapped,
|
| 461 |
+
scale,
|
| 462 |
+
kernel_options,
|
| 463 |
+
score_mod_other_buffers_unwrapped,
|
| 464 |
+
mask_mod_other_buffers_unwrapped,
|
| 465 |
+
)
|
| 466 |
+
return ctx.wrap_tensors(out) # type: ignore[return-value, arg-type]
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@register_fake(flex_attention)
|
| 470 |
+
def flex_attention_fake_impl(
|
| 471 |
+
query: torch.Tensor,
|
| 472 |
+
key: torch.Tensor,
|
| 473 |
+
value: torch.Tensor,
|
| 474 |
+
score_mod: Callable,
|
| 475 |
+
block_mask: tuple,
|
| 476 |
+
scale: float,
|
| 477 |
+
kernel_options: dict[str, Any],
|
| 478 |
+
score_mod_other_buffers: tuple = (),
|
| 479 |
+
mask_mod_other_buffers: tuple = (),
|
| 480 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 481 |
+
if has_user_subclass(
|
| 482 |
+
(
|
| 483 |
+
query,
|
| 484 |
+
key,
|
| 485 |
+
value,
|
| 486 |
+
score_mod,
|
| 487 |
+
block_mask,
|
| 488 |
+
scale,
|
| 489 |
+
kernel_options,
|
| 490 |
+
score_mod_other_buffers,
|
| 491 |
+
mask_mod_other_buffers,
|
| 492 |
+
),
|
| 493 |
+
allowed_subclasses=(FakeTensor,),
|
| 494 |
+
):
|
| 495 |
+
return NotImplemented
|
| 496 |
+
|
| 497 |
+
# TODO: Figure out a better way to handle this for NJT than using sum()
|
| 498 |
+
if query.is_nested:
|
| 499 |
+
out = torch.empty_like(query, memory_format=torch.contiguous_format)
|
| 500 |
+
logsumexp = query.sum(dim=-1)
|
| 501 |
+
return out, logsumexp
|
| 502 |
+
|
| 503 |
+
v_head_dim = value.size(-1)
|
| 504 |
+
batch_size, num_heads, seq_len_q, _q_head_dim = query.shape
|
| 505 |
+
logsumexp = query.new_empty(batch_size, num_heads, seq_len_q, dtype=torch.float32)
|
| 506 |
+
out_shape = (batch_size, num_heads, seq_len_q, v_head_dim)
|
| 507 |
+
out = query.new_empty(out_shape)
|
| 508 |
+
out = _permute_strides(out, query.stride())
|
| 509 |
+
return out, logsumexp
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
# Registers dispatches for SAC
|
| 513 |
+
redirect_to_mode(flex_attention, _CachingTorchDispatchMode)
|
| 514 |
+
redirect_to_mode(flex_attention, _CachedTorchDispatchMode)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
# ---------------------------- Autograd Implementation ----------------------------
|
| 518 |
+
def create_fw_bw_graph(
|
| 519 |
+
score_mod: Callable,
|
| 520 |
+
index_values: tuple[Tensor, Tensor, Tensor, Tensor, Tensor],
|
| 521 |
+
other_buffers: tuple[Tensor, ...],
|
| 522 |
+
) -> tuple[Callable, Callable]:
|
| 523 |
+
# See Note:[HOP create fw_bw graph]
|
| 524 |
+
|
| 525 |
+
# All of these imports need to be here in order to avoid circular dependencies
|
| 526 |
+
from torch._dispatch.python import suspend_functionalization
|
| 527 |
+
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
| 528 |
+
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
| 529 |
+
from torch._subclasses.functional_tensor import disable_functional_mode
|
| 530 |
+
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
|
| 531 |
+
|
| 532 |
+
dummy_aot_config = AOTConfig(
|
| 533 |
+
fw_compiler=None, # type: ignore[arg-type]
|
| 534 |
+
bw_compiler=None, # type: ignore[arg-type]
|
| 535 |
+
partition_fn=None, # type: ignore[arg-type]
|
| 536 |
+
decompositions={},
|
| 537 |
+
num_params_buffers=0,
|
| 538 |
+
aot_id=0,
|
| 539 |
+
keep_inference_input_mutations=False,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
with suspend_functionalization(), disable_functional_mode():
|
| 543 |
+
with disable_proxy_modes_tracing():
|
| 544 |
+
|
| 545 |
+
def _from_fun(
|
| 546 |
+
t: Union[Tensor, torch.SymInt, int],
|
| 547 |
+
) -> Union[Tensor, torch.SymInt, int]:
|
| 548 |
+
if isinstance(t, torch.Tensor):
|
| 549 |
+
return torch.empty_strided(
|
| 550 |
+
t.size(),
|
| 551 |
+
t.stride(),
|
| 552 |
+
device=t.device,
|
| 553 |
+
dtype=t.dtype,
|
| 554 |
+
requires_grad=t.requires_grad,
|
| 555 |
+
)
|
| 556 |
+
return t
|
| 557 |
+
|
| 558 |
+
# If someone runs this hop under the default compiler backend ("eager")
|
| 559 |
+
# Then this path will be run with the actual user inputs. We convert them
|
| 560 |
+
# to fake tensors in order to not perform any actual compute.
|
| 561 |
+
from torch._guards import detect_fake_mode
|
| 562 |
+
|
| 563 |
+
fake_mode = detect_fake_mode(index_values)
|
| 564 |
+
if fake_mode is None:
|
| 565 |
+
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
|
| 566 |
+
|
| 567 |
+
with fake_mode:
|
| 568 |
+
unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
|
| 569 |
+
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
|
| 570 |
+
|
| 571 |
+
assert all(
|
| 572 |
+
isinstance(t, (FakeTensor, int, torch.SymInt))
|
| 573 |
+
for t in unwrapped_score_mod_indexes + unwrapped_other_buffers
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
example_flat_out = pytree.tree_map(
|
| 577 |
+
_from_fun,
|
| 578 |
+
score_mod(*unwrapped_score_mod_indexes, *unwrapped_other_buffers),
|
| 579 |
+
)
|
| 580 |
+
if not isinstance(example_flat_out, torch.Tensor):
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Expected output of score_mod to be a tensor."
|
| 583 |
+
f"Got type {type(example_flat_out)}."
|
| 584 |
+
)
|
| 585 |
+
example_grad = _from_fun(example_flat_out)
|
| 586 |
+
|
| 587 |
+
def joint_f(
|
| 588 |
+
score: Tensor,
|
| 589 |
+
b: Tensor,
|
| 590 |
+
h: Tensor,
|
| 591 |
+
m: Tensor,
|
| 592 |
+
n: Tensor,
|
| 593 |
+
example_grad: Tensor,
|
| 594 |
+
*other_buffers: tuple[Tensor, ...],
|
| 595 |
+
) -> tuple[Tensor, ...]:
|
| 596 |
+
def fw_with_masks(
|
| 597 |
+
*args: tuple[Tensor, ...]
|
| 598 |
+
) -> tuple[tuple[Tensor], tuple[bool]]:
|
| 599 |
+
fw_out = score_mod(*args)
|
| 600 |
+
out_requires_grad = fw_out.requires_grad
|
| 601 |
+
return ((fw_out,), (out_requires_grad,))
|
| 602 |
+
|
| 603 |
+
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
|
| 604 |
+
args = [score, b, h, m, n] + list(other_buffers)
|
| 605 |
+
optional_grad = [example_grad] if example_grad.requires_grad else []
|
| 606 |
+
_, grads = joint(args, optional_grad)
|
| 607 |
+
|
| 608 |
+
return grads
|
| 609 |
+
|
| 610 |
+
joint_graph = make_fx(joint_f)(
|
| 611 |
+
*unwrapped_score_mod_indexes, example_grad, *unwrapped_other_buffers
|
| 612 |
+
)
|
| 613 |
+
return score_mod, joint_graph
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class FlexAttentionAutogradOp(torch.autograd.Function):
|
| 617 |
+
@staticmethod
|
| 618 |
+
def forward(
|
| 619 |
+
ctx: Any,
|
| 620 |
+
query: Tensor,
|
| 621 |
+
key: Tensor,
|
| 622 |
+
value: Tensor,
|
| 623 |
+
fw_graph: Callable,
|
| 624 |
+
joint_graph: Callable,
|
| 625 |
+
block_mask: tuple[Any, ...],
|
| 626 |
+
scale: float,
|
| 627 |
+
kernel_options: dict[str, Any],
|
| 628 |
+
mask_mod_other_buffers: tuple[Any, ...],
|
| 629 |
+
*score_mod_other_buffers: tuple[Any, ...],
|
| 630 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 631 |
+
any_buffer_requires_grad = any(
|
| 632 |
+
buffer.requires_grad
|
| 633 |
+
for buffer in mask_mod_other_buffers
|
| 634 |
+
if isinstance(buffer, torch.Tensor)
|
| 635 |
+
)
|
| 636 |
+
assert (
|
| 637 |
+
not any_buffer_requires_grad
|
| 638 |
+
), "Captured buffers from mask mod that require grad are not supported."
|
| 639 |
+
ctx._fw_graph = fw_graph
|
| 640 |
+
ctx._joint_graph = joint_graph
|
| 641 |
+
ctx._mask_graph = block_mask[-1]
|
| 642 |
+
ctx.scale = scale
|
| 643 |
+
ctx.kernel_options = kernel_options
|
| 644 |
+
ctx._score_mod_other_buffers_len = len(score_mod_other_buffers)
|
| 645 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 646 |
+
out, logsumexp = flex_attention(
|
| 647 |
+
query,
|
| 648 |
+
key,
|
| 649 |
+
value,
|
| 650 |
+
fw_graph,
|
| 651 |
+
block_mask,
|
| 652 |
+
scale,
|
| 653 |
+
kernel_options,
|
| 654 |
+
score_mod_other_buffers,
|
| 655 |
+
mask_mod_other_buffers,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
save_tensors_and_symints_for_backward(
|
| 659 |
+
ctx,
|
| 660 |
+
(
|
| 661 |
+
query,
|
| 662 |
+
key,
|
| 663 |
+
value,
|
| 664 |
+
out,
|
| 665 |
+
logsumexp,
|
| 666 |
+
*block_mask[:-1],
|
| 667 |
+
*score_mod_other_buffers,
|
| 668 |
+
*mask_mod_other_buffers,
|
| 669 |
+
),
|
| 670 |
+
)
|
| 671 |
+
return out, logsumexp
|
| 672 |
+
|
| 673 |
+
@staticmethod
|
| 674 |
+
def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> tuple[Optional[Tensor], ...]: # type: ignore[override]
|
| 675 |
+
fw_args = saved_tensors_and_symints(ctx)
|
| 676 |
+
(
|
| 677 |
+
query,
|
| 678 |
+
key,
|
| 679 |
+
value,
|
| 680 |
+
out,
|
| 681 |
+
logsumexp,
|
| 682 |
+
query_lengths,
|
| 683 |
+
kv_lengths,
|
| 684 |
+
kv_num_blocks,
|
| 685 |
+
kv_indices,
|
| 686 |
+
full_kv_num_blocks,
|
| 687 |
+
full_kv_indices,
|
| 688 |
+
q_num_blocks,
|
| 689 |
+
q_indices,
|
| 690 |
+
full_q_num_blocks,
|
| 691 |
+
full_q_indices,
|
| 692 |
+
Q_BLOCK_SIZE,
|
| 693 |
+
KV_BLOCK_SIZE,
|
| 694 |
+
*other_buffers,
|
| 695 |
+
) = fw_args
|
| 696 |
+
fw_graph = ctx._fw_graph
|
| 697 |
+
joint_graph = ctx._joint_graph
|
| 698 |
+
mask_graph = ctx._mask_graph
|
| 699 |
+
scale = ctx.scale
|
| 700 |
+
kernel_options = ctx.kernel_options
|
| 701 |
+
score_mod_other_buffers = tuple(
|
| 702 |
+
other_buffers[: ctx._score_mod_other_buffers_len]
|
| 703 |
+
)
|
| 704 |
+
mask_mod_other_buffers = tuple(
|
| 705 |
+
other_buffers[ctx._score_mod_other_buffers_len :]
|
| 706 |
+
)
|
| 707 |
+
# We have asserted that mask_mod_other_buffers do not require grad,
|
| 708 |
+
# but score_mod_other_buffers can require grad.
|
| 709 |
+
none_grads = [None] * 6
|
| 710 |
+
(
|
| 711 |
+
grad_query,
|
| 712 |
+
grad_key,
|
| 713 |
+
grad_value,
|
| 714 |
+
grad_score_mod_captured,
|
| 715 |
+
) = flex_attention_backward(
|
| 716 |
+
query,
|
| 717 |
+
key,
|
| 718 |
+
value,
|
| 719 |
+
out,
|
| 720 |
+
logsumexp,
|
| 721 |
+
grad_out,
|
| 722 |
+
grad_logsumexp,
|
| 723 |
+
fw_graph,
|
| 724 |
+
joint_graph,
|
| 725 |
+
(
|
| 726 |
+
query_lengths,
|
| 727 |
+
kv_lengths,
|
| 728 |
+
kv_num_blocks,
|
| 729 |
+
kv_indices,
|
| 730 |
+
full_kv_num_blocks,
|
| 731 |
+
full_kv_indices,
|
| 732 |
+
q_num_blocks,
|
| 733 |
+
q_indices,
|
| 734 |
+
full_q_num_blocks,
|
| 735 |
+
full_q_indices,
|
| 736 |
+
Q_BLOCK_SIZE,
|
| 737 |
+
KV_BLOCK_SIZE,
|
| 738 |
+
mask_graph,
|
| 739 |
+
),
|
| 740 |
+
scale,
|
| 741 |
+
kernel_options,
|
| 742 |
+
score_mod_other_buffers,
|
| 743 |
+
mask_mod_other_buffers,
|
| 744 |
+
)
|
| 745 |
+
return grad_query, grad_key, grad_value, *none_grads, *grad_score_mod_captured
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
# TODO: Rework DispatchKey.Autograd to py_autograd_impl
|
| 749 |
+
@flex_attention.py_impl(DispatchKey.Autograd)
|
| 750 |
+
def flex_attention_autograd(
|
| 751 |
+
query: torch.Tensor,
|
| 752 |
+
key: torch.Tensor,
|
| 753 |
+
value: torch.Tensor,
|
| 754 |
+
score_mod: Callable,
|
| 755 |
+
block_mask: tuple,
|
| 756 |
+
scale: float,
|
| 757 |
+
kernel_options: dict[str, Any],
|
| 758 |
+
score_mod_other_buffers: tuple[Tensor, ...] = (),
|
| 759 |
+
mask_mod_other_buffers: tuple[Tensor, ...] = (),
|
| 760 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 761 |
+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
| 762 |
+
|
| 763 |
+
with TransformGetItemToIndex():
|
| 764 |
+
input_requires_grad = any(
|
| 765 |
+
isinstance(t, torch.Tensor) and t.requires_grad
|
| 766 |
+
for t in (query, key, value, *score_mod_other_buffers)
|
| 767 |
+
)
|
| 768 |
+
if torch.is_grad_enabled() and input_requires_grad:
|
| 769 |
+
example_vals = (
|
| 770 |
+
query.new_zeros((), requires_grad=input_requires_grad),
|
| 771 |
+
query.new_zeros((), dtype=torch.int),
|
| 772 |
+
query.new_zeros((), dtype=torch.int),
|
| 773 |
+
query.new_zeros((), dtype=torch.int),
|
| 774 |
+
query.new_zeros((), dtype=torch.int),
|
| 775 |
+
)
|
| 776 |
+
fw_graph, bw_graph = create_fw_bw_graph(
|
| 777 |
+
score_mod, example_vals, score_mod_other_buffers
|
| 778 |
+
)
|
| 779 |
+
else:
|
| 780 |
+
fw_graph, bw_graph = score_mod, None
|
| 781 |
+
out, logsumexp = FlexAttentionAutogradOp.apply(
|
| 782 |
+
query,
|
| 783 |
+
key,
|
| 784 |
+
value,
|
| 785 |
+
fw_graph,
|
| 786 |
+
bw_graph,
|
| 787 |
+
block_mask,
|
| 788 |
+
scale,
|
| 789 |
+
kernel_options,
|
| 790 |
+
mask_mod_other_buffers,
|
| 791 |
+
*score_mod_other_buffers,
|
| 792 |
+
)
|
| 793 |
+
return out, logsumexp
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
# ---------------------------- Backward HOP Implementation ----------------------------
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
@flex_attention_backward.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 800 |
+
def sdpa_dense_backward(
|
| 801 |
+
query: torch.Tensor,
|
| 802 |
+
key: torch.Tensor,
|
| 803 |
+
value: torch.Tensor,
|
| 804 |
+
out: torch.Tensor,
|
| 805 |
+
logsumexp: torch.Tensor,
|
| 806 |
+
grad_out: torch.Tensor,
|
| 807 |
+
grad_logsumexp: torch.Tensor,
|
| 808 |
+
fw_graph: Callable, # GraphModule type hint?
|
| 809 |
+
joint_graph: Callable,
|
| 810 |
+
block_mask: tuple,
|
| 811 |
+
scale: float,
|
| 812 |
+
kernel_options: dict[str, Any],
|
| 813 |
+
score_mod_other_buffers: tuple,
|
| 814 |
+
mask_mod_other_buffers: tuple,
|
| 815 |
+
) -> tuple[
|
| 816 |
+
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
| 817 |
+
]:
|
| 818 |
+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
| 819 |
+
|
| 820 |
+
Bq, Hq, seq_len_q, qk_head_dim = query.shape
|
| 821 |
+
Bkv, Hkv, seq_len_kv, v_head_dim = value.shape
|
| 822 |
+
|
| 823 |
+
# Get outputs before calling repeat interleave and permute to input stride orders
|
| 824 |
+
actual_grad_query = query.new_empty((Bq, Hq, seq_len_q, qk_head_dim))
|
| 825 |
+
actual_grad_query = _permute_strides(actual_grad_query, query.stride())
|
| 826 |
+
|
| 827 |
+
actual_grad_key = key.new_empty((Bq, Hkv, seq_len_kv, qk_head_dim))
|
| 828 |
+
actual_grad_key = _permute_strides(actual_grad_key, key.stride())
|
| 829 |
+
|
| 830 |
+
actual_grad_value = value.new_empty((Bq, Hkv, seq_len_kv, v_head_dim))
|
| 831 |
+
actual_grad_value = _permute_strides(actual_grad_value, value.stride())
|
| 832 |
+
|
| 833 |
+
def _maybe_new_buffer(
|
| 834 |
+
buffer: Union[torch.Tensor, torch.SymInt, int],
|
| 835 |
+
) -> Optional[Union[torch.Tensor, torch.SymInt, int]]:
|
| 836 |
+
if isinstance(buffer, torch.Tensor):
|
| 837 |
+
return (
|
| 838 |
+
torch.empty_like(buffer, memory_format=torch.contiguous_format)
|
| 839 |
+
if buffer.requires_grad
|
| 840 |
+
else None
|
| 841 |
+
)
|
| 842 |
+
return buffer
|
| 843 |
+
|
| 844 |
+
actual_grad_score_mod_captured = [
|
| 845 |
+
_maybe_new_buffer(buffer) for buffer in score_mod_other_buffers
|
| 846 |
+
]
|
| 847 |
+
|
| 848 |
+
Bq, Bkv = query.size(0), key.size(0)
|
| 849 |
+
if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)):
|
| 850 |
+
raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}")
|
| 851 |
+
|
| 852 |
+
key = key.expand((Bq, *key.size()[1:]))
|
| 853 |
+
value = value.expand((Bq, *value.size()[1:]))
|
| 854 |
+
|
| 855 |
+
G = query.size(1) // key.size(1)
|
| 856 |
+
key = torch.repeat_interleave(key, G, dim=1)
|
| 857 |
+
value = torch.repeat_interleave(value, G, dim=1)
|
| 858 |
+
|
| 859 |
+
# We're undoing the log -> log2 change of base in the forwards
|
| 860 |
+
logsumexp = logsumexp * math.log(2)
|
| 861 |
+
# The backwards formula for the log -> log2 change of base in the forwards
|
| 862 |
+
grad_logsumexp = grad_logsumexp / math.log(2)
|
| 863 |
+
scores, post_mod_scores = _math_attention_inner(
|
| 864 |
+
query,
|
| 865 |
+
key,
|
| 866 |
+
value,
|
| 867 |
+
fw_graph,
|
| 868 |
+
block_mask,
|
| 869 |
+
scale,
|
| 870 |
+
kernel_options,
|
| 871 |
+
score_mod_other_buffers,
|
| 872 |
+
mask_mod_other_buffers,
|
| 873 |
+
)
|
| 874 |
+
masked_out_rows = logsumexp == -float("inf")
|
| 875 |
+
softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1))
|
| 876 |
+
softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores)
|
| 877 |
+
|
| 878 |
+
grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out
|
| 879 |
+
|
| 880 |
+
grad_softmax_scores = grad_out @ value.transpose(-2, -1)
|
| 881 |
+
|
| 882 |
+
sum_scores = torch.sum(out * grad_out, -1, keepdim=True)
|
| 883 |
+
grad_score_mod = softmax_scores * (
|
| 884 |
+
grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1)
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
b = torch.arange(0, scores.size(0), device=scores.device)
|
| 888 |
+
h = torch.arange(0, scores.size(1), device=scores.device)
|
| 889 |
+
m = torch.arange(0, scores.size(2), device=scores.device)
|
| 890 |
+
n = torch.arange(0, scores.size(3), device=scores.device)
|
| 891 |
+
|
| 892 |
+
mask_graph = block_mask[-1]
|
| 893 |
+
# Gradient of the inline score_mod function, with respect to the scores
|
| 894 |
+
captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
|
| 895 |
+
out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers)
|
| 896 |
+
from torch.nn.attention.flex_attention import _vmap_for_bhqkv
|
| 897 |
+
|
| 898 |
+
# inputs are [score, b, h, q_idx, kv_idx, gradOut, ...]
|
| 899 |
+
# score and gradOut are "fully" batched
|
| 900 |
+
joint_score_mod = _vmap_for_bhqkv(
|
| 901 |
+
joint_graph,
|
| 902 |
+
prefix=(0,),
|
| 903 |
+
suffix=(0,) + captured_buffers_in_dim,
|
| 904 |
+
out_dims=out_dims,
|
| 905 |
+
)
|
| 906 |
+
with TransformGetItemToIndex():
|
| 907 |
+
grad_scores, _, _, _, _, *grad_score_mod_captured = joint_score_mod(
|
| 908 |
+
scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers
|
| 909 |
+
)
|
| 910 |
+
grad_scores = grad_scores * scale
|
| 911 |
+
grad_scores = grad_scores.to(query.dtype)
|
| 912 |
+
|
| 913 |
+
mask_mod = _vmap_for_bhqkv(
|
| 914 |
+
mask_graph, prefix=(), suffix=(None,) * len(mask_mod_other_buffers)
|
| 915 |
+
)
|
| 916 |
+
with TransformGetItemToIndex():
|
| 917 |
+
mask_scores = mask_mod(b, h, m, n, *mask_mod_other_buffers)
|
| 918 |
+
grad_scores = torch.where(
|
| 919 |
+
mask_scores, grad_scores, torch.tensor(0, dtype=query.dtype)
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
grad_query = grad_scores @ key
|
| 923 |
+
grad_key = grad_scores.transpose(-2, -1) @ query
|
| 924 |
+
|
| 925 |
+
# Reduce DK, DV along broadcasted heads.
|
| 926 |
+
grad_key = grad_key.view(
|
| 927 |
+
grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1)
|
| 928 |
+
)
|
| 929 |
+
grad_value = grad_value.view(
|
| 930 |
+
grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1)
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
grad_key = torch.sum(grad_key, 2, keepdim=False)
|
| 934 |
+
grad_value = torch.sum(grad_value, 2, keepdim=False)
|
| 935 |
+
|
| 936 |
+
# Fill to correctly strided outputs
|
| 937 |
+
actual_grad_query.copy_(grad_query)
|
| 938 |
+
actual_grad_key.copy_(grad_key)
|
| 939 |
+
actual_grad_value.copy_(grad_value)
|
| 940 |
+
|
| 941 |
+
if Bq != Bkv:
|
| 942 |
+
assert (
|
| 943 |
+
Bq > 1 and Bkv == 1
|
| 944 |
+
), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}"
|
| 945 |
+
|
| 946 |
+
actual_grad_key = torch.sum(actual_grad_key, 0, keepdim=True)
|
| 947 |
+
actual_grad_value = torch.sum(actual_grad_value, 0, keepdim=True)
|
| 948 |
+
|
| 949 |
+
score_mod_other_buffer_grads = [
|
| 950 |
+
actual_grad.copy_(grad) if isinstance(actual_grad, torch.Tensor) else None
|
| 951 |
+
for actual_grad, grad in zip(
|
| 952 |
+
actual_grad_score_mod_captured, grad_score_mod_captured
|
| 953 |
+
)
|
| 954 |
+
]
|
| 955 |
+
|
| 956 |
+
return (
|
| 957 |
+
actual_grad_query,
|
| 958 |
+
actual_grad_key,
|
| 959 |
+
actual_grad_value,
|
| 960 |
+
tuple(score_mod_other_buffer_grads),
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
def trace_flex_attention_backward(
|
| 965 |
+
proxy_mode: ProxyTorchDispatchMode,
|
| 966 |
+
query: torch.Tensor,
|
| 967 |
+
key: torch.Tensor,
|
| 968 |
+
value: torch.Tensor,
|
| 969 |
+
out: torch.Tensor,
|
| 970 |
+
logsumexp: torch.Tensor,
|
| 971 |
+
grad_out: torch.Tensor,
|
| 972 |
+
grad_logsumexp: torch.Tensor,
|
| 973 |
+
fw_graph: Union[Callable, GraphModule],
|
| 974 |
+
joint_graph: GraphModule,
|
| 975 |
+
block_mask: tuple,
|
| 976 |
+
scale: float,
|
| 977 |
+
kernel_options: dict[str, Any],
|
| 978 |
+
score_mod_other_buffers: tuple = (),
|
| 979 |
+
mask_mod_other_buffers: tuple = (),
|
| 980 |
+
) -> tuple[
|
| 981 |
+
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
| 982 |
+
]:
|
| 983 |
+
"""We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
|
| 984 |
+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
| 985 |
+
|
| 986 |
+
example_out = flex_attention_backward(
|
| 987 |
+
query,
|
| 988 |
+
key,
|
| 989 |
+
value,
|
| 990 |
+
out,
|
| 991 |
+
logsumexp,
|
| 992 |
+
grad_out,
|
| 993 |
+
grad_logsumexp,
|
| 994 |
+
fw_graph,
|
| 995 |
+
joint_graph,
|
| 996 |
+
block_mask,
|
| 997 |
+
scale,
|
| 998 |
+
kernel_options,
|
| 999 |
+
score_mod_other_buffers,
|
| 1000 |
+
mask_mod_other_buffers,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
requires_grad = any(pytree.tree_map(lambda x: x.requires_grad, (query, key)))
|
| 1004 |
+
fw_example_vals = [query.new_zeros((), requires_grad=requires_grad)] + [
|
| 1005 |
+
query.new_zeros((), dtype=torch.int) for _ in range(4)
|
| 1006 |
+
]
|
| 1007 |
+
bw_example_vals = fw_example_vals + [query.new_zeros(())]
|
| 1008 |
+
mask_example_vals = [query.new_zeros((), dtype=torch.int) for _ in range(4)]
|
| 1009 |
+
mask_graph = block_mask[-1]
|
| 1010 |
+
with TransformGetItemToIndex():
|
| 1011 |
+
# There's no active make_fx during the compiled autograd graph's initial capture
|
| 1012 |
+
fw_graph = _maybe_reenter_make_fx(fw_graph)(
|
| 1013 |
+
*fw_example_vals, *score_mod_other_buffers
|
| 1014 |
+
)
|
| 1015 |
+
joint_graph = _maybe_reenter_make_fx(joint_graph)(
|
| 1016 |
+
*bw_example_vals, *score_mod_other_buffers
|
| 1017 |
+
)
|
| 1018 |
+
mask_graph = _maybe_reenter_make_fx(mask_graph)(
|
| 1019 |
+
*mask_example_vals, *mask_mod_other_buffers
|
| 1020 |
+
)
|
| 1021 |
+
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
| 1022 |
+
block_mask = block_mask[:-1] + (mask_graph,)
|
| 1023 |
+
|
| 1024 |
+
qualname = proxy_mode.tracer.get_fresh_qualname("fw_graph")
|
| 1025 |
+
proxy_mode.tracer.root.register_module(qualname, fw_graph) # type: ignore[arg-type]
|
| 1026 |
+
qualname = proxy_mode.tracer.get_fresh_qualname("joint_graph")
|
| 1027 |
+
proxy_mode.tracer.root.register_module(qualname, joint_graph)
|
| 1028 |
+
qualname = proxy_mode.tracer.get_fresh_qualname("mask_graph")
|
| 1029 |
+
proxy_mode.tracer.root.register_module(qualname, mask_graph)
|
| 1030 |
+
|
| 1031 |
+
node_args = (
|
| 1032 |
+
query,
|
| 1033 |
+
key,
|
| 1034 |
+
value,
|
| 1035 |
+
out,
|
| 1036 |
+
logsumexp,
|
| 1037 |
+
grad_out,
|
| 1038 |
+
grad_logsumexp,
|
| 1039 |
+
fw_graph,
|
| 1040 |
+
joint_graph,
|
| 1041 |
+
block_mask,
|
| 1042 |
+
scale,
|
| 1043 |
+
kernel_options,
|
| 1044 |
+
score_mod_other_buffers,
|
| 1045 |
+
mask_mod_other_buffers,
|
| 1046 |
+
)
|
| 1047 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
| 1048 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 1049 |
+
"call_function",
|
| 1050 |
+
flex_attention_backward,
|
| 1051 |
+
proxy_args,
|
| 1052 |
+
{},
|
| 1053 |
+
name="flex_attention_backward",
|
| 1054 |
+
)
|
| 1055 |
+
return track_tensor_tree(
|
| 1056 |
+
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
|
| 1060 |
+
@flex_attention_backward.py_impl(ProxyTorchDispatchMode)
|
| 1061 |
+
def flex_attention_backward_proxy_torch_dispatch_mode(
|
| 1062 |
+
mode: ProxyTorchDispatchMode,
|
| 1063 |
+
query: torch.Tensor,
|
| 1064 |
+
key: torch.Tensor,
|
| 1065 |
+
value: torch.Tensor,
|
| 1066 |
+
out: torch.Tensor,
|
| 1067 |
+
logsumexp: torch.Tensor,
|
| 1068 |
+
grad_out: torch.Tensor,
|
| 1069 |
+
grad_logsumexp: torch.Tensor,
|
| 1070 |
+
fw_graph: Union[Callable, GraphModule],
|
| 1071 |
+
joint_graph: GraphModule,
|
| 1072 |
+
block_mask: tuple,
|
| 1073 |
+
scale: float,
|
| 1074 |
+
kernel_options: dict[str, Any],
|
| 1075 |
+
score_mod_other_buffers: tuple = (),
|
| 1076 |
+
mask_mod_other_buffers: tuple = (),
|
| 1077 |
+
) -> tuple[
|
| 1078 |
+
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
| 1079 |
+
]:
|
| 1080 |
+
assert mode is not None, "Mode should always be enabled for python fallback key"
|
| 1081 |
+
return trace_flex_attention_backward(
|
| 1082 |
+
mode,
|
| 1083 |
+
query,
|
| 1084 |
+
key,
|
| 1085 |
+
value,
|
| 1086 |
+
out,
|
| 1087 |
+
logsumexp,
|
| 1088 |
+
grad_out,
|
| 1089 |
+
grad_logsumexp,
|
| 1090 |
+
fw_graph,
|
| 1091 |
+
joint_graph,
|
| 1092 |
+
block_mask,
|
| 1093 |
+
scale,
|
| 1094 |
+
kernel_options,
|
| 1095 |
+
score_mod_other_buffers,
|
| 1096 |
+
mask_mod_other_buffers,
|
| 1097 |
+
)
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
@flex_attention_backward.py_functionalize_impl
|
| 1101 |
+
def flex_attention_backward_functionalize(
|
| 1102 |
+
ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,
|
| 1103 |
+
query: torch.Tensor,
|
| 1104 |
+
key: torch.Tensor,
|
| 1105 |
+
value: torch.Tensor,
|
| 1106 |
+
out: torch.Tensor,
|
| 1107 |
+
logsumexp: torch.Tensor,
|
| 1108 |
+
grad_out: torch.Tensor,
|
| 1109 |
+
grad_logsumexp: torch.Tensor,
|
| 1110 |
+
fw_graph: Union[Callable, GraphModule],
|
| 1111 |
+
joint_graph: GraphModule,
|
| 1112 |
+
block_mask: tuple,
|
| 1113 |
+
scale: float,
|
| 1114 |
+
kernel_options: dict[str, Any],
|
| 1115 |
+
score_mod_other_buffers: tuple = (),
|
| 1116 |
+
mask_mod_other_buffers: tuple = (),
|
| 1117 |
+
) -> tuple[
|
| 1118 |
+
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
| 1119 |
+
]:
|
| 1120 |
+
"""Defines the functionalization rules for the flex_attention operator.
|
| 1121 |
+
|
| 1122 |
+
Write now we are unwrapping each tensor and then redispatching to the next,
|
| 1123 |
+
since we know that the forward score mod function is assured to be free of mutations
|
| 1124 |
+
to the other_buffers, we skip that mutate check and go straight to redispatching.
|
| 1125 |
+
"""
|
| 1126 |
+
|
| 1127 |
+
if has_user_subclass(
|
| 1128 |
+
(
|
| 1129 |
+
query,
|
| 1130 |
+
key,
|
| 1131 |
+
value,
|
| 1132 |
+
out,
|
| 1133 |
+
logsumexp,
|
| 1134 |
+
grad_out,
|
| 1135 |
+
grad_logsumexp,
|
| 1136 |
+
block_mask,
|
| 1137 |
+
scale,
|
| 1138 |
+
kernel_options,
|
| 1139 |
+
score_mod_other_buffers,
|
| 1140 |
+
mask_mod_other_buffers,
|
| 1141 |
+
),
|
| 1142 |
+
allowed_subclasses=(FakeTensor, FunctionalTensor),
|
| 1143 |
+
):
|
| 1144 |
+
return NotImplemented
|
| 1145 |
+
query_unwrapped = ctx.unwrap_tensors(query)
|
| 1146 |
+
key_unwrapped = ctx.unwrap_tensors(key)
|
| 1147 |
+
value_unwrapped = ctx.unwrap_tensors(value)
|
| 1148 |
+
out_unwrapped = ctx.unwrap_tensors(out)
|
| 1149 |
+
logsumexp_unwrapped = ctx.unwrap_tensors(logsumexp)
|
| 1150 |
+
grad_out_unwrapped = ctx.unwrap_tensors(grad_out)
|
| 1151 |
+
grad_logsumexp_unwrapped = ctx.unwrap_tensors(grad_logsumexp)
|
| 1152 |
+
block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
|
| 1153 |
+
score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
|
| 1154 |
+
mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
|
| 1155 |
+
|
| 1156 |
+
# Appease the mypy overlords
|
| 1157 |
+
assert isinstance(query_unwrapped, torch.Tensor)
|
| 1158 |
+
assert isinstance(key_unwrapped, torch.Tensor)
|
| 1159 |
+
assert isinstance(value_unwrapped, torch.Tensor)
|
| 1160 |
+
assert isinstance(out_unwrapped, torch.Tensor)
|
| 1161 |
+
assert isinstance(logsumexp_unwrapped, torch.Tensor)
|
| 1162 |
+
assert isinstance(grad_out_unwrapped, torch.Tensor)
|
| 1163 |
+
assert isinstance(grad_logsumexp_unwrapped, torch.Tensor)
|
| 1164 |
+
assert isinstance(block_mask_unwrapped, tuple)
|
| 1165 |
+
assert isinstance(score_mod_other_buffers_unwrapped, tuple)
|
| 1166 |
+
assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
|
| 1167 |
+
|
| 1168 |
+
with ctx.redispatch_to_next():
|
| 1169 |
+
functional_fw_graph = ctx.functionalize(fw_graph)
|
| 1170 |
+
functional_joint_graph = ctx.functionalize(joint_graph)
|
| 1171 |
+
|
| 1172 |
+
(
|
| 1173 |
+
grad_query,
|
| 1174 |
+
grad_key,
|
| 1175 |
+
grad_value,
|
| 1176 |
+
grad_score_mod_captured,
|
| 1177 |
+
) = flex_attention_backward(
|
| 1178 |
+
query_unwrapped,
|
| 1179 |
+
key_unwrapped,
|
| 1180 |
+
value_unwrapped,
|
| 1181 |
+
out_unwrapped,
|
| 1182 |
+
logsumexp_unwrapped,
|
| 1183 |
+
grad_out_unwrapped,
|
| 1184 |
+
grad_logsumexp_unwrapped,
|
| 1185 |
+
functional_fw_graph, # type: ignore[arg-type]
|
| 1186 |
+
functional_joint_graph, # type: ignore[arg-type]
|
| 1187 |
+
block_mask_unwrapped,
|
| 1188 |
+
scale,
|
| 1189 |
+
kernel_options,
|
| 1190 |
+
score_mod_other_buffers_unwrapped,
|
| 1191 |
+
mask_mod_other_buffers_unwrapped,
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
return ctx.wrap_tensors((grad_query, grad_key, grad_value, grad_score_mod_captured)) # type: ignore[return-value,arg-type]
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
@register_fake(flex_attention_backward)
|
| 1198 |
+
def flex_attention_backward_fake_tensor_mode(
|
| 1199 |
+
query: torch.Tensor,
|
| 1200 |
+
key: torch.Tensor,
|
| 1201 |
+
value: torch.Tensor,
|
| 1202 |
+
out: torch.Tensor,
|
| 1203 |
+
logsumexp: torch.Tensor,
|
| 1204 |
+
grad_out: torch.Tensor,
|
| 1205 |
+
grad_logsumexp: torch.Tensor,
|
| 1206 |
+
fw_graph: Union[Callable, GraphModule],
|
| 1207 |
+
joint_graph: GraphModule,
|
| 1208 |
+
block_mask: tuple,
|
| 1209 |
+
scale: float,
|
| 1210 |
+
kernel_options: dict[str, Any],
|
| 1211 |
+
score_mod_other_buffers: tuple = (),
|
| 1212 |
+
mask_mod_other_buffers: tuple = (),
|
| 1213 |
+
) -> tuple[
|
| 1214 |
+
torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...]
|
| 1215 |
+
]:
|
| 1216 |
+
if has_user_subclass(
|
| 1217 |
+
(
|
| 1218 |
+
query,
|
| 1219 |
+
key,
|
| 1220 |
+
value,
|
| 1221 |
+
out,
|
| 1222 |
+
logsumexp,
|
| 1223 |
+
grad_out,
|
| 1224 |
+
grad_logsumexp,
|
| 1225 |
+
block_mask,
|
| 1226 |
+
scale,
|
| 1227 |
+
kernel_options,
|
| 1228 |
+
score_mod_other_buffers,
|
| 1229 |
+
mask_mod_other_buffers,
|
| 1230 |
+
),
|
| 1231 |
+
allowed_subclasses=(FakeTensor,),
|
| 1232 |
+
):
|
| 1233 |
+
return NotImplemented
|
| 1234 |
+
Bq, _, _, qk_head_dim = query.shape
|
| 1235 |
+
Bkv, Hkv, seq_len_kv, v_head_dim = value.shape
|
| 1236 |
+
|
| 1237 |
+
grad_query = torch.empty_like(query)
|
| 1238 |
+
# zeros_and_scatter creates a contiguous zeros tensor -> contiguous_format
|
| 1239 |
+
grad_score_mod_captured = tuple(
|
| 1240 |
+
[
|
| 1241 |
+
(
|
| 1242 |
+
torch.empty_like(buffer, memory_format=torch.contiguous_format)
|
| 1243 |
+
if isinstance(buffer, torch.Tensor) and buffer.requires_grad
|
| 1244 |
+
else None
|
| 1245 |
+
)
|
| 1246 |
+
for buffer in score_mod_other_buffers
|
| 1247 |
+
]
|
| 1248 |
+
)
|
| 1249 |
+
|
| 1250 |
+
broadcasted_grad_key = key.new_empty((Bq, Hkv, seq_len_kv, qk_head_dim))
|
| 1251 |
+
broadcasted_grad_key = _permute_strides(broadcasted_grad_key, key.stride())
|
| 1252 |
+
|
| 1253 |
+
broadcasted_grad_value = value.new_empty((Bq, Hkv, seq_len_kv, v_head_dim))
|
| 1254 |
+
broadcasted_grad_value = _permute_strides(broadcasted_grad_value, value.stride())
|
| 1255 |
+
|
| 1256 |
+
if Bq > 1 and Bkv == 1:
|
| 1257 |
+
grad_key = torch.sum(broadcasted_grad_key, dim=0, keepdim=True)
|
| 1258 |
+
grad_value = torch.sum(broadcasted_grad_value, dim=0, keepdim=True)
|
| 1259 |
+
else:
|
| 1260 |
+
grad_key = broadcasted_grad_key
|
| 1261 |
+
grad_value = broadcasted_grad_value
|
| 1262 |
+
|
| 1263 |
+
return grad_query, grad_key, grad_value, grad_score_mod_captured
|
| 1264 |
+
|
| 1265 |
+
|
| 1266 |
+
flex_attention_backward.py_autograd_impl(
|
| 1267 |
+
autograd_not_implemented(flex_attention_backward, deferred_error=True)
|
| 1268 |
+
)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/foreach_map.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
from torch._higher_order_ops.base_hop import BaseHOP, FunctionWithNoFreeVars
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ForeachMap(BaseHOP):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__("foreach_map")
|
| 11 |
+
|
| 12 |
+
def __call__(self, fn, *operands, **kwargs): # type: ignore[override]
|
| 13 |
+
fn = FunctionWithNoFreeVars(fn)
|
| 14 |
+
return super().__call__(fn, *operands, **kwargs)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
_foreach_map = ForeachMap()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def foreach_map(op: Callable, *operands: Any, **kwargs: dict[str, Any]):
|
| 21 |
+
from torch._dynamo.polyfills import foreach_map_fn
|
| 22 |
+
|
| 23 |
+
return _foreach_map(foreach_map_fn, op, *operands, **kwargs)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/hints_wrap.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils._pytree as pytree
|
| 4 |
+
from torch._C import DispatchKey
|
| 5 |
+
from torch._higher_order_ops.utils import (
|
| 6 |
+
autograd_not_implemented,
|
| 7 |
+
reenter_make_fx,
|
| 8 |
+
unique_graph_id,
|
| 9 |
+
)
|
| 10 |
+
from torch._ops import HigherOrderOperator
|
| 11 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 12 |
+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# used for wrapping a function/op with context hints
|
| 16 |
+
class HintsWrapper(HigherOrderOperator):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super().__init__("hints_wrapper")
|
| 19 |
+
|
| 20 |
+
def __call__(self, body_fn, args, kwargs, hints):
|
| 21 |
+
r"""
|
| 22 |
+
Call implementation of hints_wrapper
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
body_fn (Callable): A callable function that is within the scope
|
| 26 |
+
that is being traced.
|
| 27 |
+
|
| 28 |
+
args (Tuple of torch.Tensor/int/float/bool): A tuple of inputs to
|
| 29 |
+
body_fn.
|
| 30 |
+
|
| 31 |
+
kwargs (dict): Keyword argument to the body_fn.
|
| 32 |
+
|
| 33 |
+
hints (dict): A dict of context hints which could be passed to
|
| 34 |
+
backend compiler.
|
| 35 |
+
"""
|
| 36 |
+
if not isinstance(args, tuple):
|
| 37 |
+
raise RuntimeError(f"args must be a tuple, got {type(args)}")
|
| 38 |
+
|
| 39 |
+
if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args):
|
| 40 |
+
raise RuntimeError(
|
| 41 |
+
"args must be a tuple of tensors, ints, floats, or bools, got "
|
| 42 |
+
f"{args}"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if not isinstance(kwargs, dict):
|
| 46 |
+
raise RuntimeError(f"kwargs must be a dict, got {type(kwargs)}")
|
| 47 |
+
|
| 48 |
+
if len(kwargs) > 0:
|
| 49 |
+
raise RuntimeError(
|
| 50 |
+
f"kwargs except for hints are not supported, got {kwargs}"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if not isinstance(hints, dict):
|
| 54 |
+
raise RuntimeError(f"hints must be a dict, got {type(hints)}")
|
| 55 |
+
|
| 56 |
+
for k, v in hints.items():
|
| 57 |
+
if not isinstance(k, str):
|
| 58 |
+
raise RuntimeError(f"hints key must be a str, got {k}.")
|
| 59 |
+
|
| 60 |
+
if not isinstance(v, (int, float, bool, str)):
|
| 61 |
+
raise RuntimeError(
|
| 62 |
+
"hints must be a dict containing int, float, bool or str "
|
| 63 |
+
f"value, got value {v} for key {k}."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
return super().__call__(body_fn, args, kwargs, hints)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
hints_wrapper = HintsWrapper()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@hints_wrapper.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 73 |
+
def hints_wrapper_dense(body_fn, args, kwargs, hints):
|
| 74 |
+
return body_fn(*args, **kwargs)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
hints_wrapper.py_autograd_impl(
|
| 78 |
+
autograd_not_implemented(hints_wrapper, deferred_error=True)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@hints_wrapper.py_impl(FakeTensorMode)
|
| 83 |
+
def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints):
|
| 84 |
+
flat_args = pytree.tree_leaves(args)
|
| 85 |
+
with mode:
|
| 86 |
+
return body_func(*flat_args, **kwargs)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@hints_wrapper.py_functionalize_impl
|
| 90 |
+
def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints):
|
| 91 |
+
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
| 92 |
+
|
| 93 |
+
unwrapped_args = ctx.unwrap_tensors(args)
|
| 94 |
+
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
| 95 |
+
unwrapped_hints = ctx.unwrap_tensors(hints)
|
| 96 |
+
with ctx.redispatch_to_next():
|
| 97 |
+
functional_body_fn = ctx.functionalize(body_fn)
|
| 98 |
+
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
| 99 |
+
_check_alias_and_mutation(
|
| 100 |
+
body_fn, unwrapped_args, "hints_wrapper", pre_dispatch
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
outputs = hints_wrapper(
|
| 104 |
+
functional_body_fn,
|
| 105 |
+
unwrapped_args,
|
| 106 |
+
unwrapped_kwargs,
|
| 107 |
+
unwrapped_hints,
|
| 108 |
+
)
|
| 109 |
+
return ctx.wrap_tensors(outputs)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def trace_hints_wrapper(proxy_mode, hints_wrapper, body_fn, args, kwargs, hints):
|
| 113 |
+
flat_args = tuple(pytree.tree_leaves(args))
|
| 114 |
+
body_graph = reenter_make_fx(body_fn)(*flat_args, **kwargs)
|
| 115 |
+
|
| 116 |
+
_, body_graph_name = unique_graph_id(proxy_mode, prefix="hints_wrapper_body_graph")
|
| 117 |
+
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
|
| 118 |
+
|
| 119 |
+
new_args: tuple = (body_graph, flat_args, {})
|
| 120 |
+
# merge hints into kwargs
|
| 121 |
+
new_kwargs = {}
|
| 122 |
+
new_kwargs["hints"] = hints
|
| 123 |
+
|
| 124 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_args)
|
| 125 |
+
proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_kwargs)
|
| 126 |
+
|
| 127 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 128 |
+
"call_function", hints_wrapper, proxy_args, proxy_kwargs, name="hints_wrapper"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
out = body_fn(*flat_args, **kwargs)
|
| 132 |
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@hints_wrapper.py_impl(ProxyTorchDispatchMode)
|
| 136 |
+
def inner(proxy_mode, body_fn, args, kwargs, hints):
|
| 137 |
+
if proxy_mode.enable_tracing:
|
| 138 |
+
return trace_hints_wrapper(
|
| 139 |
+
proxy_mode, hints_wrapper, body_fn, args, kwargs, hints
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
return hints_wrapper(body_fn, args, kwargs, hints)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/invoke_subgraph.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import contextlib
|
| 5 |
+
from contextlib import nullcontext
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Optional, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils._pytree as pytree
|
| 11 |
+
from torch._C import DispatchKey
|
| 12 |
+
from torch._dispatch.python import suspend_functionalization
|
| 13 |
+
from torch._higher_order_ops.utils import (
|
| 14 |
+
_from_fun,
|
| 15 |
+
_maybe_reenter_make_fx,
|
| 16 |
+
_set_compilation_env,
|
| 17 |
+
clone_outputs_aliasing_inputs,
|
| 18 |
+
FunctionalizeCtxWrapper,
|
| 19 |
+
get_dummy_aot_autograd_config,
|
| 20 |
+
HopInstance,
|
| 21 |
+
prepare_fw_with_masks,
|
| 22 |
+
reenter_make_fx,
|
| 23 |
+
register_fake,
|
| 24 |
+
save_tensors_and_symints_for_backward,
|
| 25 |
+
saved_tensors_and_symints,
|
| 26 |
+
)
|
| 27 |
+
from torch._ops import HigherOrderOperator
|
| 28 |
+
from torch._subclasses.functional_tensor import disable_functional_mode
|
| 29 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 30 |
+
_temp_remove_metadata_torch_function_mode,
|
| 31 |
+
_temp_remove_pre_dispatch_torch_function_mode,
|
| 32 |
+
disable_proxy_modes_tracing,
|
| 33 |
+
ProxyTorchDispatchMode,
|
| 34 |
+
track_tensor_tree,
|
| 35 |
+
)
|
| 36 |
+
from torch.fx.graph_module import GraphModule
|
| 37 |
+
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
invoke_subgraph_counter = 0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# During the tracing of the joint graph, we construct this information. This is
|
| 44 |
+
# used to filter out grad_outs/tangents in the `backward` method of
|
| 45 |
+
# InvokeSubgraphAutogradOp.
|
| 46 |
+
@dataclass
|
| 47 |
+
class OutputMetadata:
|
| 48 |
+
num_fw_outs: Optional[int] = None
|
| 49 |
+
indexes_with_none: set[int] = field(default_factory=set)
|
| 50 |
+
indexes_with_no_grad: set[int] = field(default_factory=set)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class InvokeSubgraphHOP(HigherOrderOperator):
|
| 54 |
+
def __init__(self) -> None:
|
| 55 |
+
# Invoke subgraph does not have any state, it is just a wrapper over a
|
| 56 |
+
# subgraph, so we can safely cache the HOP.
|
| 57 |
+
super().__init__("invoke_subgraph", cacheable=True)
|
| 58 |
+
# This is used by the fake tensor cache key validator to extract the
|
| 59 |
+
# subgraph and iterate over the nodes to find if all nodes are fake
|
| 60 |
+
# tensor cacheable.
|
| 61 |
+
self.subgraph_indexes = [
|
| 62 |
+
0,
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
# identifier is setup by upper part of the stack. This helps us in
|
| 66 |
+
# identifying two invoke_subgraph calls have same subgraph.
|
| 67 |
+
def __call__(
|
| 68 |
+
self,
|
| 69 |
+
subgraph: Union[GraphModule, FunctionalizeCtxWrapper],
|
| 70 |
+
identifier: Optional[str],
|
| 71 |
+
*operands,
|
| 72 |
+
):
|
| 73 |
+
assert identifier is None or isinstance(
|
| 74 |
+
identifier, str
|
| 75 |
+
), "identifier must be a None or a string"
|
| 76 |
+
|
| 77 |
+
assert all(
|
| 78 |
+
isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands
|
| 79 |
+
), f"invoke_subgraph operands must be a list of tensors/ints/SymInts {operands}"
|
| 80 |
+
|
| 81 |
+
return super().__call__(subgraph, identifier, *operands)
|
| 82 |
+
|
| 83 |
+
def gen_schema(self, subgraph, identifier, *operands):
|
| 84 |
+
from torch._higher_order_ops.schema import HopSchemaGenerator
|
| 85 |
+
from torch._higher_order_ops.utils import (
|
| 86 |
+
check_input_alias_and_mutation_return_outputs,
|
| 87 |
+
materialize_as_graph,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
gm: torch.fx.GraphModule = (
|
| 91 |
+
subgraph
|
| 92 |
+
if isinstance(subgraph, torch.fx.GraphModule)
|
| 93 |
+
else materialize_as_graph(subgraph, operands)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
schema_gen = HopSchemaGenerator(self)
|
| 97 |
+
schema_gen.add_arg("subgraph", gm)
|
| 98 |
+
schema_gen.add_arg("identifier", identifier)
|
| 99 |
+
(
|
| 100 |
+
_,
|
| 101 |
+
_,
|
| 102 |
+
_,
|
| 103 |
+
mutated_inputs,
|
| 104 |
+
outputs,
|
| 105 |
+
) = check_input_alias_and_mutation_return_outputs(gm, operands)
|
| 106 |
+
for idx, arg in enumerate(operands):
|
| 107 |
+
schema_gen.add_arg(f"arg{idx}", arg, is_mutated=idx in mutated_inputs)
|
| 108 |
+
for out in outputs:
|
| 109 |
+
schema_gen.add_output(out)
|
| 110 |
+
|
| 111 |
+
return schema_gen.gen_schema()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
invoke_subgraph = InvokeSubgraphHOP()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def invoke_subgraph_placeholder(func, *args, **kwargs):
|
| 118 |
+
if torch.compiler.is_dynamo_compiling():
|
| 119 |
+
# This is just a placeholder for Dynamo to replace with invoke_subgraph
|
| 120 |
+
raise RuntimeError("invoke_subgraph should not be called directly in Dynamo")
|
| 121 |
+
|
| 122 |
+
if torch.compiler.is_compiling():
|
| 123 |
+
# For non-strict export tracing, we still want to go through Dynamo
|
| 124 |
+
from torch._dynamo.backends.debugging import (
|
| 125 |
+
make_eager_backend_with_torch_function_mode,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def _invoke_subgraph_placeholder_wrapper(func, args):
|
| 129 |
+
return invoke_subgraph_placeholder(func, *args)
|
| 130 |
+
|
| 131 |
+
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode():
|
| 132 |
+
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
| 133 |
+
if metadata_mode:
|
| 134 |
+
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
|
| 135 |
+
else:
|
| 136 |
+
backend = "eager"
|
| 137 |
+
|
| 138 |
+
return torch.compile(
|
| 139 |
+
_invoke_subgraph_placeholder_wrapper,
|
| 140 |
+
backend=backend,
|
| 141 |
+
fullgraph=True,
|
| 142 |
+
)(func, args)
|
| 143 |
+
|
| 144 |
+
return func(*args, **kwargs)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def mark_compile_region(fn=None):
|
| 148 |
+
"""
|
| 149 |
+
This wrapper instructs torch.compile to compile the wrapped region once and
|
| 150 |
+
reuse the compiled artifact, instead of the usual way of aggressively
|
| 151 |
+
inlining the function.
|
| 152 |
+
|
| 153 |
+
Under the hood, it tells TorchDynamo to use InvokeSubgraph HOP for the
|
| 154 |
+
region. For PyTorch eager, this is a no-op.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def wrap(func):
|
| 158 |
+
def inner(*args, **kwargs):
|
| 159 |
+
# Get the innermost function to avoid nested compile regions
|
| 160 |
+
inner_func = func
|
| 161 |
+
while hasattr(inner_func, "__marked_compile_region_fn__"):
|
| 162 |
+
inner_func = inner_func.__marked_compile_region_fn__
|
| 163 |
+
return invoke_subgraph_placeholder(inner_func, *args, **kwargs)
|
| 164 |
+
|
| 165 |
+
inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined]
|
| 166 |
+
|
| 167 |
+
return inner
|
| 168 |
+
|
| 169 |
+
if fn:
|
| 170 |
+
return wrap(fn)
|
| 171 |
+
else:
|
| 172 |
+
return wrap
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_invoke_subgraph_cache():
|
| 176 |
+
cache = None
|
| 177 |
+
if tracing_ctx := torch._guards.TracingContext.try_get():
|
| 178 |
+
cache = tracing_ctx.hop_dispatch_set_cache.get_cache(invoke_subgraph)
|
| 179 |
+
return cache
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra
|
| 183 |
+
def trace_joint_graph(fn, fw_inputs, fw_outputs):
|
| 184 |
+
"""
|
| 185 |
+
Naively trace out a joint graph. This simplifies the reconstruction of joint
|
| 186 |
+
graph in the min-cut partitioner later on.
|
| 187 |
+
"""
|
| 188 |
+
from torch._functorch.aot_autograd import create_joint
|
| 189 |
+
|
| 190 |
+
dummy_aot_config = get_dummy_aot_autograd_config()
|
| 191 |
+
|
| 192 |
+
# This joint_fn is inserted as the backward graph as is. This simplifies the
|
| 193 |
+
# min-cut partitioner work later on.
|
| 194 |
+
# Input signature - (*primals, *tangents)
|
| 195 |
+
# Output signature - (*grads, *fw_outs)
|
| 196 |
+
# The output signature is deliberately kept grads first and fw_outs second.
|
| 197 |
+
# Having grads first makes the min-cut partitioner HOP graph stitching
|
| 198 |
+
# easier.
|
| 199 |
+
def joint_fn(*primals_and_tangents):
|
| 200 |
+
primals = primals_and_tangents[: len(fw_inputs)]
|
| 201 |
+
tangents = primals_and_tangents[len(fw_inputs) :]
|
| 202 |
+
|
| 203 |
+
fw_outs, grads = create_joint(
|
| 204 |
+
prepare_fw_with_masks(fn), aot_config=dummy_aot_config
|
| 205 |
+
)(primals, tangents)
|
| 206 |
+
|
| 207 |
+
maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
|
| 208 |
+
|
| 209 |
+
# return signature is deliberately kept (*grads, *fw_outs). This
|
| 210 |
+
# simplifies partitioning work later on.
|
| 211 |
+
return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs)))
|
| 212 |
+
|
| 213 |
+
primals = list(fw_inputs)
|
| 214 |
+
# This assumes that the tangent strides match fw_outputs strides. Check the
|
| 215 |
+
# InvokeSubgraphAutogradOp backward op for the contiguous call.
|
| 216 |
+
tangents = [_from_fun(out) for out in fw_outputs]
|
| 217 |
+
|
| 218 |
+
joint_operands = primals + tangents
|
| 219 |
+
|
| 220 |
+
return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra
|
| 224 |
+
def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
|
| 225 |
+
with suspend_functionalization(), disable_functional_mode():
|
| 226 |
+
with disable_proxy_modes_tracing():
|
| 227 |
+
# args are functional tensors, generate some example tensors
|
| 228 |
+
fw_inputs = pytree.tree_map(_from_fun, operands)
|
| 229 |
+
|
| 230 |
+
from torch._guards import detect_fake_mode
|
| 231 |
+
|
| 232 |
+
fake_mode = detect_fake_mode(fw_inputs)
|
| 233 |
+
context = (
|
| 234 |
+
nullcontext()
|
| 235 |
+
if fake_mode is None or fake_mode.shape_env is None
|
| 236 |
+
else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
with context:
|
| 240 |
+
fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
|
| 241 |
+
|
| 242 |
+
num_fw_outs = len(fw_outs)
|
| 243 |
+
|
| 244 |
+
# Collect the indexes of none in the output to check that the grad
|
| 245 |
+
# is None at the corresponding index in the backward. This check is
|
| 246 |
+
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
|
| 247 |
+
# Also collect the indexes of no_grad in the output to filter out
|
| 248 |
+
# the grad_outs in the `backward` method.
|
| 249 |
+
output_metadata = OutputMetadata()
|
| 250 |
+
|
| 251 |
+
output_metadata.num_fw_outs = num_fw_outs
|
| 252 |
+
for idx, fw_out in enumerate(fw_outs):
|
| 253 |
+
if fw_out is None:
|
| 254 |
+
output_metadata.indexes_with_none.add(idx)
|
| 255 |
+
elif not fw_out.requires_grad:
|
| 256 |
+
output_metadata.indexes_with_no_grad.add(idx)
|
| 257 |
+
|
| 258 |
+
if grad_outputs is None:
|
| 259 |
+
# Infer grad_outputs to be the same properties as the fw_outputs
|
| 260 |
+
# if they're not passed in
|
| 261 |
+
# Although fw_outs are equivalent to grad_outputs for tracing
|
| 262 |
+
# purposes, we have to carefully handle the None and fw_out that do
|
| 263 |
+
# not have require_grad. At those indexes, we will have None in the
|
| 264 |
+
# backward graph.
|
| 265 |
+
grad_outputs = fw_outs
|
| 266 |
+
grad_outputs = [grad for grad in grad_outputs if grad is not None]
|
| 267 |
+
grad_outputs = [grad for grad in grad_outputs if grad.requires_grad]
|
| 268 |
+
|
| 269 |
+
# Force grad_out to be contiguous. This is because at runtime,
|
| 270 |
+
# grad_out could have different strides than fw_outs. So, we
|
| 271 |
+
# force the grad_outs to be contiguous for both tracing and
|
| 272 |
+
# runtime.
|
| 273 |
+
grad_outputs = [grad.contiguous() for grad in grad_outputs]
|
| 274 |
+
|
| 275 |
+
if any(
|
| 276 |
+
not isinstance(out, torch.Tensor)
|
| 277 |
+
for out in grad_outputs
|
| 278 |
+
if out is not None
|
| 279 |
+
):
|
| 280 |
+
raise RuntimeError(
|
| 281 |
+
"Expect outputs of invoke_subgraph to only contains tensors or None. "
|
| 282 |
+
f"Got types {[type(out) for out in grad_outputs]}."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Trace the forward subgraph
|
| 286 |
+
fw_graph = _maybe_reenter_make_fx(subgraph)(*fw_inputs)
|
| 287 |
+
|
| 288 |
+
# Trace the joint graph and assign it to the bwd graph
|
| 289 |
+
bw_graph = trace_joint_graph(
|
| 290 |
+
subgraph,
|
| 291 |
+
fw_inputs,
|
| 292 |
+
grad_outputs,
|
| 293 |
+
)
|
| 294 |
+
return fw_graph, bw_graph, output_metadata
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def get_output_metadata(subgraph, *operands):
|
| 298 |
+
with suspend_functionalization(), disable_functional_mode():
|
| 299 |
+
with disable_proxy_modes_tracing():
|
| 300 |
+
# args are functional tensors, generate some example tensors
|
| 301 |
+
fw_inputs = pytree.tree_map(_from_fun, operands)
|
| 302 |
+
|
| 303 |
+
from torch._guards import detect_fake_mode
|
| 304 |
+
|
| 305 |
+
fake_mode = detect_fake_mode(fw_inputs)
|
| 306 |
+
context = (
|
| 307 |
+
nullcontext()
|
| 308 |
+
if fake_mode is None or fake_mode.shape_env is None
|
| 309 |
+
else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
with context:
|
| 313 |
+
fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs))
|
| 314 |
+
|
| 315 |
+
num_fw_outs = len(fw_outs)
|
| 316 |
+
|
| 317 |
+
# Collect the indexes of none in the output to check that the grad
|
| 318 |
+
# is None at the corresponding index in the backward. This check is
|
| 319 |
+
# performed in the autograd.Function - InvokeSubgraphAutogradOp.
|
| 320 |
+
# Also collect the indexes of no_grad in the output to filter out
|
| 321 |
+
# the grad_outs in the `backward` method.
|
| 322 |
+
output_metadata = OutputMetadata()
|
| 323 |
+
|
| 324 |
+
output_metadata.num_fw_outs = num_fw_outs
|
| 325 |
+
for idx, fw_out in enumerate(fw_outs):
|
| 326 |
+
if fw_out is None:
|
| 327 |
+
output_metadata.indexes_with_none.add(idx)
|
| 328 |
+
elif not fw_out.requires_grad:
|
| 329 |
+
output_metadata.indexes_with_no_grad.add(idx)
|
| 330 |
+
return output_metadata
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def trace_joint_graph_as_bwd(
|
| 334 |
+
subgraph, num_primals, joint_operands, include_key_set, exclude_key_set
|
| 335 |
+
):
|
| 336 |
+
"""
|
| 337 |
+
Naively trace out a joint graph. This simplifies the reconstruction of joint
|
| 338 |
+
graph in the min-cut partitioner later on.
|
| 339 |
+
"""
|
| 340 |
+
from torch._functorch.aot_autograd import create_joint
|
| 341 |
+
|
| 342 |
+
dummy_aot_config = get_dummy_aot_autograd_config()
|
| 343 |
+
|
| 344 |
+
if isinstance(subgraph, torch.fx.GraphModule):
|
| 345 |
+
|
| 346 |
+
def graph_with_interpreter(*args):
|
| 347 |
+
# Running graph with interpreter is needed for propagating the stack_trace
|
| 348 |
+
with torch.fx.traceback.preserve_node_meta():
|
| 349 |
+
return torch.fx.Interpreter(subgraph).run(*args)
|
| 350 |
+
|
| 351 |
+
fn = graph_with_interpreter
|
| 352 |
+
else:
|
| 353 |
+
fn = subgraph
|
| 354 |
+
|
| 355 |
+
# This joint_fn is inserted as the backward graph as is. This simplifies the
|
| 356 |
+
# min-cut partitioner work later on.
|
| 357 |
+
# Input signature - (*primals, *tangents)
|
| 358 |
+
# Output signature - (*grads, *fw_outs)
|
| 359 |
+
# The output signature is deliberately kept grads first and fw_outs second.
|
| 360 |
+
# Having grads first makes the min-cut partitioner HOP graph stitching
|
| 361 |
+
# easier.
|
| 362 |
+
def joint_fn(*primals_and_tangents):
|
| 363 |
+
primals = primals_and_tangents[:num_primals]
|
| 364 |
+
tangents = primals_and_tangents[num_primals:]
|
| 365 |
+
|
| 366 |
+
fw_outs, grads = create_joint(
|
| 367 |
+
prepare_fw_with_masks(fn), aot_config=dummy_aot_config
|
| 368 |
+
)(primals, tangents)
|
| 369 |
+
|
| 370 |
+
maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents)
|
| 371 |
+
|
| 372 |
+
# return signature is deliberately kept (*grads, *fw_outs). This
|
| 373 |
+
# simplifies partitioning work later on.
|
| 374 |
+
return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs)))
|
| 375 |
+
|
| 376 |
+
with suspend_functionalization(), disable_functional_mode():
|
| 377 |
+
with disable_proxy_modes_tracing():
|
| 378 |
+
joint_operands = [_from_fun(arg) for arg in joint_operands]
|
| 379 |
+
with contextlib.ExitStack() as stack:
|
| 380 |
+
stack.enter_context(
|
| 381 |
+
torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
|
| 382 |
+
)
|
| 383 |
+
with torch.enable_grad():
|
| 384 |
+
return _maybe_reenter_make_fx(joint_fn)(*joint_operands)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
| 388 |
+
"""
|
| 389 |
+
Saves the subgraph, i.e. original callable, in the forward method. And then
|
| 390 |
+
traces out a joint graph in the backward. This delaying of tracing in
|
| 391 |
+
backward, also called as lazy backward, ensures that the assumptions about
|
| 392 |
+
the grad_out strides and tensor-subclass-ness are already accounted for.
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
@staticmethod
|
| 396 |
+
def forward(
|
| 397 |
+
ctx,
|
| 398 |
+
subgraph,
|
| 399 |
+
identifier,
|
| 400 |
+
output_metadata,
|
| 401 |
+
*operands,
|
| 402 |
+
):
|
| 403 |
+
# We want to delay the backward graph construction until the backward.
|
| 404 |
+
# So in forward, we just run the fw callable as is. And save all the
|
| 405 |
+
# information necessary to construct the backward graph in the ctx.
|
| 406 |
+
ctx._subgraph = subgraph
|
| 407 |
+
ctx._identifier = identifier
|
| 408 |
+
ctx._output_metadata = output_metadata
|
| 409 |
+
# We snapshot the dispatch keys in forward for materializing the
|
| 410 |
+
# the bw_graph in backward.
|
| 411 |
+
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
|
| 412 |
+
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
|
| 413 |
+
|
| 414 |
+
save_tensors_and_symints_for_backward(ctx, operands)
|
| 415 |
+
|
| 416 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 417 |
+
out = invoke_subgraph(
|
| 418 |
+
subgraph,
|
| 419 |
+
f"fw_{identifier}",
|
| 420 |
+
*operands,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# Check that None is at expected indexes.
|
| 424 |
+
for idx, o in enumerate(out):
|
| 425 |
+
if o is None:
|
| 426 |
+
assert idx in output_metadata.indexes_with_none
|
| 427 |
+
|
| 428 |
+
return out
|
| 429 |
+
|
| 430 |
+
@staticmethod
|
| 431 |
+
def backward(
|
| 432 |
+
ctx,
|
| 433 |
+
*grad_outs,
|
| 434 |
+
):
|
| 435 |
+
from torch._dynamo.utils import dynamo_timed
|
| 436 |
+
|
| 437 |
+
subgraph = ctx._subgraph
|
| 438 |
+
identifier = ctx._identifier
|
| 439 |
+
output_metadata = ctx._output_metadata
|
| 440 |
+
primals = saved_tensors_and_symints(ctx)
|
| 441 |
+
|
| 442 |
+
# Filter out grads that are None or do not require_grad. This was
|
| 443 |
+
# the assumption we made during the tracing of joint_graph.
|
| 444 |
+
filtered_grad_outs = []
|
| 445 |
+
for idx, o in enumerate(grad_outs):
|
| 446 |
+
if o is None:
|
| 447 |
+
assert idx in output_metadata.indexes_with_none
|
| 448 |
+
elif idx in output_metadata.indexes_with_no_grad:
|
| 449 |
+
# Deliberately skip over the grad_outs which we know should be
|
| 450 |
+
# None because the corresponding fwd_out does not require_grad.
|
| 451 |
+
pass
|
| 452 |
+
else:
|
| 453 |
+
filtered_grad_outs.append(o)
|
| 454 |
+
filtered_grad_outs = tuple(filtered_grad_outs)
|
| 455 |
+
|
| 456 |
+
# Important note - Even though the forward graph can be same for
|
| 457 |
+
# different invoke_subgraphs, the backward graph can be different
|
| 458 |
+
# because the tangent strides can be different. So, here we cache on
|
| 459 |
+
# tangent_metadata in addition to identifier
|
| 460 |
+
from torch._guards import detect_fake_mode
|
| 461 |
+
from torch._subclasses._fake_tensor_utils import _CacheKeyState
|
| 462 |
+
from torch._subclasses.fake_tensor import extract_tensor_metadata
|
| 463 |
+
|
| 464 |
+
fake_mode = detect_fake_mode(primals + filtered_grad_outs)
|
| 465 |
+
state = _CacheKeyState(fake_mode.shape_env)
|
| 466 |
+
|
| 467 |
+
tangent_metadata: list[object] = []
|
| 468 |
+
for tangent in filtered_grad_outs:
|
| 469 |
+
metadata = extract_tensor_metadata(tangent)
|
| 470 |
+
metadata._flatten_into(tangent_metadata, fake_mode, state)
|
| 471 |
+
tangent_metadata = tuple(tangent_metadata)
|
| 472 |
+
|
| 473 |
+
# bw_graph is a joint graph with signature (*primals_and_tangents) and
|
| 474 |
+
# returns (*grads_and_fw_outs). To get the grads, we use the num_fw_outs
|
| 475 |
+
# to extract the grads.
|
| 476 |
+
primals_and_tangents = primals + filtered_grad_outs
|
| 477 |
+
|
| 478 |
+
# Check if we have already traced the bwd subgraph.
|
| 479 |
+
bw_graph = None
|
| 480 |
+
suffix = None
|
| 481 |
+
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
| 482 |
+
cache_hit = False
|
| 483 |
+
if invoke_subgraph_cache:
|
| 484 |
+
bw_graph, suffix = invoke_subgraph_cache.get_lazy_bwd_entry(
|
| 485 |
+
identifier, tangent_metadata
|
| 486 |
+
)
|
| 487 |
+
cache_hit = bw_graph is not None
|
| 488 |
+
|
| 489 |
+
if bw_graph is None:
|
| 490 |
+
assert suffix is None
|
| 491 |
+
with dynamo_timed(
|
| 492 |
+
"invoke_subgraph_trace_joint_graph", log_pt2_compile_event=True
|
| 493 |
+
):
|
| 494 |
+
bw_graph = trace_joint_graph_as_bwd(
|
| 495 |
+
subgraph,
|
| 496 |
+
len(primals),
|
| 497 |
+
primals_and_tangents,
|
| 498 |
+
ctx._fw_include_key_set,
|
| 499 |
+
ctx._fw_exclude_key_set,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
if invoke_subgraph_cache and not cache_hit:
|
| 503 |
+
suffix = invoke_subgraph_cache.add_lazy_bwd_entry(
|
| 504 |
+
identifier, tangent_metadata, bw_graph
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
grads = invoke_subgraph(
|
| 508 |
+
bw_graph, f"bw_{identifier}_{suffix}", *primals_and_tangents
|
| 509 |
+
)[: -output_metadata.num_fw_outs]
|
| 510 |
+
return None, None, None, *grads
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
@invoke_subgraph.py_autograd_impl
|
| 514 |
+
def _(subgraph, identifier, *operands):
|
| 515 |
+
# Check if we have already traced the subgraph.
|
| 516 |
+
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
| 517 |
+
if invoke_subgraph_cache:
|
| 518 |
+
if saved_autograd_fn := invoke_subgraph_cache.get_autograd_key_entry(
|
| 519 |
+
identifier
|
| 520 |
+
):
|
| 521 |
+
return saved_autograd_fn(*operands)
|
| 522 |
+
|
| 523 |
+
output_metadata = get_output_metadata(subgraph, *operands)
|
| 524 |
+
|
| 525 |
+
def autograd_fn_callable(*args):
|
| 526 |
+
return InvokeSubgraphAutogradOp.apply(
|
| 527 |
+
subgraph, identifier, output_metadata, *args
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
# Save the autograd_fn_callable in the dispatch set cache.
|
| 531 |
+
if invoke_subgraph_cache:
|
| 532 |
+
invoke_subgraph_cache.add_autograd_key_entry(identifier, autograd_fn_callable)
|
| 533 |
+
|
| 534 |
+
return autograd_fn_callable(*operands)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 538 |
+
def _(subgraph, identifier, *operands):
|
| 539 |
+
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
| 540 |
+
|
| 541 |
+
mode = _get_current_dispatch_mode()
|
| 542 |
+
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
| 543 |
+
return subgraph(*operands)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@invoke_subgraph.py_functionalize_impl
|
| 547 |
+
def _(ctx, subgraph, identifier, *operands):
|
| 548 |
+
from torch._higher_order_ops.auto_functionalize import (
|
| 549 |
+
can_auto_functionalize,
|
| 550 |
+
do_auto_functionalize_v2,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
unwrapped_operands = ctx.unwrap_tensors(operands)
|
| 554 |
+
hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands)
|
| 555 |
+
if can_auto_functionalize(hop_instance):
|
| 556 |
+
# NOTE: [auto_functionalize x invoke_subgraph caching]
|
| 557 |
+
# We call auto_functionalized_v2 to support input mutation of invoke_subgraph.
|
| 558 |
+
# See NOTE [Support input mutation of hops] for the overall design.
|
| 559 |
+
#
|
| 560 |
+
# invoke_subgraph is special because of its identifier based caching machanism.
|
| 561 |
+
# In invoke_subgraph's functionalization key implementation, we create a new
|
| 562 |
+
# identifer because the subgraph is replaced by FunctionWithNoFreeVars in a
|
| 563 |
+
# functional + epilogue form.
|
| 564 |
+
assert isinstance(identifier, str), identifier
|
| 565 |
+
return do_auto_functionalize_v2(
|
| 566 |
+
ctx.mode,
|
| 567 |
+
hop_instance,
|
| 568 |
+
(subgraph, "auto_functionalized_" + identifier, *operands),
|
| 569 |
+
{},
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
with ctx.redispatch_to_next():
|
| 573 |
+
# NB: There is an assumption that subgraph does not mutate inputs and
|
| 574 |
+
# there is no aliasing. Its Dynamo responsibility to prevent formation
|
| 575 |
+
# of invoke_subgraph ops if input aliasing/mutation is detected.
|
| 576 |
+
functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph)
|
| 577 |
+
out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands)
|
| 578 |
+
return ctx.wrap_tensors(out)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# Register the hop fake fn. This will be called in the fake_tensor _dispatch_impl.
|
| 582 |
+
@register_fake(invoke_subgraph)
|
| 583 |
+
def _(subgraph, identifier, *operands):
|
| 584 |
+
from torch._dynamo.utils import dynamo_timed
|
| 585 |
+
|
| 586 |
+
with dynamo_timed("invoke_subgraph_fake_tensor", log_pt2_compile_event=True):
|
| 587 |
+
return subgraph(*operands)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
@invoke_subgraph.py_impl(ProxyTorchDispatchMode)
|
| 591 |
+
def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands):
|
| 592 |
+
# Check if we have already traced the subgraph.
|
| 593 |
+
graph = None
|
| 594 |
+
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
| 595 |
+
if invoke_subgraph_cache:
|
| 596 |
+
graph = invoke_subgraph_cache.get_proxy_dispatch_entry(identifier)
|
| 597 |
+
|
| 598 |
+
if graph is None:
|
| 599 |
+
from torch._dynamo.utils import dynamo_timed
|
| 600 |
+
|
| 601 |
+
with dynamo_timed("invoke_subgraph_proxy_tensor", log_pt2_compile_event=True):
|
| 602 |
+
graph = reenter_make_fx(subgraph)(*operands)
|
| 603 |
+
|
| 604 |
+
from torch._guards import detect_fake_mode
|
| 605 |
+
|
| 606 |
+
fake_mode = detect_fake_mode(operands)
|
| 607 |
+
insert_deferred_runtime_asserts(
|
| 608 |
+
graph,
|
| 609 |
+
fake_mode.shape_env,
|
| 610 |
+
"invoke_subgraph_proxy_torch_dispatch_mode",
|
| 611 |
+
export=True,
|
| 612 |
+
)
|
| 613 |
+
graph.recompile()
|
| 614 |
+
|
| 615 |
+
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
| 616 |
+
if invoke_subgraph_cache:
|
| 617 |
+
invoke_subgraph_cache.add_proxy_dispatch_entry(identifier, graph)
|
| 618 |
+
|
| 619 |
+
node_args = (graph, identifier, *operands)
|
| 620 |
+
|
| 621 |
+
def _unwrap_proxy(arg):
|
| 622 |
+
if isinstance(arg, torch.fx.GraphModule):
|
| 623 |
+
# NOTE: [invoke_subgraph proxy_mode x auto_functionalize]
|
| 624 |
+
# Previously, we assumed that `invoke_subgraph` would always be traced with the same tracer.
|
| 625 |
+
# This allowed us to cache modules by their identifiers, assuming they were already registered.
|
| 626 |
+
#
|
| 627 |
+
# However, this assumption no longer holds when we auto-functionalize `invoke_subgraph`.
|
| 628 |
+
# auto_functionalize functionalizes the subgraph and wrap it with `FunctionWithNoFreeVars`.
|
| 629 |
+
# In the proxy mode implementation of `auto_functionalized_v2`, we need to materialize `FunctionWithNoFreeVars`
|
| 630 |
+
# input as a graph module. To do this, we re-trace the `invoke_subgraph` hop, which starts a new sub-tracer
|
| 631 |
+
# (see NOTE [materialize callable inputs as graph]). # When the new sub-tracer traces the `invoke_subgraph`
|
| 632 |
+
# with a previously cached identifier, the corresponding graph module might not
|
| 633 |
+
# exist as a submodule in the new tracer's root. Therefore, we register it as a submodule below.
|
| 634 |
+
#
|
| 635 |
+
# The alternative is to give a new identifer when we re-trace the invoke_subgraph but this will increase
|
| 636 |
+
# the compilatoin time, which defeats the purpose of caching.
|
| 637 |
+
registered_before = False
|
| 638 |
+
for (
|
| 639 |
+
_,
|
| 640 |
+
submod,
|
| 641 |
+
) in proxy_mode.tracer.root.named_modules(): # type: ignore[union-attr]
|
| 642 |
+
if arg is submod:
|
| 643 |
+
registered_before = True
|
| 644 |
+
|
| 645 |
+
if not registered_before:
|
| 646 |
+
qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph") # type: ignore[union-attr]
|
| 647 |
+
proxy_mode.tracer.root.register_module(qualname, arg) # type: ignore[union-attr]
|
| 648 |
+
return proxy_mode.tracer.unwrap_proxy(arg) # type: ignore[union-attr]
|
| 649 |
+
|
| 650 |
+
proxy_args = pytree.tree_map(_unwrap_proxy, node_args) # type: ignore[union-attr]
|
| 651 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 652 |
+
"call_function", invoke_subgraph, proxy_args, {}
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
example_out = invoke_subgraph(graph, identifier, *operands)
|
| 656 |
+
return track_tensor_tree(
|
| 657 |
+
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
| 658 |
+
)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/map.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
from typing import Callable, Union
|
| 4 |
+
from typing_extensions import TypeVarTuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.utils._pytree as pytree
|
| 8 |
+
from torch._C import DispatchKey
|
| 9 |
+
from torch._dispatch.python import suspend_functionalization
|
| 10 |
+
from torch._higher_order_ops.utils import _maybe_run_with_interpreter, reenter_make_fx
|
| 11 |
+
from torch._ops import HigherOrderOperator
|
| 12 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 13 |
+
from torch._subclasses.functional_tensor import disable_functional_mode
|
| 14 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 15 |
+
disable_proxy_modes_tracing,
|
| 16 |
+
make_fx,
|
| 17 |
+
ProxyTorchDispatchMode,
|
| 18 |
+
track_tensor_tree,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from .utils import (
|
| 22 |
+
_from_fun,
|
| 23 |
+
_stack_pytree,
|
| 24 |
+
_unstack_pytree,
|
| 25 |
+
clone_outputs_aliasing_inputs,
|
| 26 |
+
prepare_fw_with_masks,
|
| 27 |
+
save_tensors_and_symints_for_backward,
|
| 28 |
+
saved_tensors_and_symints,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MapImpl(HigherOrderOperator):
|
| 33 |
+
def __init__(self):
|
| 34 |
+
super().__init__("map_impl")
|
| 35 |
+
|
| 36 |
+
def __call__(self, *args, **kwargs):
|
| 37 |
+
return super().__call__(*args, **kwargs)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
map_impl = MapImpl()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_fw_bw_graph(f, num_mapped_args, *args):
|
| 44 |
+
mapped_xs = args[:num_mapped_args]
|
| 45 |
+
pos_args = args[num_mapped_args:]
|
| 46 |
+
|
| 47 |
+
# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
|
| 48 |
+
|
| 49 |
+
with suspend_functionalization(), disable_functional_mode():
|
| 50 |
+
with disable_proxy_modes_tracing():
|
| 51 |
+
unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs)
|
| 52 |
+
example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]
|
| 53 |
+
|
| 54 |
+
example_pos_args = [
|
| 55 |
+
_from_fun(arg) if isinstance(arg, torch.Tensor) else arg
|
| 56 |
+
for arg in pos_args
|
| 57 |
+
]
|
| 58 |
+
example_flat_out = pytree.tree_map(
|
| 59 |
+
_from_fun, f(*example_xs, *example_pos_args)
|
| 60 |
+
)
|
| 61 |
+
if any(
|
| 62 |
+
not isinstance(out, torch.Tensor)
|
| 63 |
+
for out in example_flat_out
|
| 64 |
+
if out is not None
|
| 65 |
+
):
|
| 66 |
+
raise RuntimeError(
|
| 67 |
+
"Expect outputs of map only contains tensors or None. "
|
| 68 |
+
f"Got types {[type(out) for out in example_flat_out]}."
|
| 69 |
+
)
|
| 70 |
+
example_grad = [_from_fun(out) for out in example_flat_out]
|
| 71 |
+
|
| 72 |
+
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
|
| 73 |
+
|
| 74 |
+
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
| 75 |
+
|
| 76 |
+
dummy_aot_config = AOTConfig(
|
| 77 |
+
fw_compiler=None, # type: ignore[arg-type]
|
| 78 |
+
bw_compiler=None, # type: ignore[arg-type]
|
| 79 |
+
partition_fn=None, # type: ignore[arg-type]
|
| 80 |
+
decompositions={},
|
| 81 |
+
num_params_buffers=0,
|
| 82 |
+
aot_id=0,
|
| 83 |
+
keep_inference_input_mutations=False,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def joint_f(*example_args):
|
| 87 |
+
joint_mapped_args = example_args[:joint_num_mapped]
|
| 88 |
+
args = example_args[joint_num_mapped:]
|
| 89 |
+
|
| 90 |
+
mapped_input = joint_mapped_args[:num_mapped_args]
|
| 91 |
+
mapped_grads = joint_mapped_args[num_mapped_args:]
|
| 92 |
+
|
| 93 |
+
joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config)
|
| 94 |
+
_, grads = joint(
|
| 95 |
+
list(mapped_input) + list(args),
|
| 96 |
+
[
|
| 97 |
+
grad
|
| 98 |
+
for grad in mapped_grads
|
| 99 |
+
if grad is not None and grad.requires_grad
|
| 100 |
+
],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# In order to keep map functional for backward graph,
|
| 104 |
+
# we clone outputs that are aliasing inputs
|
| 105 |
+
maybe_clone = clone_outputs_aliasing_inputs(example_args)
|
| 106 |
+
|
| 107 |
+
return pytree.tree_map(maybe_clone, grads)
|
| 108 |
+
|
| 109 |
+
joint_num_mapped = len(example_grad) + len(example_xs)
|
| 110 |
+
joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
|
| 111 |
+
return fw_graph, joint_graph
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def map(
|
| 115 |
+
f: Callable[[pytree.PyTree, tuple[pytree.PyTree, ...]], pytree.PyTree],
|
| 116 |
+
xs: Union[pytree.PyTree, torch.Tensor],
|
| 117 |
+
*args: TypeVarTuple,
|
| 118 |
+
):
|
| 119 |
+
r"""
|
| 120 |
+
Perfoms a map of f with xs. Intuitively, you can think of the semantic being:
|
| 121 |
+
|
| 122 |
+
out = []
|
| 123 |
+
for idx in len(xs.size(0)):
|
| 124 |
+
xs_sliced = xs.select(0, idx)
|
| 125 |
+
out.append(f(xs_sliced, *args))
|
| 126 |
+
torch.stack(out)
|
| 127 |
+
|
| 128 |
+
.. warning::
|
| 129 |
+
`torch._higher_order_ops.map` is a prototype feature in PyTorch. It currently
|
| 130 |
+
does not support autograd and you may run into miscompiles.
|
| 131 |
+
Read more about feature classification at:
|
| 132 |
+
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
f (Callable): a callable that takes an input x, that could either be a single Tensor
|
| 137 |
+
or a nested dict, list of tensors and some additional inputs
|
| 138 |
+
xs: the inputs that're to be mapped over. We'll iterate over the first dim of each x
|
| 139 |
+
and perform f on each slice.
|
| 140 |
+
|
| 141 |
+
*args: additional arguments provided to each step of f. They could also be omitted and
|
| 142 |
+
map is able to automatically figure out the read dependency.
|
| 143 |
+
|
| 144 |
+
Return:
|
| 145 |
+
the stacked output for each step of f
|
| 146 |
+
|
| 147 |
+
Example:
|
| 148 |
+
|
| 149 |
+
def f(xs):
|
| 150 |
+
return xs[0] + xs[1] + const1 + const2
|
| 151 |
+
|
| 152 |
+
xs = [torch.randn(2, 3), torch.randn(2, 3)]
|
| 153 |
+
const1 = torch.randn(2, 3)
|
| 154 |
+
const2 = torch.randn(2, 3)
|
| 155 |
+
# returns a tensor of shape [2, 2, 3]
|
| 156 |
+
torch._higher_order_ops.map(f, xs)
|
| 157 |
+
|
| 158 |
+
"""
|
| 159 |
+
flat_xs, xs_spec = pytree.tree_flatten(xs)
|
| 160 |
+
flat_args, args_spec = pytree.tree_flatten(args)
|
| 161 |
+
if not all(isinstance(t, torch.Tensor) for t in flat_xs):
|
| 162 |
+
raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
|
| 163 |
+
|
| 164 |
+
shapes = [xs.shape for xs in flat_xs]
|
| 165 |
+
leading_dim_size = shapes[0][0]
|
| 166 |
+
if leading_dim_size == 0:
|
| 167 |
+
raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
|
| 168 |
+
|
| 169 |
+
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
|
| 170 |
+
raise RuntimeError(
|
| 171 |
+
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def run_flattened_map(f, flat_xs, flat_args):
|
| 175 |
+
def wrapped_fn(*flat_args, f, xs_tree_spec, args_tree_spec, num_xs):
|
| 176 |
+
xs = pytree.tree_unflatten(flat_args[:num_xs], xs_tree_spec)
|
| 177 |
+
args = pytree.tree_unflatten(flat_args[num_xs:], args_tree_spec)
|
| 178 |
+
return f(xs, *args)
|
| 179 |
+
|
| 180 |
+
inner_f = functools.partial(
|
| 181 |
+
wrapped_fn,
|
| 182 |
+
f=f,
|
| 183 |
+
xs_tree_spec=xs_spec,
|
| 184 |
+
args_tree_spec=args_spec,
|
| 185 |
+
num_xs=len(flat_xs),
|
| 186 |
+
)
|
| 187 |
+
return map_impl(inner_f, flat_xs, flat_args)
|
| 188 |
+
|
| 189 |
+
from torch._higher_order_ops.utils import _maybe_compile_and_run_fn
|
| 190 |
+
|
| 191 |
+
return _maybe_compile_and_run_fn(run_flattened_map, f, flat_xs, flat_args)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class MapAutogradOp(torch.autograd.Function):
|
| 195 |
+
@staticmethod
|
| 196 |
+
def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
|
| 197 |
+
save_tensors_and_symints_for_backward(ctx, flat_args)
|
| 198 |
+
ctx._joint_graph = joint_graph
|
| 199 |
+
ctx._num_mapped_args = num_mapped_args
|
| 200 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 201 |
+
return (
|
| 202 |
+
*map_impl(
|
| 203 |
+
fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
@staticmethod
|
| 208 |
+
def backward(ctx, *flat_grads):
|
| 209 |
+
fw_args = saved_tensors_and_symints(ctx)
|
| 210 |
+
fw_mapped_args = fw_args[: ctx._num_mapped_args]
|
| 211 |
+
pos_args = fw_args[ctx._num_mapped_args :]
|
| 212 |
+
|
| 213 |
+
grads = map_impl(
|
| 214 |
+
ctx._joint_graph,
|
| 215 |
+
fw_mapped_args + flat_grads,
|
| 216 |
+
pos_args,
|
| 217 |
+
)
|
| 218 |
+
return None, None, None, *grads
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def trace_map(proxy_mode, func_overload, f, xs, pos_args):
|
| 222 |
+
example_input = _unstack_pytree(xs)[0]
|
| 223 |
+
body_graph = f
|
| 224 |
+
|
| 225 |
+
body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args)
|
| 226 |
+
|
| 227 |
+
next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_")
|
| 228 |
+
|
| 229 |
+
proxy_mode.tracer.root.register_module(next_name, body_graph)
|
| 230 |
+
|
| 231 |
+
fake_outs = map_impl(body_graph, xs, pos_args)
|
| 232 |
+
|
| 233 |
+
node_args = (body_graph, list(xs), list(pos_args))
|
| 234 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
| 235 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 236 |
+
"call_function", func_overload, proxy_args, {}, name="map_impl"
|
| 237 |
+
)
|
| 238 |
+
return track_tensor_tree(
|
| 239 |
+
fake_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 244 |
+
def map_dense(f, xs, pos_args):
|
| 245 |
+
pytrees = [f(*inp, *pos_args) for inp in _unstack_pytree(xs)]
|
| 246 |
+
return _stack_pytree(pytrees)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@map_impl.py_autograd_impl
|
| 250 |
+
def map_autograd(f, xs, pos_args):
|
| 251 |
+
num_mapped_args = len(xs)
|
| 252 |
+
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
|
| 253 |
+
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
|
| 254 |
+
return flat_out
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@map_impl.py_impl(ProxyTorchDispatchMode)
|
| 258 |
+
def map_proxy_torch_dispatch_mode(mode, f, xs, args):
|
| 259 |
+
return trace_map(mode, map_impl, f, xs, args)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@map_impl.py_impl(FakeTensorMode)
|
| 263 |
+
def map_fake_tensor_mode(mode, f, xs, args):
|
| 264 |
+
with mode:
|
| 265 |
+
return map_dense(f, xs, args)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
@map_impl.py_functionalize_impl
|
| 269 |
+
def map_functionalize(ctx, f, xs, pos_args):
|
| 270 |
+
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
| 271 |
+
|
| 272 |
+
unwrapped_xs = ctx.unwrap_tensors(xs)
|
| 273 |
+
unwrapped_args = ctx.unwrap_tensors(pos_args)
|
| 274 |
+
wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f))
|
| 275 |
+
|
| 276 |
+
with ctx.redispatch_to_next():
|
| 277 |
+
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
| 278 |
+
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
| 279 |
+
_check_alias_and_mutation(f, example_inputs, "map", pre_dispatch)
|
| 280 |
+
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
|
| 281 |
+
return ctx.wrap_tensors(map_return)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _fake_map(f, x, *args):
|
| 285 |
+
from functorch.experimental.control_flow import _stack_pytree, _unstack_pytree
|
| 286 |
+
|
| 287 |
+
x_pytrees = _unstack_pytree(x)
|
| 288 |
+
zs = []
|
| 289 |
+
for xp in x_pytrees:
|
| 290 |
+
zs.append(f(xp, *args))
|
| 291 |
+
return _stack_pytree(zs)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/out_dtype.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.utils._pytree as pytree
|
| 5 |
+
from torch._C import DispatchKey
|
| 6 |
+
from torch._higher_order_ops.utils import autograd_not_implemented
|
| 7 |
+
from torch._ops import HigherOrderOperator
|
| 8 |
+
from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
|
| 9 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 10 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 11 |
+
disable_proxy_modes_tracing,
|
| 12 |
+
maybe_handle_decomp,
|
| 13 |
+
ProxyTorchDispatchMode,
|
| 14 |
+
track_tensor_tree,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# TODO to figure out a more generic approach
|
| 19 |
+
ALLOWABLE_OPS = [
|
| 20 |
+
torch.ops.aten.linear.default,
|
| 21 |
+
torch.ops.aten.mm.default,
|
| 22 |
+
torch.ops.aten.conv2d.default,
|
| 23 |
+
torch.ops.aten.convolution.default,
|
| 24 |
+
torch.ops.aten.mul.Tensor,
|
| 25 |
+
torch.ops.aten.mul.Scalar,
|
| 26 |
+
torch.ops.aten.div.Tensor,
|
| 27 |
+
torch.ops.aten.div.Scalar,
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class OutDtypeOperator(HigherOrderOperator):
|
| 32 |
+
"""
|
| 33 |
+
The out_dtype operator takes an existing ATen functional operator, an
|
| 34 |
+
`out_dtype` argument, and arguments to the original operator, and executes
|
| 35 |
+
the original operator and returns a Tensor with the `out_dtype` precision.
|
| 36 |
+
This operator does not mandate a compute precision so it allows the
|
| 37 |
+
representation to not be opinionated about the exact implementation.
|
| 38 |
+
|
| 39 |
+
The general implementation for all operators will be the following:
|
| 40 |
+
1. Promote inputs dtypes based on default PyTorch dtype promotion rules,
|
| 41 |
+
using the dtypes of all input Tensors/Scalars and the `out_dtype`
|
| 42 |
+
arugument.
|
| 43 |
+
2. Execute the operator
|
| 44 |
+
3. Cast the output to `out_dtype`
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self) -> None:
|
| 48 |
+
super().__init__("out_dtype")
|
| 49 |
+
|
| 50 |
+
def __call__(self, op, output_dtype, *args):
|
| 51 |
+
if not isinstance(op, torch._ops.OpOverload):
|
| 52 |
+
raise ValueError("out_dtype's first argument must be an OpOverload")
|
| 53 |
+
if op._schema.is_mutable:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
"out_dtype's first argument needs to be a functional operator"
|
| 56 |
+
)
|
| 57 |
+
if not (
|
| 58 |
+
len(op._schema.returns) == 1
|
| 59 |
+
and isinstance(op._schema.returns[0].type, torch.TensorType)
|
| 60 |
+
):
|
| 61 |
+
raise ValueError(
|
| 62 |
+
"out_dtype's can only apply to ops that return a single tensor"
|
| 63 |
+
f"Instead got {[r.type for r in op._schema.returns]}"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if op not in ALLOWABLE_OPS:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
f"out_dtype only allows the following operators: {ALLOWABLE_OPS}."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
res = super().__call__(op, output_dtype, *args)
|
| 72 |
+
|
| 73 |
+
return res
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
out_dtype = OutDtypeOperator()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args):
|
| 80 |
+
# NB: Long-term we should put the decomposition logic into
|
| 81 |
+
# ProxyTorchDispatchMode so that people do not need to call maybe_handle_decomp
|
| 82 |
+
# in all HigherOrderOp proxy implementations.
|
| 83 |
+
r = maybe_handle_decomp(proxy_mode, func_overload, (op, output_dtype, *args), {})
|
| 84 |
+
if r is not NotImplemented:
|
| 85 |
+
return r
|
| 86 |
+
|
| 87 |
+
with disable_proxy_modes_tracing():
|
| 88 |
+
# This is a simplified implementation of this operator just for tracing.
|
| 89 |
+
# Actual implementation may also first promote the arguments
|
| 90 |
+
out = op(*args).to(dtype=output_dtype)
|
| 91 |
+
|
| 92 |
+
node_args = (op, output_dtype, *args)
|
| 93 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
| 94 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 95 |
+
"call_function", func_overload, proxy_args, {}, name="out_dtype"
|
| 96 |
+
)
|
| 97 |
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@out_dtype.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 101 |
+
def out_dtype_dense(op: torch._ops.OpOverload, output_dtype: torch.dtype, *args):
|
| 102 |
+
if is_int_mm(op, output_dtype, args):
|
| 103 |
+
return torch._int_mm(*args)
|
| 104 |
+
return out_dtype_fallback(op, output_dtype, *args)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def is_int_mm(op, output_dtype, args):
|
| 108 |
+
return (
|
| 109 |
+
op == torch.ops.aten.mm.default
|
| 110 |
+
and output_dtype == torch.int32
|
| 111 |
+
and len(args) == 2
|
| 112 |
+
and args[0].dtype == torch.int8
|
| 113 |
+
and args[1].dtype == torch.int8
|
| 114 |
+
and args[0].is_cuda
|
| 115 |
+
and args[1].is_cuda
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def out_dtype_fallback(op, output_dtype, *args):
|
| 120 |
+
flat_inputs = pytree.arg_tree_leaves(*args) + [torch.ones(1, dtype=output_dtype)]
|
| 121 |
+
promote_dtype: torch.dtype = elementwise_dtypes(
|
| 122 |
+
*flat_inputs,
|
| 123 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 124 |
+
)[0]
|
| 125 |
+
|
| 126 |
+
casted_args = pytree.tree_map_only(
|
| 127 |
+
torch.Tensor, lambda arg: arg.to(dtype=promote_dtype), args
|
| 128 |
+
)
|
| 129 |
+
res = op(*casted_args).to(dtype=output_dtype)
|
| 130 |
+
return res
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
out_dtype.py_autograd_impl(autograd_not_implemented(out_dtype, deferred_error=True))
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@out_dtype.py_impl(ProxyTorchDispatchMode)
|
| 137 |
+
def out_dtype_proxy(
|
| 138 |
+
mode: ProxyTorchDispatchMode,
|
| 139 |
+
op: torch._ops.OpOverload,
|
| 140 |
+
output_dtype: torch.dtype,
|
| 141 |
+
*args,
|
| 142 |
+
):
|
| 143 |
+
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@out_dtype.py_impl(FakeTensorMode)
|
| 147 |
+
def out_dtype_fake_tensor_mode(
|
| 148 |
+
mode: FakeTensorMode,
|
| 149 |
+
op: torch._ops.OpOverload,
|
| 150 |
+
output_dtype: torch.dtype,
|
| 151 |
+
*args,
|
| 152 |
+
):
|
| 153 |
+
with mode:
|
| 154 |
+
return out_dtype_dense(op, output_dtype, *args)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@out_dtype.py_functionalize_impl
|
| 158 |
+
def out_dtype_func(ctx, op, output_dtype, *args):
|
| 159 |
+
unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
|
| 160 |
+
|
| 161 |
+
with ctx.redispatch_to_next():
|
| 162 |
+
res = out_dtype(op, output_dtype, *unwrapped_args)
|
| 163 |
+
return ctx.wrap_tensors(res)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/run_const_graph.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
from torch._C import DispatchKey
|
| 4 |
+
from torch._higher_order_ops.utils import autograd_not_implemented
|
| 5 |
+
from torch._ops import HigherOrderOperator
|
| 6 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 7 |
+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
| 8 |
+
from torch.utils import _pytree as pytree
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RunConstGraph(HigherOrderOperator):
|
| 12 |
+
def __init__(self):
|
| 13 |
+
super().__init__("run_const_graph")
|
| 14 |
+
|
| 15 |
+
def __call__(self, graph, args):
|
| 16 |
+
return super().__call__(graph, args)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
run_const_graph = RunConstGraph()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@run_const_graph.py_impl(ProxyTorchDispatchMode)
|
| 23 |
+
def run_const_graph_dispatch_mode(mode, graph, args):
|
| 24 |
+
const_gm, weights = graph, args
|
| 25 |
+
p_args = pytree.tree_map(mode.tracer.unwrap_proxy, (graph, args))
|
| 26 |
+
assert isinstance(const_gm, torch.fx.GraphModule)
|
| 27 |
+
assert not hasattr(mode.tracer.root, "_const_graph")
|
| 28 |
+
mode.tracer.root.register_module("_const_graph", const_gm)
|
| 29 |
+
|
| 30 |
+
proxy = mode.tracer.create_proxy("call_function", run_const_graph, p_args, {})
|
| 31 |
+
|
| 32 |
+
out = const_gm(*weights)
|
| 33 |
+
return track_tensor_tree(out, proxy, constant=None, tracer=mode.tracer)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@run_const_graph.py_functionalize_impl
|
| 37 |
+
def run_const_graph_functional(ctx, graph, args):
|
| 38 |
+
unwrapped_args = ctx.unwrap_tensors(args)
|
| 39 |
+
|
| 40 |
+
with ctx.redispatch_to_next():
|
| 41 |
+
out = run_const_graph(*unwrapped_args)
|
| 42 |
+
return ctx.wrap_tensors(out)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
run_const_graph.py_autograd_impl(
|
| 46 |
+
autograd_not_implemented(run_const_graph, deferred_error=True)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@run_const_graph.py_impl(FakeTensorMode)
|
| 51 |
+
def run_const_graph_fake_tensor_mode(mode, graph, args):
|
| 52 |
+
assert isinstance(graph, torch.fx.GraphModule)
|
| 53 |
+
with mode:
|
| 54 |
+
return graph(*args)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@run_const_graph.py_impl(DispatchKey.CPU)
|
| 58 |
+
def run_const_graph_cpu(graph, args):
|
| 59 |
+
assert isinstance(graph, torch.fx.GraphModule)
|
| 60 |
+
return graph(*args)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/scan.py
ADDED
|
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import itertools
|
| 4 |
+
from collections.abc import Sequence
|
| 5 |
+
from typing import Any, Callable, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch._prims_common as utils
|
| 9 |
+
import torch.utils._pytree as pytree
|
| 10 |
+
from torch._C import DispatchKey
|
| 11 |
+
from torch._higher_order_ops.cond import create_bw_fn
|
| 12 |
+
from torch._higher_order_ops.utils import (
|
| 13 |
+
_maybe_compile_and_run_fn,
|
| 14 |
+
check_meta_consistency,
|
| 15 |
+
first_slice_copy,
|
| 16 |
+
materialize_as_graph,
|
| 17 |
+
reenter_make_fx,
|
| 18 |
+
save_tensors_and_symints_for_backward,
|
| 19 |
+
saved_tensors_and_symints,
|
| 20 |
+
unique_graph_id,
|
| 21 |
+
validate_subgraph_args_types,
|
| 22 |
+
)
|
| 23 |
+
from torch._ops import HigherOrderOperator
|
| 24 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 25 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 26 |
+
disable_proxy_modes_tracing,
|
| 27 |
+
ProxyTorchDispatchMode,
|
| 28 |
+
track_tensor_tree,
|
| 29 |
+
)
|
| 30 |
+
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
aten = torch._ops.ops.aten
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def wrap_combine_fn_flat(
|
| 37 |
+
*args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves
|
| 38 |
+
):
|
| 39 |
+
assert len(args) == (
|
| 40 |
+
num_init_leaves + num_inp_leaves
|
| 41 |
+
), f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}"
|
| 42 |
+
carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init)
|
| 43 |
+
xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs)
|
| 44 |
+
return combine_fn(carry, xs)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _extract_carry_and_out(flat_out: list[Any], num_carry: int):
|
| 48 |
+
return split_into_chunks(flat_out, [num_carry, len(flat_out) - num_carry])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# We also do a clone with contiguous_format. This is to be consistent with
|
| 52 |
+
# eager semantic of scan, which stacks the outputs. The result is contiguous
|
| 53 |
+
# as a result of the stack operation.
|
| 54 |
+
def stack_y(y: torch.Tensor, scan_length: int) -> torch.Tensor:
|
| 55 |
+
return (
|
| 56 |
+
y.unsqueeze(0)
|
| 57 |
+
.repeat(*([scan_length] + [1] * y.ndim))
|
| 58 |
+
.clone(memory_format=torch.contiguous_format)
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# NOTE: These functions can be reused in associative_scan and eventually moved to
|
| 63 |
+
# torch._higher_order_ops.utils
|
| 64 |
+
def get_tensor_mask(tensor_list: list[Any]) -> list[bool]:
|
| 65 |
+
# Returns a mask whether a list element is a tensor or not
|
| 66 |
+
return [True if isinstance(v, torch.Tensor) else False for v in tensor_list]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def mask_list(
|
| 70 |
+
mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None
|
| 71 |
+
) -> list[Any]:
|
| 72 |
+
# Masks elements on an `inp` list.
|
| 73 |
+
# If other is None, then the elements of the `inp` list where the mask is False are removed
|
| 74 |
+
# If other is not None, then the elements of the `inp` list where the mask is False are
|
| 75 |
+
# replaced with the elements of the `other` list
|
| 76 |
+
assert len(mask) == len(
|
| 77 |
+
inp
|
| 78 |
+
), "The length of the mask needs to be identical to the length of the input"
|
| 79 |
+
if other is not None:
|
| 80 |
+
assert len(inp) == len(
|
| 81 |
+
other
|
| 82 |
+
), "If an input and an other list is provided, they need to have the same length"
|
| 83 |
+
return [i if m else o for m, i, o in zip(mask, inp, other)]
|
| 84 |
+
else:
|
| 85 |
+
return [i for m, i in zip(mask, inp) if m]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def first_slice_copy_with_grad(li: list[Any]) -> list[Any]:
|
| 89 |
+
# First_slice_copy does not keep the original requires_grad flag,
|
| 90 |
+
# but we need it for materialize_as_graph
|
| 91 |
+
# in order to compute the correct gradients
|
| 92 |
+
# The reason why first_slice_copy doesn't keep requires_grad flag is
|
| 93 |
+
# because it's called in torch.autograd.Function.backward/forward.
|
| 94 |
+
slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li]
|
| 95 |
+
return slc
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]:
|
| 99 |
+
it = iter(iterable)
|
| 100 |
+
assert sum(chunk_sizes) == len(
|
| 101 |
+
iterable
|
| 102 |
+
), "the sum of all chunks needs to match the length of the iterable."
|
| 103 |
+
return [list(itertools.islice(it, size)) for size in chunk_sizes]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def call_operator(operator, *args):
|
| 107 |
+
return pytree.tree_leaves(operator(*args))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def scan(
|
| 111 |
+
combine_fn: Callable[
|
| 112 |
+
[pytree.PyTree, pytree.PyTree], tuple[pytree.PyTree, pytree.PyTree]
|
| 113 |
+
],
|
| 114 |
+
init: pytree.PyTree,
|
| 115 |
+
xs: pytree.PyTree,
|
| 116 |
+
*,
|
| 117 |
+
dim: int = 0,
|
| 118 |
+
reverse: bool = False,
|
| 119 |
+
) -> tuple[pytree.PyTree, pytree.PyTree]:
|
| 120 |
+
r"""
|
| 121 |
+
Performs an inclusive scan with a combine function.
|
| 122 |
+
|
| 123 |
+
.. warning::
|
| 124 |
+
`torch.scan` is a prototype feature in PyTorch. It currently
|
| 125 |
+
does not support autograd and you may run into miscompiles.
|
| 126 |
+
Read more about feature classification at:
|
| 127 |
+
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> (Tensor, Tensor)``,
|
| 131 |
+
or if xs is a pytree ``(pytree, pytree) -> (pytree, pytree)``.
|
| 132 |
+
The first input to ``combine_fn`` is the previous or initial scan carry
|
| 133 |
+
and the second input element to ``combine_fn`` is a slice of the input along dim.
|
| 134 |
+
The first output element of ``combine_fn`` is the next scan carry
|
| 135 |
+
and the second output of ``combine_fn`` represents a slice of the output.
|
| 136 |
+
This function must be pure, i.e., no lifted arguments are supported at the moment
|
| 137 |
+
and may not have any side effects.
|
| 138 |
+
init (torch.Tensor or pytree with tensor leaves): The inital scan carry, a tensor, or nested pytree of tensors.
|
| 139 |
+
The ``init`` is expected to have the same pytree structure as the first output element (i.e. carry)
|
| 140 |
+
of ``combine_fn``.
|
| 141 |
+
xs (torch.Tensor or pytree with tensor leaves): The input tensor, or nested pytree of tensors.
|
| 142 |
+
|
| 143 |
+
Kwargs:
|
| 144 |
+
dim (int): the dimension to scan over, default 0.
|
| 145 |
+
reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
final_carry (torch.Tensor or pytree with tensor leaves),
|
| 149 |
+
the final carry of the scan operation with same pytree structure as init.
|
| 150 |
+
out (torch.Tensor or pytree with tensor leaves),
|
| 151 |
+
each tensor leaf is a stacked output along first dim, where each slice is the output of a scan iteration.
|
| 152 |
+
|
| 153 |
+
Restrictions:
|
| 154 |
+
- The combine_fn shouldn't have any aliasing between input-input, input-output, and output-output. E.g. return a view
|
| 155 |
+
or the same tensor as input is not supported. As a workaround, can clone the output to avoid aliasing.
|
| 156 |
+
|
| 157 |
+
- The combine_fn shoudn't mutate any inputs. We'll remove the mutation restriction for inference soon. Please file an issue
|
| 158 |
+
if you input mutation support for training is needed.
|
| 159 |
+
|
| 160 |
+
- The combine_fn's init carry should match the next_carry in pytree structure and in tensor metadata.
|
| 161 |
+
|
| 162 |
+
Example::
|
| 163 |
+
|
| 164 |
+
def add(x: torch.Tensor, y: torch.Tensor):
|
| 165 |
+
next_carry = y = x + y
|
| 166 |
+
# clone the output to avoid output-output aliasing
|
| 167 |
+
return next_carry, y.clone()
|
| 168 |
+
|
| 169 |
+
i0 = torch.zeros(1)
|
| 170 |
+
xs = torch.arange(5)
|
| 171 |
+
# returns torch.tensor([10.]), torch.tensor([[0], [1.], [3.], [6.], [10.]])
|
| 172 |
+
last_carry, cumsum = scan(add, init=i0, xs=xs)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
"""
|
| 176 |
+
# The reason we flatten init and xs before calling into dynamo is that
|
| 177 |
+
# we want to create a consistent input ordering for combine_fn
|
| 178 |
+
# and we also want to the input ordering matches the output ordering.
|
| 179 |
+
leaves_init, spec_init = pytree.tree_flatten(init)
|
| 180 |
+
leaves_xs_orig, spec_xs = pytree.tree_flatten(xs)
|
| 181 |
+
|
| 182 |
+
# Shortcut if no xs is provided
|
| 183 |
+
if len(leaves_xs_orig) == 0:
|
| 184 |
+
return init, []
|
| 185 |
+
|
| 186 |
+
def _validate_input(cfn, lxs, linit, d, r):
|
| 187 |
+
# Basic arguments check
|
| 188 |
+
if not callable(cfn):
|
| 189 |
+
raise RuntimeError("Combine_fn must be a callable, but got {cfn}")
|
| 190 |
+
if not isinstance(d, int):
|
| 191 |
+
raise RuntimeError("Dim must be an int, but got " + str(type(d)))
|
| 192 |
+
if not isinstance(r, bool):
|
| 193 |
+
raise RuntimeError("Reverse must be a bool, but got " + str(type(r)))
|
| 194 |
+
|
| 195 |
+
# Checks for init
|
| 196 |
+
if len(linit) == 0:
|
| 197 |
+
raise RuntimeError("scan() operator requires init leaves.")
|
| 198 |
+
for x in linit:
|
| 199 |
+
if not isinstance(x, torch.Tensor):
|
| 200 |
+
raise RuntimeError(f"All init leaves must be a Tensor but got {x}")
|
| 201 |
+
|
| 202 |
+
# Checks for xs
|
| 203 |
+
for x in lxs:
|
| 204 |
+
if not isinstance(x, torch.Tensor):
|
| 205 |
+
raise RuntimeError(f"All xs leaves must be a Tensor but got {x}")
|
| 206 |
+
if any(x.ndim <= d for x in lxs):
|
| 207 |
+
raise RuntimeError(
|
| 208 |
+
"All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0"
|
| 209 |
+
)
|
| 210 |
+
if any(x.shape[d] == 0 for x in lxs):
|
| 211 |
+
raise RuntimeError(
|
| 212 |
+
"All xs leaves must at least have 'dim' number of dimensions and scan dimension > 0"
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
ndim = leaves_xs_orig[0].ndim
|
| 216 |
+
dim = utils.canonicalize_dim(ndim, dim)
|
| 217 |
+
|
| 218 |
+
_validate_input(combine_fn, leaves_xs_orig, leaves_init, dim, reverse)
|
| 219 |
+
|
| 220 |
+
# Move scan dim to 0 and always perform scan on dim 0
|
| 221 |
+
leaves_xs = []
|
| 222 |
+
for elem in leaves_xs_orig:
|
| 223 |
+
leaves_xs.append(torch.movedim(elem, dim, 0))
|
| 224 |
+
|
| 225 |
+
if reverse:
|
| 226 |
+
leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs]
|
| 227 |
+
|
| 228 |
+
# TODO: Support _inductor lowering
|
| 229 |
+
# TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc.
|
| 230 |
+
|
| 231 |
+
combine_fn = functools.partial(
|
| 232 |
+
wrap_combine_fn_flat,
|
| 233 |
+
combine_fn=combine_fn,
|
| 234 |
+
spec_init=spec_init,
|
| 235 |
+
spec_xs=spec_xs,
|
| 236 |
+
num_init_leaves=len(leaves_init),
|
| 237 |
+
num_inp_leaves=len(leaves_xs),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def run_flattened_scan(combine_fn, leaves_init, leaves_xs):
|
| 241 |
+
return scan_op(combine_fn, leaves_init, leaves_xs, additional_inputs=())
|
| 242 |
+
|
| 243 |
+
carry, out = _maybe_compile_and_run_fn(
|
| 244 |
+
run_flattened_scan,
|
| 245 |
+
combine_fn,
|
| 246 |
+
leaves_init,
|
| 247 |
+
leaves_xs,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if reverse:
|
| 251 |
+
out = pytree.tree_map(lambda elem: elem.flip([0]), out)
|
| 252 |
+
|
| 253 |
+
return carry, out
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class ScanOp(HigherOrderOperator):
|
| 257 |
+
def __init__(self):
|
| 258 |
+
super().__init__("scan")
|
| 259 |
+
|
| 260 |
+
def __call__(self, combine_fn, init, xs, additional_inputs):
|
| 261 |
+
# There is currently an issue that the ScanOp is sometimes called with
|
| 262 |
+
# the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785
|
| 263 |
+
# Once this issue is resolved, the assertion should only allow tuples
|
| 264 |
+
# and the tuple cast should be removed
|
| 265 |
+
assert isinstance(
|
| 266 |
+
additional_inputs, (tuple, list)
|
| 267 |
+
), "additional_inputs must be a tuple."
|
| 268 |
+
additional_inputs = (
|
| 269 |
+
tuple(additional_inputs)
|
| 270 |
+
if isinstance(additional_inputs, list)
|
| 271 |
+
else additional_inputs
|
| 272 |
+
)
|
| 273 |
+
validate_subgraph_args_types(additional_inputs)
|
| 274 |
+
return super().__call__(combine_fn, init, xs, additional_inputs)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
scan_op = ScanOp()
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def generic_scan(operator, init, xs, dim=0, additional_inputs=()):
|
| 281 |
+
def _scan(init, xs):
|
| 282 |
+
"""Perform scan on `elems` using `elems_init."""
|
| 283 |
+
carry = init
|
| 284 |
+
if len(xs) == 0:
|
| 285 |
+
return carry, []
|
| 286 |
+
|
| 287 |
+
num_elems = xs[0].shape[dim]
|
| 288 |
+
ind = 0
|
| 289 |
+
|
| 290 |
+
# Compute dummy shapes for the pre-allocation
|
| 291 |
+
num_init_leaves = len(init)
|
| 292 |
+
dummy_carry, dummy_out = _extract_carry_and_out(
|
| 293 |
+
call_operator(
|
| 294 |
+
operator,
|
| 295 |
+
*carry,
|
| 296 |
+
*[first_slice_copy(elem, dim) for elem in xs],
|
| 297 |
+
*additional_inputs,
|
| 298 |
+
),
|
| 299 |
+
num_init_leaves,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
out_tensor_mask = get_tensor_mask(dummy_out)
|
| 303 |
+
dummy_out_masked = mask_list(out_tensor_mask, dummy_out)
|
| 304 |
+
|
| 305 |
+
# Pre-alocate
|
| 306 |
+
# outs -> Output matrix
|
| 307 |
+
# idxs -> Index matrix for scatter_
|
| 308 |
+
# out: (num_elems, M, N, ...)
|
| 309 |
+
# idx: (1, M, N)
|
| 310 |
+
outs = [
|
| 311 |
+
torch.zeros(
|
| 312 |
+
[num_elems] + list(e.size()),
|
| 313 |
+
dtype=e.dtype,
|
| 314 |
+
device=e.device,
|
| 315 |
+
)
|
| 316 |
+
for i, e in enumerate(dummy_out_masked)
|
| 317 |
+
]
|
| 318 |
+
idxs = [
|
| 319 |
+
torch.ones_like(e, dtype=torch.int64).unsqueeze(0)
|
| 320 |
+
for i, e in enumerate(dummy_out_masked)
|
| 321 |
+
]
|
| 322 |
+
|
| 323 |
+
def store_out_in_outs(out, ind):
|
| 324 |
+
# Store the intermediate out in the outs matrix
|
| 325 |
+
for o, x, idx in zip(outs, out, idxs):
|
| 326 |
+
# o: (num_elems, M, N ...)
|
| 327 |
+
# x: (M, N, ...) -> (1, M, N)
|
| 328 |
+
# ind * idx: (1, M, N,) with values to be ind
|
| 329 |
+
# essentially: o[ind][n][k] = x[0][n][k]
|
| 330 |
+
o.scatter_(0, ind * idx, x.unsqueeze(0))
|
| 331 |
+
|
| 332 |
+
for i in range(num_elems):
|
| 333 |
+
ind = i
|
| 334 |
+
carry, out = _extract_carry_and_out(
|
| 335 |
+
call_operator(
|
| 336 |
+
operator,
|
| 337 |
+
*carry,
|
| 338 |
+
*[elem.select(dim, ind) for elem in xs],
|
| 339 |
+
*additional_inputs,
|
| 340 |
+
),
|
| 341 |
+
num_init_leaves,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Store the inits in the outs matrix.
|
| 345 |
+
store_out_in_outs(mask_list(out_tensor_mask, out), ind)
|
| 346 |
+
|
| 347 |
+
# Expand outs with None depending on the tensor mask of the output
|
| 348 |
+
outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask]
|
| 349 |
+
|
| 350 |
+
return [*carry, *outs_expanded]
|
| 351 |
+
|
| 352 |
+
scans = _scan(init, xs)
|
| 353 |
+
return scans
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def trace_scan(
|
| 357 |
+
proxy_mode,
|
| 358 |
+
func_overload,
|
| 359 |
+
combine_fn: Callable,
|
| 360 |
+
init: list[torch.Tensor],
|
| 361 |
+
xs: list[torch.Tensor],
|
| 362 |
+
additional_inputs: tuple[torch.Tensor],
|
| 363 |
+
):
|
| 364 |
+
from torch._dynamo.utils import clone_input
|
| 365 |
+
|
| 366 |
+
with disable_proxy_modes_tracing():
|
| 367 |
+
sample_inits = [clone_input(x_init) for x_init in init]
|
| 368 |
+
sample_inputs = [first_slice_copy(x) for x in xs]
|
| 369 |
+
sample_additional_inputs = [
|
| 370 |
+
clone_input(x) if isinstance(x, torch.Tensor) else x
|
| 371 |
+
for x in additional_inputs
|
| 372 |
+
]
|
| 373 |
+
combine_graph = reenter_make_fx(combine_fn)(
|
| 374 |
+
*sample_inits, *sample_inputs, *sample_additional_inputs
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
outputs = None
|
| 378 |
+
for node in combine_graph.graph.nodes:
|
| 379 |
+
if node.op == "output":
|
| 380 |
+
assert outputs is None
|
| 381 |
+
assert len(node.args) == 1
|
| 382 |
+
outputs = node.args[0]
|
| 383 |
+
|
| 384 |
+
assert outputs is not None
|
| 385 |
+
|
| 386 |
+
carry, output = _extract_carry_and_out(outputs, len(init))
|
| 387 |
+
init_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
|
| 388 |
+
i.clone() for i in init
|
| 389 |
+
]
|
| 390 |
+
carry_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
|
| 391 |
+
c.meta["val"] for c in carry
|
| 392 |
+
]
|
| 393 |
+
check_meta_consistency(
|
| 394 |
+
init_fake_tensors, carry_fake_tensors, "init", "carry", include_contiguity=False
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
_, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")
|
| 398 |
+
|
| 399 |
+
proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)
|
| 400 |
+
|
| 401 |
+
args = (combine_graph, init, xs, additional_inputs)
|
| 402 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
| 403 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 404 |
+
"call_function", func_overload, proxy_args, {}, name="scan"
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
with disable_proxy_modes_tracing():
|
| 408 |
+
scan_length = xs[0].shape[0]
|
| 409 |
+
fake_carry, fake_outputs = _extract_carry_and_out(
|
| 410 |
+
[o.meta["val"] for o in outputs], len(init)
|
| 411 |
+
)
|
| 412 |
+
out = (
|
| 413 |
+
*fake_carry,
|
| 414 |
+
*(stack_y(t, scan_length) for t in fake_outputs),
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
@scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 421 |
+
def scan_op_dense(combine_fn, init, xs, additional_inputs):
|
| 422 |
+
mode = _get_current_dispatch_mode()
|
| 423 |
+
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
| 424 |
+
return generic_scan(combine_fn, init, xs, additional_inputs=additional_inputs)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class ScanAutogradOp(torch.autograd.Function):
|
| 428 |
+
"""
|
| 429 |
+
Example ::
|
| 430 |
+
|
| 431 |
+
def combine_fn(x: torch.Tensor, y: torch.Tensor):
|
| 432 |
+
next_carry = y = x * y
|
| 433 |
+
return next_carry, y
|
| 434 |
+
|
| 435 |
+
The ``combine_fn_bw``, computing the gradients for x and y of ``combine_fn`` is computed as:
|
| 436 |
+
def combine_fn_bw(x: torch.Tensor, y: torch.Tensor, g_carry: torch.Tensor, g_y: torch.Tensor):
|
| 437 |
+
return g_y * y + g_carry * y, g_y * x + g_carry * x
|
| 438 |
+
|
| 439 |
+
Note: In a real usecase of scan, there may be additional_inputs that participate in the
|
| 440 |
+
forward as well as in the backward of the scan operator. For the sake of readability those inputs
|
| 441 |
+
have been omitted in the following example, but are included in the subsequent detailed description below
|
| 442 |
+
|
| 443 |
+
The forward output of scan is computed as:
|
| 444 |
+
carry, ys = scan(combine_fn, init, xs).
|
| 445 |
+
|
| 446 |
+
This computation can be unpacked as
|
| 447 |
+
c_0, ys_0 = combine_fn(init, xs_0)
|
| 448 |
+
c_1, ys_1 = combine_fn(carry_0, xs_1)
|
| 449 |
+
c_2, ys_2 = combine_fn(carry_1, xs_2)
|
| 450 |
+
...
|
| 451 |
+
c_T, ys_T = combine_fn(carry_(T-1), xs_T)
|
| 452 |
+
|
| 453 |
+
We collect c_0, c_1, ..., c_T into a vector of carries that we save for the backward,
|
| 454 |
+
but we only output (c_T, ys),
|
| 455 |
+
where ys is the vector of all intermediate outputs [y_0, y_1, ..., y_T].
|
| 456 |
+
|
| 457 |
+
Given the carries and the ys, the gradients for xs and for init can be computed as follows:
|
| 458 |
+
We receive the upstream gradients in torch.autograd.Function, i.e., we get g_c_T and g_ys,
|
| 459 |
+
where g_ys is the vector of all intermediate gradients of the outputs [g_ys_0, g_ys_1, ..., g_ys_T]
|
| 460 |
+
|
| 461 |
+
We then proceed to compute the gradients for the init (g_init) and the xs (g_xs) by running a
|
| 462 |
+
scan operation reverse over time. For example,
|
| 463 |
+
|
| 464 |
+
g_c_(T-1), g_xs_T = combine_fn_bw(c_(T-1), xs_T, g_c_T, g_ys_T)
|
| 465 |
+
g_c_(T-2), g_xs_(T-1) = combine_fn_bw(c_(T-2), xs_(T-1), g_c_(T-1), g_ys_(T-1))
|
| 466 |
+
g_c_(T-3), g_xs_(T-2) = combine_fn_bw(c_(T-3), xs_(T-2), g_c_(T-2), g_ys_(T-2))
|
| 467 |
+
...
|
| 468 |
+
g_init, g_xs_1 = combine_fn_bw(c_0, xs_1, g_c_0, g_ys_1)
|
| 469 |
+
0 , g_xs_0 = combine_fn_bw(init, xs_0, g_init, g_ys_0),
|
| 470 |
+
|
| 471 |
+
where combine_fn_bw takes the forward inputs of step t (i.e. c_(t-1), xs_t),
|
| 472 |
+
the gradients of the carry of step t (i.e. g_c_t) and
|
| 473 |
+
the upstream gradient of the output of step t (i.e. g_ys_T)
|
| 474 |
+
and returns the gradient of xs_t -> g_xs_t, as well as the gradient for the carry of step t-1 -> g_c_(t-1).
|
| 475 |
+
|
| 476 |
+
Through this procedure we end up with the
|
| 477 |
+
gradients for the init -> g_init,
|
| 478 |
+
the gradients for the xs -> g_xs.
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
NOTE: [scan autograd implementation]
|
| 482 |
+
|
| 483 |
+
The forward of scan can be computed as:
|
| 484 |
+
1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``:
|
| 485 |
+
To use a scan operation for the backward path as well, we need access to the carries from all steps.
|
| 486 |
+
Thus, the function ``combine_fn`` is wrapped such that it returns all carries and not only the last carry.
|
| 487 |
+
In particular, we define ``combine_fn_with_carry_checkpoint``:
|
| 488 |
+
def combine_fn_with_carry_checkpoint(x: torch.Tensor, y: torch.Tensor):
|
| 489 |
+
carry, y = combine_fn(x, y)
|
| 490 |
+
return carry, (carry, y)
|
| 491 |
+
|
| 492 |
+
The scan operator will stack all outputs along the scan dimension.
|
| 493 |
+
Thus, by putting next_carry also into outputs of ``combine_fn_with_carry_checkpoint``,
|
| 494 |
+
the carries from all steps will be stacked and hence gives us chekpointed_carries
|
| 495 |
+
|
| 496 |
+
2.) Compute all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``:
|
| 497 |
+
c_T, (carries, ys) = scan_op(combine_fn_with_carry_checkpoint, init, xs, additional_inputs),
|
| 498 |
+
Where c_T (last carry) and ys (all outputs) are the original results of scan with the ``combine_fn``.
|
| 499 |
+
However, carries are checkpointed carries from all steps.
|
| 500 |
+
As a result of the forward, only the last carry c_T and the ys are returned,
|
| 501 |
+
while all carries are saved for the backward.
|
| 502 |
+
|
| 503 |
+
The backward of scan can be computed as:
|
| 504 |
+
|
| 505 |
+
3.) Prepare the backward graph:
|
| 506 |
+
We prepare the backward graph to be used in the backward function.
|
| 507 |
+
We utilize ``create_bw_fn`` to generate the joint function, i.e.,
|
| 508 |
+
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands), where fw_operands = [init, xs_0, additional_inputs]
|
| 509 |
+
|
| 510 |
+
The ctx._combine_fn_bw requires the primals (operands)
|
| 511 |
+
followed by the tangents (upstream gradients) from a single step
|
| 512 |
+
and produces the gradients of that step, i.e.,
|
| 513 |
+
g_c_(T-1), g_xs_T, g_additional_input_T = ctx._combine_fn_bw(c_(T-1), xs_T, additional_inputs, g_c_T, g_ys_T).
|
| 514 |
+
|
| 515 |
+
4.) Create a wrapper of the ``combine_fn_bw``, i.e., ``combine_fn_bw_grad_accumulation``:
|
| 516 |
+
In the forward, there may be additional inputs that participate in every forward step.
|
| 517 |
+
The gradients for those additional inputs are also computed at every step and need to be accumulated over all steps,
|
| 518 |
+
which is taken care of in this wrapper. For example:
|
| 519 |
+
def combine_fn_bw_grad_accumulation(*args):
|
| 520 |
+
carried_g_additional_input = args[:num_additional_inputs]
|
| 521 |
+
inputs_bw_fn = args[num_additional_inputs:]
|
| 522 |
+
g_c_(t-1), g_xs_t, g_additional_input_t = ctx._combine_fn_bw(*inputs_bw_fn)
|
| 523 |
+
new_g_additional_inputs = carried_g_additional_input + g_additional_input_t
|
| 524 |
+
# The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
|
| 525 |
+
# The ``g_xs_t`` is encoded as the output of the backward scan operator
|
| 526 |
+
return [*new_g_additional_inputs, *g_c_t, *g_xs_t]
|
| 527 |
+
|
| 528 |
+
5.) Perform the backward scan as
|
| 529 |
+
g_additional_inputs, g_init, g_xs = scan_op(combine_fn_bw_grad_accumulation, bw_init, bw_xs), where
|
| 530 |
+
bw_init consists of the initial gradient carry for the additional_inputs (initialized with 0s):
|
| 531 |
+
initial_g_additional_inputs, and the gradient of the last carry: g_c_T. Thus:
|
| 532 |
+
bwd_init = [*initial_g_additional_inputs, *g_c_T].
|
| 533 |
+
|
| 534 |
+
bw_xs consists of the combination of the upstream gradients g_ys,
|
| 535 |
+
the forward carries prepended with the fw_init, i.e., bw_carries = concat([fw_init, fw_carries[:-1]]) and
|
| 536 |
+
the fw_xs. In particular,
|
| 537 |
+
bwd_xs = [*g_ys, *bw_carries, *fw_xs].
|
| 538 |
+
|
| 539 |
+
Note: g_c_T and g_ys are provided through the torch.autograd.Function.backward's input
|
| 540 |
+
|
| 541 |
+
As demonstrated in the Example above, this backward scan then yields the gradient for the init -> g_init
|
| 542 |
+
and the gradient for the xs -> g_xs
|
| 543 |
+
|
| 544 |
+
NOTE: [scan partial grad handling]
|
| 545 |
+
If any element of init, of xs, of the outputs or of the additional_inputs does not require gradients,
|
| 546 |
+
i.e., requires_grad=False, there will be still gradients returned for those elements,
|
| 547 |
+
but those gradients will be a tensor filled with zeros of the same shape as the element itself.
|
| 548 |
+
|
| 549 |
+
A special case are additional_inputs that are not tensors. Such inputs can occur for example with symbolic tracing,
|
| 550 |
+
where the shape symbol (SymInt) becomes an additional_input.
|
| 551 |
+
For such cases, we compute a ``additional_inputs_tensor_mask``, which is True for elements of additional_inputs
|
| 552 |
+
that are tensors and False otherwise. Gradients of additional_inputs are only accumulated if this mask is True,
|
| 553 |
+
otherwise, the value of initial_g_additional_inputs is passed, which is None for non-Tensor values.
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
+
@staticmethod
|
| 557 |
+
def forward(
|
| 558 |
+
ctx,
|
| 559 |
+
combine_fn,
|
| 560 |
+
num_leaves_init,
|
| 561 |
+
num_leaves_xs,
|
| 562 |
+
num_additional_inputs,
|
| 563 |
+
*operands,
|
| 564 |
+
):
|
| 565 |
+
ctx._num_leaves_init = num_leaves_init
|
| 566 |
+
ctx._num_leaves_xs = num_leaves_xs
|
| 567 |
+
ctx._num_additional_inputs = num_additional_inputs
|
| 568 |
+
ctx._combine_fn = combine_fn
|
| 569 |
+
init, xs, additional_inputs = split_into_chunks(
|
| 570 |
+
operands, [num_leaves_init, num_leaves_xs, num_additional_inputs]
|
| 571 |
+
)
|
| 572 |
+
additional_inputs_tensor_mask = get_tensor_mask(additional_inputs)
|
| 573 |
+
ctx._additional_inputs_tensor_mask = additional_inputs_tensor_mask
|
| 574 |
+
|
| 575 |
+
# We snapshot the dispatch keys in forward for materializing the
|
| 576 |
+
# the bw_graph in backward.
|
| 577 |
+
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
|
| 578 |
+
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
|
| 579 |
+
|
| 580 |
+
# 1.) Prepare the forward graph wrapper ``combine_fn_with_carry_checkpoint``
|
| 581 |
+
# The wrapper of the forward graph returns carries from all iterations,
|
| 582 |
+
# not just from the last iteration. These are required in the backward path
|
| 583 |
+
def combine_fn_with_carry_checkpoint(*args):
|
| 584 |
+
carry, y = _extract_carry_and_out(combine_fn(*args), num_leaves_init)
|
| 585 |
+
return [
|
| 586 |
+
*carry,
|
| 587 |
+
# We additionally checkpoint all the intemediate carry outputs for backward.
|
| 588 |
+
*[
|
| 589 |
+
n_c.clone().detach() if isinstance(n_c, torch.Tensor) else n_c
|
| 590 |
+
for n_c in carry
|
| 591 |
+
],
|
| 592 |
+
*y,
|
| 593 |
+
]
|
| 594 |
+
|
| 595 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 596 |
+
# 2.) Compute the all carries, the last carry and all outputs using ``combine_fn_with_carry_checkpoint``
|
| 597 |
+
c_T, carries_ys = _extract_carry_and_out(
|
| 598 |
+
scan_op(
|
| 599 |
+
combine_fn_with_carry_checkpoint,
|
| 600 |
+
init,
|
| 601 |
+
xs,
|
| 602 |
+
additional_inputs,
|
| 603 |
+
),
|
| 604 |
+
num_leaves_init,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
# Collect the carries for each time step from the outs
|
| 608 |
+
# and save them for the backward path
|
| 609 |
+
carries = list(carries_ys[:num_leaves_init])
|
| 610 |
+
ys = list(carries_ys[num_leaves_init:])
|
| 611 |
+
save_tensors_and_symints_for_backward(ctx, list(operands) + carries + ys)
|
| 612 |
+
ctx._num_leaves_ys = len(ys)
|
| 613 |
+
|
| 614 |
+
return (*c_T, *ys)
|
| 615 |
+
|
| 616 |
+
@staticmethod
|
| 617 |
+
def backward(ctx, *flat_grads):
|
| 618 |
+
r"""
|
| 619 |
+
This function computes the gradients of the scan operation.
|
| 620 |
+
It does so by using a scan operator using all carries and the upstream gradients (see description above)
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
flat_grads (torch.Tensor): The tensor of flattened upstream gradients.
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
# Collect the saved items from the forward
|
| 627 |
+
num_leaves_init = ctx._num_leaves_init
|
| 628 |
+
num_leaves_xs = ctx._num_leaves_xs
|
| 629 |
+
num_leaves_ys = ctx._num_leaves_ys
|
| 630 |
+
num_additional_inputs = ctx._num_additional_inputs
|
| 631 |
+
additional_inputs_tensor_mask = ctx._additional_inputs_tensor_mask
|
| 632 |
+
|
| 633 |
+
def prepend_init_to_carries(init, carries):
|
| 634 |
+
# Prepare the carries for the backward path.
|
| 635 |
+
# This requires to concatenate the init and the carries
|
| 636 |
+
return [
|
| 637 |
+
torch.cat([torch.unsqueeze(i, 0), c[:-1]], dim=0)
|
| 638 |
+
for i, c in zip(init, carries)
|
| 639 |
+
]
|
| 640 |
+
|
| 641 |
+
def initialize_g_additional_inputs(
|
| 642 |
+
additional_inputs,
|
| 643 |
+
):
|
| 644 |
+
# The initial gradients for the additional_inputs are all zeros
|
| 645 |
+
g_additional_inputs = [
|
| 646 |
+
torch.zeros_like(ai) if ai_tm else None
|
| 647 |
+
for ai_tm, ai in zip(additional_inputs_tensor_mask, additional_inputs)
|
| 648 |
+
]
|
| 649 |
+
return g_additional_inputs
|
| 650 |
+
|
| 651 |
+
# Retrieve the forward inputs and the forward outputs and dissect them
|
| 652 |
+
flat_args = saved_tensors_and_symints(ctx)
|
| 653 |
+
fw_init, fw_xs, additional_inputs, fw_carries, fw_ys = split_into_chunks(
|
| 654 |
+
flat_args,
|
| 655 |
+
[
|
| 656 |
+
num_leaves_init,
|
| 657 |
+
num_leaves_xs,
|
| 658 |
+
num_additional_inputs,
|
| 659 |
+
num_leaves_init,
|
| 660 |
+
num_leaves_ys,
|
| 661 |
+
],
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
# 3.) Prepare the backward graph
|
| 665 |
+
fw_operands = (
|
| 666 |
+
*fw_init,
|
| 667 |
+
*[first_slice_copy(xs) for xs in fw_xs],
|
| 668 |
+
*additional_inputs,
|
| 669 |
+
)
|
| 670 |
+
ctx._combine_fn_bw = create_bw_fn(ctx._combine_fn, fw_operands)
|
| 671 |
+
|
| 672 |
+
# 4.) Create the BW wrapper to accumulate the gradients for the additional_inputs
|
| 673 |
+
def combine_fn_bw_grad_accumulation(*args):
|
| 674 |
+
# Dissect args and re-order them for the ``ctx._combine_fn_bw``
|
| 675 |
+
# The content of ``combine_fn_bw_tangents`` is [*carries_g, *outs_g]
|
| 676 |
+
# The content of ``combine_fn_bw_primals`` is [*init, *xs, *additional_inputs]
|
| 677 |
+
(
|
| 678 |
+
carried_g_additional_input,
|
| 679 |
+
combine_fn_bw_tangents,
|
| 680 |
+
combine_fn_bw_primals,
|
| 681 |
+
) = split_into_chunks(
|
| 682 |
+
args,
|
| 683 |
+
[
|
| 684 |
+
num_additional_inputs,
|
| 685 |
+
num_leaves_init + num_leaves_ys,
|
| 686 |
+
num_leaves_init + num_leaves_xs + num_additional_inputs,
|
| 687 |
+
],
|
| 688 |
+
)
|
| 689 |
+
combine_fn_bw_args = (*combine_fn_bw_primals, *combine_fn_bw_tangents)
|
| 690 |
+
|
| 691 |
+
g_c_t, g_xs_t, g_additional_inputs_t = split_into_chunks(
|
| 692 |
+
ctx._combine_fn_bw(*combine_fn_bw_args),
|
| 693 |
+
[num_leaves_init, num_leaves_xs, num_additional_inputs],
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
new_g_additional_inputs = [
|
| 697 |
+
# If the additional inputs are ints or SymInts, those values are taken as is and no gradients are added
|
| 698 |
+
carr_g + curr_g if add_inp_tm else carr_g
|
| 699 |
+
for add_inp_tm, carr_g, curr_g in zip(
|
| 700 |
+
additional_inputs_tensor_mask,
|
| 701 |
+
carried_g_additional_input,
|
| 702 |
+
g_additional_inputs_t,
|
| 703 |
+
)
|
| 704 |
+
]
|
| 705 |
+
|
| 706 |
+
# The ``new_g_additional_inputs`` and the ``g_c_t`` are encoded in the carry of the backward scan operator
|
| 707 |
+
# The ``g_xs_t`` is encoded as the output of the backward scan operator
|
| 708 |
+
return [*new_g_additional_inputs, *g_c_t, *g_xs_t]
|
| 709 |
+
|
| 710 |
+
# Materialize the ``combine_fn_bw_grad_accumulation``
|
| 711 |
+
def construct_args_single_step_bw():
|
| 712 |
+
# This function constructs the arguments for a single step of the backward scan.
|
| 713 |
+
# In other words, it creates the arguments for ``combine_fn_bw_grad_accumulation``
|
| 714 |
+
# The order of the arguments returned is identical to the order the backward scan
|
| 715 |
+
# operations provides
|
| 716 |
+
|
| 717 |
+
# The following arguments are used for the backward part of the joint graph
|
| 718 |
+
# The first argument relates to the gradient accumulation of the additional inputs.
|
| 719 |
+
# Because only tensor elements of additional inputs can have requires_grad=True,
|
| 720 |
+
# the values for non-tensor elements of additional inputs are None
|
| 721 |
+
masked_additional_inputs = [
|
| 722 |
+
a.clone() if add_inp_tm else None
|
| 723 |
+
for add_inp_tm, a in zip(
|
| 724 |
+
additional_inputs_tensor_mask, additional_inputs
|
| 725 |
+
)
|
| 726 |
+
]
|
| 727 |
+
|
| 728 |
+
# The second argument relates to the gradients of the carries.
|
| 729 |
+
# Because the arguments are for a single step only,
|
| 730 |
+
# only the first slice of the carries is used.
|
| 731 |
+
sliced_carries = [first_slice_copy(c) for c in fw_carries]
|
| 732 |
+
|
| 733 |
+
# The third argument relates to the gradients of the ys.
|
| 734 |
+
# Because the arguments are for a single step only,
|
| 735 |
+
# only the first slice of the ys is used.
|
| 736 |
+
sliced_ys = [first_slice_copy(o) for o in fw_ys]
|
| 737 |
+
|
| 738 |
+
# The following arguments are used for the forward part of the joint graph
|
| 739 |
+
# The fourth argument relates to the init for the forward.
|
| 740 |
+
# I.e., fw_init
|
| 741 |
+
|
| 742 |
+
# The fifth argument relates to the xs for the forward.
|
| 743 |
+
# Because the arguments are for a single step only,
|
| 744 |
+
# only the first slice of the xs is used.
|
| 745 |
+
# Note: It is important to preserve the requires_grad flag of xs
|
| 746 |
+
# and thus we use the wrapper function ``first_slice_copy_with_grad``
|
| 747 |
+
fw_xs_slice = first_slice_copy_with_grad(fw_xs)
|
| 748 |
+
|
| 749 |
+
# The last argument relates to the additional inputs for the forward.
|
| 750 |
+
# I.e., additional_inputs
|
| 751 |
+
|
| 752 |
+
return (
|
| 753 |
+
*masked_additional_inputs,
|
| 754 |
+
*sliced_carries,
|
| 755 |
+
*sliced_ys,
|
| 756 |
+
*fw_init,
|
| 757 |
+
*fw_xs_slice,
|
| 758 |
+
*additional_inputs,
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
args_single_step_bw = construct_args_single_step_bw()
|
| 762 |
+
|
| 763 |
+
# TODO: we need to materialize the bw graphs because dynamo is unable to
|
| 764 |
+
# trace through the joint function when torch.compile torch.autograd.grad.
|
| 765 |
+
combine_fn_bw_grad_accumulation_gm = materialize_as_graph(
|
| 766 |
+
combine_fn_bw_grad_accumulation,
|
| 767 |
+
args_single_step_bw,
|
| 768 |
+
ctx._fw_include_key_set,
|
| 769 |
+
ctx._fw_exclude_key_set,
|
| 770 |
+
force_enable_grad=True,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# Decompose the flat_grads into g_c_T, g_ys
|
| 774 |
+
g_c_T, g_ys = split_into_chunks(flat_grads, [num_leaves_init, num_leaves_ys])
|
| 775 |
+
|
| 776 |
+
# Initialize the g_additional_inputs with zero-tensors.
|
| 777 |
+
# This step is necessary because the gradients of the additional inputs are accumulated in the
|
| 778 |
+
# ``wrapper_bwd_combine_fn`` and thus need a zero-initialized starting point
|
| 779 |
+
initial_g_additional_inputs = initialize_g_additional_inputs(additional_inputs)
|
| 780 |
+
|
| 781 |
+
# Prepend the inits to the carries.
|
| 782 |
+
# This is needed, because when computing the gradients, the last carry is not needed
|
| 783 |
+
# but the first carry, the init, is required.
|
| 784 |
+
bw_carries = prepend_init_to_carries(fw_init, fw_carries)
|
| 785 |
+
|
| 786 |
+
# Prepare the xs for the backward scan.
|
| 787 |
+
bwd_xs = [*g_ys, *bw_carries, *fw_xs]
|
| 788 |
+
|
| 789 |
+
# The flipping of the ``bwd_xs`` is necessary because the scan_op in the backward is always performed in reverse
|
| 790 |
+
bwd_xs = [torch.flip(elem, [0]) for elem in bwd_xs]
|
| 791 |
+
|
| 792 |
+
# Prepare the bwd_init
|
| 793 |
+
bwd_init = [*initial_g_additional_inputs, *g_c_T]
|
| 794 |
+
|
| 795 |
+
# 5.) Perform the backwrad scan:
|
| 796 |
+
# The ``combine_fn_bw_wrapped`` receives the
|
| 797 |
+
# initial_g_additional_inputs and the last carry as the ``bwd_init`` and the
|
| 798 |
+
# gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs``
|
| 799 |
+
gradients = scan_op(
|
| 800 |
+
combine_fn_bw_grad_accumulation_gm,
|
| 801 |
+
bwd_init,
|
| 802 |
+
bwd_xs,
|
| 803 |
+
additional_inputs,
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
# Unpack the computed gradients
|
| 807 |
+
g_additional_inputs, g_init, g_xs = split_into_chunks(
|
| 808 |
+
gradients, [num_additional_inputs, num_leaves_init, num_leaves_xs]
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# The flipping back along the scan dimension is required to get the gradients in the right order for ``xs``
|
| 812 |
+
g_xs = [torch.flip(elem, [0]) for elem in g_xs]
|
| 813 |
+
|
| 814 |
+
return *[None] * 4, *g_init, *g_xs, *g_additional_inputs
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
@scan_op.py_autograd_impl
|
| 818 |
+
def scan_autograd(combine_fn, init, xs, additional_inputs):
|
| 819 |
+
num_leaves_init = len(init)
|
| 820 |
+
num_leaves_xs = len(xs)
|
| 821 |
+
num_additional_inputs = len(additional_inputs)
|
| 822 |
+
|
| 823 |
+
flat_out = ScanAutogradOp.apply(
|
| 824 |
+
combine_fn,
|
| 825 |
+
num_leaves_init,
|
| 826 |
+
num_leaves_xs,
|
| 827 |
+
num_additional_inputs,
|
| 828 |
+
*(tuple(init) + tuple(xs) + additional_inputs),
|
| 829 |
+
)
|
| 830 |
+
return *flat_out[:num_leaves_init], *flat_out[num_leaves_init:]
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
@scan_op.py_impl(ProxyTorchDispatchMode)
|
| 834 |
+
def scan_proxy_mode(mode, combine_fn, init, xs, additional_inputs):
|
| 835 |
+
return trace_scan(mode, scan_op, combine_fn, init, xs, additional_inputs)
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
@scan_op.py_impl(FakeTensorMode)
|
| 839 |
+
def scan_fake_tensor_mode(mode, combine_fn, init, xs, additional_inputs):
|
| 840 |
+
with mode:
|
| 841 |
+
scan_length = xs[0].shape[0]
|
| 842 |
+
carry, outputs = _extract_carry_and_out(
|
| 843 |
+
combine_fn(
|
| 844 |
+
*init,
|
| 845 |
+
*[first_slice_copy(inp) for inp in xs],
|
| 846 |
+
*additional_inputs,
|
| 847 |
+
),
|
| 848 |
+
len(init),
|
| 849 |
+
)
|
| 850 |
+
out = (
|
| 851 |
+
*carry,
|
| 852 |
+
*(stack_y(t, scan_length) for t in outputs),
|
| 853 |
+
)
|
| 854 |
+
return out
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
@scan_op.py_functionalize_impl
|
| 858 |
+
def scan_functionalize(ctx, combine_fn, init, xs, additional_inputs):
|
| 859 |
+
from torch._higher_order_ops.utils import (
|
| 860 |
+
_check_alias_and_mutation,
|
| 861 |
+
_maybe_run_with_interpreter,
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
unwrapped_xs = ctx.unwrap_tensors(xs)
|
| 865 |
+
unwrapped_init = ctx.unwrap_tensors(init)
|
| 866 |
+
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
| 867 |
+
|
| 868 |
+
with ctx.redispatch_to_next():
|
| 869 |
+
functional_combine_fn = ctx.functionalize(
|
| 870 |
+
_maybe_run_with_interpreter(combine_fn)
|
| 871 |
+
)
|
| 872 |
+
sample_unwrapped_xs_sliced = [first_slice_copy(inp) for inp in unwrapped_xs]
|
| 873 |
+
sample_inputs = list(
|
| 874 |
+
itertools.chain(
|
| 875 |
+
unwrapped_init,
|
| 876 |
+
sample_unwrapped_xs_sliced,
|
| 877 |
+
unwrapped_additional_inputs,
|
| 878 |
+
)
|
| 879 |
+
)
|
| 880 |
+
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
| 881 |
+
_check_alias_and_mutation(combine_fn, sample_inputs, "scan", pre_dispatch)
|
| 882 |
+
ret = scan_op(
|
| 883 |
+
functional_combine_fn,
|
| 884 |
+
unwrapped_init,
|
| 885 |
+
unwrapped_xs,
|
| 886 |
+
unwrapped_additional_inputs,
|
| 887 |
+
)
|
| 888 |
+
return ctx.wrap_tensors(ret)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
# dense implementation for scan. Used for testing only.
|
| 892 |
+
def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False):
|
| 893 |
+
carry_leaves, carry_spec = pytree.tree_flatten(init)
|
| 894 |
+
inp_leaves, inp_spec = pytree.tree_flatten(xs)
|
| 895 |
+
if xs is None or len(inp_leaves) == 0:
|
| 896 |
+
return init, []
|
| 897 |
+
result_flat = []
|
| 898 |
+
carry = carry_leaves
|
| 899 |
+
op = reversed if reverse else lambda x: x
|
| 900 |
+
|
| 901 |
+
dummy_carry, dummy_out = combine_fn(
|
| 902 |
+
pytree.tree_unflatten(carry, carry_spec),
|
| 903 |
+
pytree.tree_unflatten(
|
| 904 |
+
[first_slice_copy(elem, dim) for elem in inp_leaves],
|
| 905 |
+
inp_spec,
|
| 906 |
+
),
|
| 907 |
+
)
|
| 908 |
+
dummy_out_leaves, dummy_out_spec = pytree.tree_flatten(dummy_out)
|
| 909 |
+
num_leaves = len(dummy_out_leaves)
|
| 910 |
+
|
| 911 |
+
for ind in op(range(inp_leaves[0].size(dim))):
|
| 912 |
+
xs = [elem.select(dim, ind) for elem in inp_leaves]
|
| 913 |
+
|
| 914 |
+
carry, y = combine_fn(
|
| 915 |
+
pytree.tree_unflatten(carry, carry_spec),
|
| 916 |
+
pytree.tree_unflatten(xs, inp_spec),
|
| 917 |
+
)
|
| 918 |
+
carry, _ = pytree.tree_flatten(carry)
|
| 919 |
+
y, _ = pytree.tree_flatten(y)
|
| 920 |
+
result_flat.append(y)
|
| 921 |
+
|
| 922 |
+
results = [
|
| 923 |
+
torch.stack([e[leave_ind] for e in op(result_flat)])
|
| 924 |
+
for leave_ind in range(num_leaves)
|
| 925 |
+
]
|
| 926 |
+
return (
|
| 927 |
+
pytree.tree_unflatten(carry, carry_spec),
|
| 928 |
+
pytree.tree_unflatten(results, dummy_out_spec),
|
| 929 |
+
)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/schema.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils._pytree as pytree
|
| 7 |
+
from torch.fx.node import Target
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Below is an implementation of generating FunctionSchema from example values.
|
| 11 |
+
# This is helpful for generating FunctionSchema for HigherOrderOperator, where
|
| 12 |
+
# we don't have a function to inspect and each call of the higher order operator
|
| 13 |
+
# would have different schema.
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class HopArgumentInfo:
|
| 16 |
+
# Could give a name to the operand by default it's empty string.
|
| 17 |
+
name: str
|
| 18 |
+
example_value: Any
|
| 19 |
+
# Provide an default_value
|
| 20 |
+
default_value: Any
|
| 21 |
+
# Whether this arugment gets mutated in the hop subgraph.
|
| 22 |
+
# For output, this should always be False
|
| 23 |
+
is_mutated: bool
|
| 24 |
+
kw_only: bool
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class HopArgumentInfoGen:
|
| 28 |
+
@staticmethod
|
| 29 |
+
def from_example(
|
| 30 |
+
example_value: Any,
|
| 31 |
+
*,
|
| 32 |
+
name: str = "",
|
| 33 |
+
default_value: Optional[Any] = None,
|
| 34 |
+
is_mutated: bool = False,
|
| 35 |
+
kw_only: bool = False,
|
| 36 |
+
) -> HopArgumentInfo:
|
| 37 |
+
if default_value is not None:
|
| 38 |
+
assert type(example_value) == type(
|
| 39 |
+
default_value
|
| 40 |
+
), f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}"
|
| 41 |
+
|
| 42 |
+
return HopArgumentInfo(
|
| 43 |
+
name=name,
|
| 44 |
+
example_value=example_value,
|
| 45 |
+
default_value=default_value,
|
| 46 |
+
is_mutated=is_mutated,
|
| 47 |
+
kw_only=kw_only,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CTypeGen:
|
| 52 |
+
convert_to_base_ty = {
|
| 53 |
+
int: torch._C.IntType.get(),
|
| 54 |
+
float: torch._C.FloatType.get(),
|
| 55 |
+
str: torch._C.StringType.get(),
|
| 56 |
+
bool: torch._C.BoolType.get(),
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# should return torch._C.JitType but that annotation is busted
|
| 60 |
+
@staticmethod
|
| 61 |
+
def from_example(obj: Any) -> Any:
|
| 62 |
+
import torch
|
| 63 |
+
|
| 64 |
+
if isinstance(obj, torch.fx.GraphModule):
|
| 65 |
+
return torch._C.AnyType.get()
|
| 66 |
+
elif isinstance(obj, torch.SymInt):
|
| 67 |
+
return torch._C.SymIntType.get()
|
| 68 |
+
return torch._C._jit_try_infer_type(obj).type()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class CArgumentGen:
|
| 72 |
+
@staticmethod
|
| 73 |
+
def from_hop_argument_info(
|
| 74 |
+
arg_idx: int, arg_info: HopArgumentInfo, is_output: bool = False
|
| 75 |
+
) -> Any:
|
| 76 |
+
typ = CTypeGen.from_example(arg_info.example_value)
|
| 77 |
+
if is_output:
|
| 78 |
+
return torch._C.Argument("", typ, None, None, False, None)
|
| 79 |
+
|
| 80 |
+
alias_set = set({f"alias::a{arg_idx}"}) if arg_info.is_mutated else set()
|
| 81 |
+
alias_info = torch._C._AliasInfo(arg_info.is_mutated, alias_set, alias_set) # type: ignore[attr-defined]
|
| 82 |
+
return torch._C.Argument(
|
| 83 |
+
arg_info.name,
|
| 84 |
+
typ,
|
| 85 |
+
None,
|
| 86 |
+
arg_info.default_value,
|
| 87 |
+
arg_info.kw_only,
|
| 88 |
+
alias_info,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class HopSchemaGenerator:
|
| 93 |
+
def __init__(self, hop: torch._ops.HigherOrderOperator):
|
| 94 |
+
self.arg_infos: list[HopArgumentInfo] = []
|
| 95 |
+
self.example_outputs: list[Any] = []
|
| 96 |
+
self.schema_tree_spec: Optional[pytree.TreeSpec] = None
|
| 97 |
+
self.hop = hop
|
| 98 |
+
|
| 99 |
+
def add_arg(
|
| 100 |
+
self,
|
| 101 |
+
name: str,
|
| 102 |
+
example_value: Any,
|
| 103 |
+
default_value: Optional[Any] = None,
|
| 104 |
+
is_mutated: bool = False,
|
| 105 |
+
kw_only: bool = False,
|
| 106 |
+
) -> None:
|
| 107 |
+
if callable(example_value):
|
| 108 |
+
assert isinstance(
|
| 109 |
+
example_value, (torch.fx.GraphModule, torch._ops.OperatorBase)
|
| 110 |
+
), (
|
| 111 |
+
"Expect callable to be a GraphModule or an. Please call materialize_as_graph first "
|
| 112 |
+
f"to turn callable arguments {example_value} into a GraphModule."
|
| 113 |
+
)
|
| 114 |
+
_, flat_spec = pytree.tree_flatten(example_value)
|
| 115 |
+
if not flat_spec.is_leaf():
|
| 116 |
+
raise RuntimeError(
|
| 117 |
+
f"example_value {example_value} is not a leaf node. "
|
| 118 |
+
"Please only add flattened inputs to the hop schema. "
|
| 119 |
+
"If you need some structure in the arguments, please"
|
| 120 |
+
"add_arg for flattened args one by one then "
|
| 121 |
+
"call add_schema_tree_spec to register the original pytree "
|
| 122 |
+
" spec of the args."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
arg_info = HopArgumentInfoGen.from_example(
|
| 126 |
+
example_value=example_value,
|
| 127 |
+
name=name,
|
| 128 |
+
default_value=default_value,
|
| 129 |
+
is_mutated=is_mutated,
|
| 130 |
+
kw_only=kw_only,
|
| 131 |
+
)
|
| 132 |
+
self.arg_infos.append(arg_info)
|
| 133 |
+
|
| 134 |
+
def add_output(self, output: Any) -> None:
|
| 135 |
+
self.example_outputs.append(output)
|
| 136 |
+
|
| 137 |
+
def add_schema_tree_spec(self, *args: Any, **kwargs: Any) -> None:
|
| 138 |
+
"""schema tree spec is the tree spec from flattening all inputs to the hop with pytree.tree_flatten
|
| 139 |
+
Since torch.FunctionSchema only have proper mutation/alias support for flattened inputs, we need
|
| 140 |
+
to store the tree spec in order to reconstruct the inputs to the hop.
|
| 141 |
+
"""
|
| 142 |
+
self.schema_tree_spec = pytree.tree_flatten((args, kwargs))[1]
|
| 143 |
+
|
| 144 |
+
def gen_schema(self) -> torch._C.FunctionSchema:
|
| 145 |
+
for i, arg_info in enumerate(self.arg_infos):
|
| 146 |
+
arg_spec = pytree.tree_flatten(arg_info.example_value)[1]
|
| 147 |
+
if not arg_spec.is_leaf() and self.schema_tree_spec is None:
|
| 148 |
+
raise RuntimeError(
|
| 149 |
+
f"example_value of arg_infos[{i}] is {arg_info.example_value}, which is not a leaf node. "
|
| 150 |
+
"Please call add_schema_tree_spec to add a schema tree spec first. "
|
| 151 |
+
"Or consider changing the hop's signature to only take flattened arguments."
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return CFunctionSchemaGen.from_hop_argument_info(
|
| 155 |
+
str(self.hop),
|
| 156 |
+
self.arg_infos,
|
| 157 |
+
HopArgumentInfoGen.from_example(tuple(self.example_outputs), name="out"),
|
| 158 |
+
self.schema_tree_spec,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class CFunctionSchemaGen:
|
| 163 |
+
"""
|
| 164 |
+
Note: [HigherOrderOperator schema generation]
|
| 165 |
+
Each invocation of a HigherOrderOperator will have a different schema.
|
| 166 |
+
For example, the schema of torch.cond varies depending on the true_fn and
|
| 167 |
+
false_fn. So we need a way to generate the schema for each invocation of a HOP.
|
| 168 |
+
|
| 169 |
+
We want to enforce the following invariants for HOP's schema:
|
| 170 |
+
1. Flattened inputs. There should be no pytree structure in it.
|
| 171 |
+
2. Flattened outputs. Note even if the hop returns a single value, it should be wrapped as a tuple.
|
| 172 |
+
3. No aliasing. This includes inp-inp aliasing, inp-out aliasing and out-out aliasing.
|
| 173 |
+
|
| 174 |
+
By enforcing these invariants, we could make HOP's schema meets the requirement of schema parser
|
| 175 |
+
and makes hop easier to handle downstream. For example, suppose we have an invoke_quant_test HOP:
|
| 176 |
+
|
| 177 |
+
class GraphModule(torch.nn.Module):
|
| 178 |
+
def forward(self, l_x_, l_y_):
|
| 179 |
+
subgraph_0 = self.subgraph_0
|
| 180 |
+
invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4');
|
| 181 |
+
|
| 182 |
+
class subgraph_0(torch.nn.Module):
|
| 183 |
+
def forward(self, l_x_, l_y_):
|
| 184 |
+
add_ = l_x_.add_(1)
|
| 185 |
+
matmul = l_x_ @ l_y_
|
| 186 |
+
sin = matmul.sin()
|
| 187 |
+
child = sin.cos()
|
| 188 |
+
child_1 = l_x_ + l_y_
|
| 189 |
+
child_2 = l_x_ - l_y_
|
| 190 |
+
child_3 = l_x_ @ l_y_
|
| 191 |
+
return (child, child_1, child_2, child_3)
|
| 192 |
+
|
| 193 |
+
By encoding the inputs of hop into a list of HopArgumentInfo and output as a single HopArgumentInfo,
|
| 194 |
+
we would get the following schema:
|
| 195 |
+
invoke_quant_test(Any arg0, Tensor(!) arg1, Tensor arg2, str scheme="\\"nf4\\"") -> (Tensor, Tensor, Tensor, Tensor)
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def from_hop_argument_info(
|
| 200 |
+
op_name: str,
|
| 201 |
+
inp_argument_info: list[HopArgumentInfo],
|
| 202 |
+
out_argument_info: HopArgumentInfo,
|
| 203 |
+
schema_tree_spec: Optional[pytree.TreeSpec],
|
| 204 |
+
) -> Any:
|
| 205 |
+
args = []
|
| 206 |
+
for i, arg_info in enumerate(inp_argument_info):
|
| 207 |
+
args.append(CArgumentGen.from_hop_argument_info(i, arg_info))
|
| 208 |
+
|
| 209 |
+
# NOTE: we want the output to always be a single argument with torch._C.TupleType.
|
| 210 |
+
assert isinstance(
|
| 211 |
+
out_argument_info.example_value, tuple
|
| 212 |
+
), f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}"
|
| 213 |
+
assert (
|
| 214 |
+
not out_argument_info.is_mutated
|
| 215 |
+
), "out_argument_info.is_mutated should always be set to False."
|
| 216 |
+
rets = None
|
| 217 |
+
if len(out_argument_info.example_value) == 1:
|
| 218 |
+
rets = [CArgumentGen.from_hop_argument_info(0, out_argument_info, True)]
|
| 219 |
+
else:
|
| 220 |
+
rets = [
|
| 221 |
+
CArgumentGen.from_hop_argument_info(
|
| 222 |
+
i,
|
| 223 |
+
HopArgumentInfoGen.from_example(
|
| 224 |
+
name=f"out{i}",
|
| 225 |
+
example_value=val,
|
| 226 |
+
default_value=None,
|
| 227 |
+
is_mutated=False,
|
| 228 |
+
),
|
| 229 |
+
is_output=True,
|
| 230 |
+
)
|
| 231 |
+
for i, val in enumerate(out_argument_info.example_value)
|
| 232 |
+
]
|
| 233 |
+
|
| 234 |
+
return HopSchema(
|
| 235 |
+
op_name,
|
| 236 |
+
"",
|
| 237 |
+
args,
|
| 238 |
+
rets,
|
| 239 |
+
False,
|
| 240 |
+
False,
|
| 241 |
+
schema_tree_spec,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class HopSchema(torch._C.FunctionSchema):
|
| 246 |
+
def __init__(
|
| 247 |
+
self,
|
| 248 |
+
name: str,
|
| 249 |
+
overload_name: str,
|
| 250 |
+
arguments: list[torch._C.Argument],
|
| 251 |
+
returns: list[torch._C.Argument],
|
| 252 |
+
is_vararg: bool,
|
| 253 |
+
is_varret: bool,
|
| 254 |
+
schema_tree_spec: Optional[pytree.TreeSpec],
|
| 255 |
+
):
|
| 256 |
+
self.tree_spec = schema_tree_spec
|
| 257 |
+
self.is_vararg = is_vararg
|
| 258 |
+
self.is_varret = is_varret
|
| 259 |
+
super().__init__(
|
| 260 |
+
name,
|
| 261 |
+
overload_name,
|
| 262 |
+
arguments,
|
| 263 |
+
returns,
|
| 264 |
+
self.is_vararg,
|
| 265 |
+
self.is_varret,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
def __deepcopy__(self, memo: Any) -> "HopSchema":
|
| 269 |
+
# Need to additionally copy the tree_spec since
|
| 270 |
+
# it's not a member of torch._C.FunctionSchema
|
| 271 |
+
return HopSchema(
|
| 272 |
+
self.name,
|
| 273 |
+
self.overload_name,
|
| 274 |
+
self.arguments,
|
| 275 |
+
self.returns,
|
| 276 |
+
self.is_vararg,
|
| 277 |
+
self.is_varret,
|
| 278 |
+
copy.deepcopy(self.tree_spec),
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def find_hop_schema(
|
| 283 |
+
gm: torch.fx.GraphModule, target: Target
|
| 284 |
+
) -> list[torch._C.FunctionSchema]:
|
| 285 |
+
schemas = []
|
| 286 |
+
for node in gm.graph.find_nodes(op="call_function", target=target):
|
| 287 |
+
|
| 288 |
+
def _get_example_value(node: torch.fx.Node) -> Any:
|
| 289 |
+
if node.op == "get_attr":
|
| 290 |
+
assert isinstance(node.target, str)
|
| 291 |
+
return getattr(gm, node.target)
|
| 292 |
+
else:
|
| 293 |
+
return (
|
| 294 |
+
node.meta["example_value"]
|
| 295 |
+
if "example_value" in node.meta
|
| 296 |
+
else node.meta["val"]
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
fake_args, fake_kwargs = pytree.tree_map_only(
|
| 300 |
+
torch.fx.Node,
|
| 301 |
+
_get_example_value,
|
| 302 |
+
(node.args, node.kwargs),
|
| 303 |
+
)
|
| 304 |
+
schema = node.target.gen_schema(*fake_args, **fake_kwargs)
|
| 305 |
+
schemas.append(schema)
|
| 306 |
+
return schemas
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/strict_mode.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
import torch._subclasses.functional_tensor
|
| 4 |
+
import torch.utils._pytree as pytree
|
| 5 |
+
from torch._C import DispatchKey
|
| 6 |
+
from torch._functorch.utils import exposed_in
|
| 7 |
+
from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented
|
| 8 |
+
from torch._ops import HigherOrderOperator
|
| 9 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 10 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 11 |
+
_temp_remove_metadata_torch_function_mode,
|
| 12 |
+
_temp_remove_pre_dispatch_torch_function_mode,
|
| 13 |
+
disable_proxy_modes_tracing,
|
| 14 |
+
make_fx,
|
| 15 |
+
ProxyTorchDispatchMode,
|
| 16 |
+
track_tensor_tree,
|
| 17 |
+
)
|
| 18 |
+
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@exposed_in("torch")
|
| 22 |
+
def strict_mode(callable, operands):
|
| 23 |
+
from torch._dynamo.backends.debugging import (
|
| 24 |
+
make_eager_backend_with_torch_function_modes,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
if torch.compiler.is_dynamo_compiling():
|
| 28 |
+
return strict_mode_op(callable, operands)
|
| 29 |
+
|
| 30 |
+
with _set_compilation_env():
|
| 31 |
+
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
| 32 |
+
with _temp_remove_pre_dispatch_torch_function_mode() as predispatch_mode:
|
| 33 |
+
modes = [metadata_mode, predispatch_mode]
|
| 34 |
+
modes = [mode for mode in modes if mode is not None]
|
| 35 |
+
if modes:
|
| 36 |
+
backend = make_eager_backend_with_torch_function_modes(modes)
|
| 37 |
+
else:
|
| 38 |
+
backend = "eager"
|
| 39 |
+
with torch._dynamo.utils.disable_cache_limit():
|
| 40 |
+
return torch.compile(
|
| 41 |
+
strict_mode_op, backend=backend, fullgraph=True
|
| 42 |
+
)(callable, operands)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class StrictMode(HigherOrderOperator):
|
| 46 |
+
def __init__(self):
|
| 47 |
+
super().__init__("strict_mode")
|
| 48 |
+
|
| 49 |
+
def __call__(self, callable, operands):
|
| 50 |
+
return super().__call__(callable, operands)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
strict_mode_op = StrictMode()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 57 |
+
def strict_mode_op_dense(callable, operands):
|
| 58 |
+
mode = _get_current_dispatch_mode()
|
| 59 |
+
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
| 60 |
+
return callable(*operands)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
strict_mode_op.py_autograd_impl(
|
| 64 |
+
autograd_not_implemented(strict_mode_op, deferred_error=True)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@strict_mode_op.py_impl(ProxyTorchDispatchMode)
|
| 69 |
+
def inner(mode, callable, operands):
|
| 70 |
+
return trace_strict_mode(mode, strict_mode_op, callable, operands)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def trace_strict_mode(mode, strict_mode_op, callable, operands):
|
| 74 |
+
pre_dispatch = getattr(mode, "pre_dispatch", False)
|
| 75 |
+
|
| 76 |
+
with disable_proxy_modes_tracing():
|
| 77 |
+
graph = make_fx(callable, pre_dispatch=pre_dispatch)(*operands)
|
| 78 |
+
|
| 79 |
+
graph_name = mode.tracer.get_fresh_qualname("strict_graph_")
|
| 80 |
+
mode.tracer.root.register_module(graph_name, graph)
|
| 81 |
+
|
| 82 |
+
args = (graph, operands)
|
| 83 |
+
|
| 84 |
+
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
|
| 85 |
+
|
| 86 |
+
out_proxy = mode.tracer.create_proxy(
|
| 87 |
+
"call_function", strict_mode_op, proxy_args, {}, name="strict_mode"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
out = graph(*operands)
|
| 91 |
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@strict_mode_op.py_impl(FakeTensorMode)
|
| 95 |
+
def strict_mode_fake_tensor_mode(mode, callable, operands):
|
| 96 |
+
with mode:
|
| 97 |
+
true_outs = callable(*operands)
|
| 98 |
+
return true_outs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@strict_mode_op.py_functionalize_impl
|
| 102 |
+
def strict_mode_func(ctx, callable, inputs):
|
| 103 |
+
unwrapped_inputs = ctx.unwrap_tensors(inputs)
|
| 104 |
+
with ctx.redispatch_to_next():
|
| 105 |
+
functional_callable = ctx.functionalize(callable)
|
| 106 |
+
|
| 107 |
+
cond_return = strict_mode_op(functional_callable, unwrapped_inputs)
|
| 108 |
+
return ctx.wrap_tensors(cond_return)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/torchbind.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch._C import DispatchKey # @manual
|
| 7 |
+
from torch._functorch._aot_autograd.utils import KNOWN_TYPES
|
| 8 |
+
from torch._higher_order_ops.utils import autograd_not_implemented
|
| 9 |
+
from torch._library.fake_class_registry import (
|
| 10 |
+
_is_script_object,
|
| 11 |
+
_ns_and_class_name,
|
| 12 |
+
FakeScriptObject,
|
| 13 |
+
)
|
| 14 |
+
from torch._ops import HigherOrderOperator
|
| 15 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 16 |
+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
| 17 |
+
from torch.fx.node import has_side_effect
|
| 18 |
+
from torch.utils import _pytree as pytree
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
log = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# The call_torchbind operator represents a method invocation on a torchbind
|
| 25 |
+
# object. The calling convention is:
|
| 26 |
+
# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
|
| 27 |
+
# We do not expect users to write this operator directly. Instead it will be
|
| 28 |
+
# emitted by Dynamo when tracing encounters a torchbind object.
|
| 29 |
+
class CallTorchBind(HigherOrderOperator):
|
| 30 |
+
def __init__(self):
|
| 31 |
+
super().__init__("call_torchbind")
|
| 32 |
+
|
| 33 |
+
def __call__(self, obj, method, *args, **kwargs):
|
| 34 |
+
return super().__call__(obj, method, *args, **kwargs)
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def schema(obj, method) -> torch.FunctionSchema:
|
| 38 |
+
"""
|
| 39 |
+
Returns the schema of ``CallTorchbind.__call__``.
|
| 40 |
+
"""
|
| 41 |
+
assert isinstance(obj, torch._inductor.ir.TorchBindObject)
|
| 42 |
+
val = obj.get_real_obj()
|
| 43 |
+
schema = val._get_method(method).schema
|
| 44 |
+
schema_str = str(schema)
|
| 45 |
+
new_schema_str = f"call_torchbind({str(schema.arguments[0].real_type)} {schema.arguments[0].name},"
|
| 46 |
+
first_comma_index = schema_str.find(",")
|
| 47 |
+
if first_comma_index == -1:
|
| 48 |
+
# If no comma is found, find the last closing parenthesis
|
| 49 |
+
first_comma_index = schema_str.rfind(") ->")
|
| 50 |
+
new_schema_str = new_schema_str + " str method" + schema_str[first_comma_index:]
|
| 51 |
+
new_schema = torch._C.parse_schema(new_schema_str)
|
| 52 |
+
return new_schema
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
call_torchbind = CallTorchBind()
|
| 56 |
+
|
| 57 |
+
# Register this operator as side-effectful with FX.
|
| 58 |
+
# TODO: this is not really sufficient. While passes (hopefully) check
|
| 59 |
+
# Node.is_impure() and make good decisions, we also assume we can execute the
|
| 60 |
+
# graph as many times as we want without changing behavior, which is NOT true of
|
| 61 |
+
# ops that mutate torchbind object state.
|
| 62 |
+
has_side_effect(call_torchbind)
|
| 63 |
+
|
| 64 |
+
_orig_scriptmethod_call = torch.ScriptMethod.__call__
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def torchbind_method_redispatch(self, *args, **kwargs):
|
| 68 |
+
if _is_script_object(self.raw_owner):
|
| 69 |
+
return call_torchbind(self.raw_owner, self.name, *args, **kwargs)
|
| 70 |
+
return _orig_scriptmethod_call(self, *args, **kwargs)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@contextmanager
|
| 74 |
+
def enable_torchbind_tracing():
|
| 75 |
+
"""Context manager that acts as a feature flag to enable torchbind tracing
|
| 76 |
+
behavior. Once torchbind tracing has been stabilized, we can remove this and
|
| 77 |
+
turn it always on.
|
| 78 |
+
"""
|
| 79 |
+
try:
|
| 80 |
+
KNOWN_TYPES.append(torch.ScriptObject)
|
| 81 |
+
torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign]
|
| 82 |
+
yield
|
| 83 |
+
finally:
|
| 84 |
+
assert (
|
| 85 |
+
KNOWN_TYPES.pop() is torch.ScriptObject
|
| 86 |
+
), "Someone else messed with KNOWN_TYPES during tracing, exploding."
|
| 87 |
+
torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 91 |
+
def call_torchbind_impl(obj, method, *args, **kwargs):
|
| 92 |
+
if isinstance(obj, torch.ScriptObject):
|
| 93 |
+
return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)
|
| 94 |
+
elif isinstance(obj, FakeScriptObject):
|
| 95 |
+
return getattr(obj.wrapped_obj, method)(*args, **kwargs)
|
| 96 |
+
else:
|
| 97 |
+
raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@call_torchbind.py_impl(ProxyTorchDispatchMode)
|
| 101 |
+
def inner(mode, *args, **kwargs):
|
| 102 |
+
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
|
| 103 |
+
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
|
| 104 |
+
|
| 105 |
+
out_proxy = mode.tracer.create_proxy(
|
| 106 |
+
"call_function",
|
| 107 |
+
call_torchbind,
|
| 108 |
+
proxy_args,
|
| 109 |
+
proxy_kwargs,
|
| 110 |
+
)
|
| 111 |
+
out = call_torchbind(*args, **kwargs)
|
| 112 |
+
|
| 113 |
+
obj, method, *_rest_args = args
|
| 114 |
+
if isinstance(obj, torch.ScriptObject):
|
| 115 |
+
ns, class_name = _ns_and_class_name(
|
| 116 |
+
obj._type().qualified_name() # type: ignore[attr-defined]
|
| 117 |
+
)
|
| 118 |
+
log.warning(
|
| 119 |
+
"Tracing torchbind method %s.%s with real ScriptObject. This may"
|
| 120 |
+
" cause the original object being mutated. If this is not intended,"
|
| 121 |
+
' You can register a fake class with torch._library.register_fake_class("%s::%s").',
|
| 122 |
+
class_name,
|
| 123 |
+
method,
|
| 124 |
+
ns,
|
| 125 |
+
class_name,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
| 129 |
+
if "val" not in out_proxy.node.meta:
|
| 130 |
+
assert out is None or isinstance(
|
| 131 |
+
out, (int, float, bool)
|
| 132 |
+
), "Currently, only these constant dtypes are supported to be returned from torchbind methods."
|
| 133 |
+
out_proxy.node.meta["val"] = out
|
| 134 |
+
return ret
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# When tracing with fake script object, the call_torchbind op will return a fake tensor
|
| 138 |
+
# When tracing with real script object, the call_torchbind op may return a real tensor,
|
| 139 |
+
# we need to convert it to fake tensor mannually. Dynamic shape is surpported.
|
| 140 |
+
@call_torchbind.py_impl(FakeTensorMode)
|
| 141 |
+
def call_torchbind_fake(mode, *args, **kwargs):
|
| 142 |
+
with mode:
|
| 143 |
+
out = call_torchbind_impl(*args, **kwargs)
|
| 144 |
+
return pytree.tree_map_only(
|
| 145 |
+
torch.Tensor,
|
| 146 |
+
lambda x: mode.from_tensor(x, static_shapes=True)
|
| 147 |
+
if not isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
|
| 148 |
+
else x,
|
| 149 |
+
out,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
call_torchbind.py_autograd_impl(
|
| 154 |
+
autograd_not_implemented(call_torchbind, deferred_error=True)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@call_torchbind.py_functionalize_impl
|
| 159 |
+
def call_torchbind_func(ctx, *args, **kwargs):
|
| 160 |
+
from torch._higher_order_ops.effects import handle_effects
|
| 161 |
+
|
| 162 |
+
return handle_effects(
|
| 163 |
+
ctx.mode._allow_token_discovery, ctx.mode._tokens, call_torchbind, args, kwargs
|
| 164 |
+
)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py
ADDED
|
@@ -0,0 +1,2051 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import copy
|
| 3 |
+
import dataclasses
|
| 4 |
+
import functools
|
| 5 |
+
import inspect
|
| 6 |
+
import itertools
|
| 7 |
+
import logging
|
| 8 |
+
import operator
|
| 9 |
+
import threading
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from collections.abc import Sequence
|
| 12 |
+
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
| 13 |
+
from typing_extensions import Never
|
| 14 |
+
|
| 15 |
+
import sympy
|
| 16 |
+
|
| 17 |
+
import torch.fx as fx
|
| 18 |
+
import torch.utils._pytree as pytree
|
| 19 |
+
from torch import SymInt, Tensor
|
| 20 |
+
from torch._C import DispatchKey
|
| 21 |
+
from torch._ops import HigherOrderOperator
|
| 22 |
+
from torch._prims_common import clone_preserve_strides
|
| 23 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 24 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 25 |
+
disable_proxy_modes_tracing,
|
| 26 |
+
ProxyTorchDispatchMode,
|
| 27 |
+
track_tensor_tree,
|
| 28 |
+
)
|
| 29 |
+
from torch.fx.experimental.symbolic_shapes import guard_scalar
|
| 30 |
+
from torch.types import IntLikeType
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if TYPE_CHECKING:
|
| 34 |
+
from triton._C.libtriton.ir import (
|
| 35 |
+
module as TritonIRModule,
|
| 36 |
+
operation as TritonIROperation,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
from torch._dynamo.symbolic_convert import InstructionTranslator
|
| 40 |
+
from torch._dynamo.variables.constant import ConstantVariable
|
| 41 |
+
from torch._dynamo.variables.functions import TritonKernelVariable
|
| 42 |
+
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
|
| 43 |
+
from torch.fx.proxy import Proxy
|
| 44 |
+
from torch.utils._triton import has_triton
|
| 45 |
+
|
| 46 |
+
TritonMetaParamsType = dict[str, int]
|
| 47 |
+
TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...]
|
| 48 |
+
TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]]
|
| 49 |
+
TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
|
| 50 |
+
|
| 51 |
+
if has_triton():
|
| 52 |
+
from triton.runtime.autotuner import Autotuner, Config as TritonConfig
|
| 53 |
+
from triton.runtime.jit import JITFunction
|
| 54 |
+
else:
|
| 55 |
+
|
| 56 |
+
class Autotuner: # type: ignore[no-redef]
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
class JITFunction: # type: ignore[no-redef]
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
TritonKernelType = Union[Autotuner, JITFunction]
|
| 63 |
+
# mypy specifically complains that TritonAutotunerType is not a valid type if Autotuner is not inside of a Union.
|
| 64 |
+
TritonAutotunerType = Union[Autotuner]
|
| 65 |
+
|
| 66 |
+
log = logging.getLogger("torch._dynamo")
|
| 67 |
+
|
| 68 |
+
# e.g. for a host-side Triton TMA API call ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``,
|
| 69 |
+
# the metadata will look like ``("experimental", ([50, 60], [32, 15], 4))``
|
| 70 |
+
TMAExperimentalMetadata = tuple[
|
| 71 |
+
str, # type of TMA (should be "experimental")
|
| 72 |
+
tuple[
|
| 73 |
+
list[IntLikeType], # dims
|
| 74 |
+
list[IntLikeType], # block_dims
|
| 75 |
+
IntLikeType, # element_size
|
| 76 |
+
],
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
# e.g. for host-side Triton TMA API call ``TensorDescriptor.from_tensor(ptr, [32, 64])``
|
| 80 |
+
# the metadata will look like ``("stable", ([32, 64],))``
|
| 81 |
+
TMAStableMetadata = tuple[
|
| 82 |
+
str, # type of TMA ("experimental" or "stable")
|
| 83 |
+
tuple[list[IntLikeType],], # block_shape
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def create_tma_experimental_metadata(
|
| 88 |
+
dims: list[IntLikeType],
|
| 89 |
+
block_dims: list[IntLikeType],
|
| 90 |
+
element_size: IntLikeType,
|
| 91 |
+
) -> TMAExperimentalMetadata:
|
| 92 |
+
return ("experimental", (dims, block_dims, element_size))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def maybe_unpack_tma_experimental_metadata(
|
| 96 |
+
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata]
|
| 97 |
+
) -> Optional[tuple[list[IntLikeType], list[IntLikeType], IntLikeType]]:
|
| 98 |
+
if not tma_meta or len(tma_meta) != 2:
|
| 99 |
+
return None
|
| 100 |
+
if tma_meta[0] == "experimental":
|
| 101 |
+
return tma_meta[1] # type: ignore[return-value]
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def create_tma_stable_metadata(
|
| 106 |
+
block_shape: list[IntLikeType],
|
| 107 |
+
) -> TMAStableMetadata:
|
| 108 |
+
return ("stable", (block_shape,))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def maybe_unpack_tma_stable_metadata(
|
| 112 |
+
tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata]
|
| 113 |
+
) -> Optional[tuple[list[IntLikeType]]]:
|
| 114 |
+
if not tma_meta or len(tma_meta) != 2:
|
| 115 |
+
return None
|
| 116 |
+
if tma_meta[0] == "stable":
|
| 117 |
+
return tma_meta[1] # type: ignore[return-value]
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# TMADescriptorMetadata maps kernel parameter names to the metadata that allows
|
| 122 |
+
# reconstructing TMA descriptors from the underlying tensors (passed as kernel
|
| 123 |
+
# arguments in the fx graph, instead of the TMA descriptors).
|
| 124 |
+
#
|
| 125 |
+
# Since there are two TMA APIs (the old "experimental" API and the new "stable" API),
|
| 126 |
+
# each entry in the dict is a tuple that starts with a string, either "experimental"
|
| 127 |
+
# or "stable". The second entry in the tuple is another tuple, with data that depends
|
| 128 |
+
# on the API type (see TMAExperimentalMetadata and TMAStableMetadata above).
|
| 129 |
+
#
|
| 130 |
+
# These are stored as raw tuples (instead of classes) for ease of serialization.
|
| 131 |
+
TMADescriptorMetadata = dict[
|
| 132 |
+
str, # kernel parameter name
|
| 133 |
+
Union[TMAExperimentalMetadata, TMAStableMetadata],
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
###############################################################################
|
| 138 |
+
# Kernel Side Table
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# We cannot put Triton Kernels into the FX graph as the graph nodes
|
| 142 |
+
# do not support arbitrary functions.
|
| 143 |
+
# Use a side table.
|
| 144 |
+
# We use two dicts so that fetching both the kernel and id are O(1)
|
| 145 |
+
class KernelSideTable:
|
| 146 |
+
id_to_kernel: dict[int, "TritonKernelType"] = {}
|
| 147 |
+
kernel_to_id: dict["TritonKernelType", int] = {}
|
| 148 |
+
constant_args: dict[int, dict[str, Any]] = {}
|
| 149 |
+
lock = threading.Lock()
|
| 150 |
+
|
| 151 |
+
# Returns index on the table
|
| 152 |
+
def add_kernel(self, kernel: "TritonKernelType") -> int:
|
| 153 |
+
with self.lock:
|
| 154 |
+
if kernel in self.kernel_to_id:
|
| 155 |
+
return self.kernel_to_id[kernel]
|
| 156 |
+
|
| 157 |
+
idx = len(self.id_to_kernel)
|
| 158 |
+
self.id_to_kernel[idx] = kernel
|
| 159 |
+
self.kernel_to_id[kernel] = idx
|
| 160 |
+
return idx
|
| 161 |
+
|
| 162 |
+
# Returns the triton kernel at the given index
|
| 163 |
+
def get_kernel(self, idx: int) -> "TritonKernelType":
|
| 164 |
+
# No need to lock here as fetching from dict is atomic
|
| 165 |
+
assert idx in self.id_to_kernel
|
| 166 |
+
return self.id_to_kernel[idx]
|
| 167 |
+
|
| 168 |
+
# Not every constant arg can be added to the graph. Use this side table
|
| 169 |
+
# for constant args.
|
| 170 |
+
def add_constant_args(self, args: dict[str, Any]) -> int:
|
| 171 |
+
with self.lock:
|
| 172 |
+
idx = len(self.constant_args)
|
| 173 |
+
self.constant_args[idx] = args
|
| 174 |
+
return idx
|
| 175 |
+
|
| 176 |
+
# Returns the constant args
|
| 177 |
+
def get_constant_args(self, idx: int) -> dict[str, Any]:
|
| 178 |
+
# No need to lock here as fetching from dict is atomic
|
| 179 |
+
assert idx in self.constant_args
|
| 180 |
+
return self.constant_args[idx]
|
| 181 |
+
|
| 182 |
+
# Resets the table (only meant to be used in unit tests)
|
| 183 |
+
# This is only safe assuming single threaded execution
|
| 184 |
+
def reset_table(self) -> None:
|
| 185 |
+
self.id_to_kernel = {}
|
| 186 |
+
self.kernel_to_id = {}
|
| 187 |
+
self.constant_args = {}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
kernel_side_table = KernelSideTable()
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
###############################################################################
|
| 194 |
+
# Mutation Tracker
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@dataclasses.dataclass(frozen=True)
|
| 198 |
+
class Param:
|
| 199 |
+
idx: int
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@dataclasses.dataclass(frozen=True)
|
| 203 |
+
class Intermediate:
|
| 204 |
+
idx: int
|
| 205 |
+
|
| 206 |
+
def fake(self) -> bool:
|
| 207 |
+
return self.idx < 0
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@dataclasses.dataclass(frozen=True)
|
| 211 |
+
class Op:
|
| 212 |
+
name: str
|
| 213 |
+
fn_call_name: Optional[str]
|
| 214 |
+
args: list[Union[Param, Intermediate]]
|
| 215 |
+
ret: Intermediate = dataclasses.field(repr=False)
|
| 216 |
+
# used for scf.yield: see [Note: scf.yield fix-up]
|
| 217 |
+
sub_idx: Optional[int] = None
|
| 218 |
+
# used for tt.elementwise_inline_asm
|
| 219 |
+
# `is_pure = True` assumes the asm block has no side-effects
|
| 220 |
+
is_pure: bool = False
|
| 221 |
+
|
| 222 |
+
def __post_init__(self) -> None:
|
| 223 |
+
if self.name == "tt.call":
|
| 224 |
+
assert self.fn_call_name is not None
|
| 225 |
+
else:
|
| 226 |
+
assert self.fn_call_name is None
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def generate_ttir(
|
| 230 |
+
kernel: "TritonKernelType",
|
| 231 |
+
kwargs: dict[str, Any],
|
| 232 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 233 |
+
) -> tuple["TritonIRModule", list[str]]:
|
| 234 |
+
"""
|
| 235 |
+
Uses Triton's internal code generation to create TTIR
|
| 236 |
+
"""
|
| 237 |
+
import sympy
|
| 238 |
+
import triton
|
| 239 |
+
import triton.runtime.jit
|
| 240 |
+
from triton.compiler.compiler import ASTSource
|
| 241 |
+
from triton.runtime.autotuner import Autotuner
|
| 242 |
+
from triton.runtime.jit import JITFunction
|
| 243 |
+
|
| 244 |
+
from torch._inductor.utils import (
|
| 245 |
+
get_triton_attrs_descriptor_version,
|
| 246 |
+
triton_version_uses_attrs_dict,
|
| 247 |
+
TritonAttrsDescriptorVersion,
|
| 248 |
+
)
|
| 249 |
+
from torch.utils._triton import has_triton_tensor_descriptor_host_tma
|
| 250 |
+
|
| 251 |
+
triton_version = get_triton_attrs_descriptor_version()
|
| 252 |
+
|
| 253 |
+
import torch._inductor.ir
|
| 254 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 255 |
+
|
| 256 |
+
if isinstance(kernel, Autotuner):
|
| 257 |
+
if len(kernel.configs) > 0:
|
| 258 |
+
# If we are autotuning, then it doesn't matter which version gets
|
| 259 |
+
# picked for tracing purposes, so lets pick the first one
|
| 260 |
+
kwargs = {**kwargs, **kernel.configs[0].kwargs}
|
| 261 |
+
kernel = kernel.fn
|
| 262 |
+
|
| 263 |
+
assert isinstance(kernel, JITFunction)
|
| 264 |
+
|
| 265 |
+
context = triton._C.libtriton.ir.context()
|
| 266 |
+
target = triton.runtime.driver.active.get_current_target()
|
| 267 |
+
backend = triton.compiler.compiler.make_backend(target)
|
| 268 |
+
options = backend.parse_options({})
|
| 269 |
+
|
| 270 |
+
# ignore backend-specific kwargs same way as in the native Triton code
|
| 271 |
+
# https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596
|
| 272 |
+
# why this is important for user-defined Triton kernels on AMD: https://github.com/pytorch/pytorch/issues/140800
|
| 273 |
+
for name in list(kwargs):
|
| 274 |
+
if name not in kernel.arg_names and name in options.__dict__:
|
| 275 |
+
kwargs.pop(name)
|
| 276 |
+
|
| 277 |
+
if len(kwargs) != len(kernel.arg_names):
|
| 278 |
+
raise ValueError(
|
| 279 |
+
"Incorrect number of arguments passed to kernel: "
|
| 280 |
+
f"passed {list(kwargs.keys())}, expected {kernel.arg_names}."
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Replace all SymExprs with a regular value for TTIR generation
|
| 284 |
+
# Replace all FakeTensor/TensorBox with real tensors
|
| 285 |
+
# These replacements are needed for triton's type, key and config functions
|
| 286 |
+
ordered_args: dict[str, Any] = {}
|
| 287 |
+
for name in kernel.arg_names:
|
| 288 |
+
a = kwargs[name]
|
| 289 |
+
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)):
|
| 290 |
+
ordered_args[name] = 2
|
| 291 |
+
elif (
|
| 292 |
+
stable_meta := maybe_unpack_tma_stable_metadata(
|
| 293 |
+
tma_descriptor_metadata.get(name, None)
|
| 294 |
+
)
|
| 295 |
+
) is not None:
|
| 296 |
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
| 297 |
+
|
| 298 |
+
block_shape = stable_meta[0]
|
| 299 |
+
with torch._C._DisableTorchDispatch():
|
| 300 |
+
# need 16-byte aligned strides
|
| 301 |
+
elements_per_dim = max(1, 16 // a.dtype.itemsize)
|
| 302 |
+
base_tensor = torch.empty(
|
| 303 |
+
[elements_per_dim] * len(block_shape), dtype=a.dtype
|
| 304 |
+
)
|
| 305 |
+
ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape)
|
| 306 |
+
elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
|
| 307 |
+
with torch._C._DisableTorchDispatch():
|
| 308 |
+
ordered_args[name] = torch.empty(2, dtype=a.dtype)
|
| 309 |
+
else:
|
| 310 |
+
ordered_args[name] = a
|
| 311 |
+
|
| 312 |
+
def is_stable_tensor_descriptor_arg(arg: Any) -> bool:
|
| 313 |
+
if has_triton_tensor_descriptor_host_tma():
|
| 314 |
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
| 315 |
+
|
| 316 |
+
if isinstance(arg, TensorDescriptor):
|
| 317 |
+
return True
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
def is_tensor_like_arg(arg: Any) -> bool:
|
| 321 |
+
if isinstance(arg, Tensor) or is_stable_tensor_descriptor_arg(arg):
|
| 322 |
+
return True
|
| 323 |
+
return False
|
| 324 |
+
|
| 325 |
+
# Note: one would expect that each input to the triton kernel maps to
|
| 326 |
+
# one input parameter in the TTIR. This is _not_ true for TMA descriptors:
|
| 327 |
+
# one TMA descriptor gets converted into:
|
| 328 |
+
# * one TMA descriptor input
|
| 329 |
+
# * N strides, for a rank-N tensor
|
| 330 |
+
# * N sizes, for a rank-N tensor
|
| 331 |
+
# To account for this, we inject some fake arg names as placeholders for
|
| 332 |
+
# the stride and size parameters.
|
| 333 |
+
def get_tensor_names(name: str, arg: Any) -> list[str]:
|
| 334 |
+
if isinstance(arg, Tensor):
|
| 335 |
+
return [name]
|
| 336 |
+
if is_stable_tensor_descriptor_arg(arg):
|
| 337 |
+
stable_meta = maybe_unpack_tma_stable_metadata(
|
| 338 |
+
tma_descriptor_metadata[name]
|
| 339 |
+
)
|
| 340 |
+
assert stable_meta is not None
|
| 341 |
+
block_shape = stable_meta[0]
|
| 342 |
+
tensor_rank = len(block_shape)
|
| 343 |
+
names = [name]
|
| 344 |
+
names.extend(name + f" STRIDE PLACEHOLDER {i}" for i in range(tensor_rank))
|
| 345 |
+
names.extend(name + f" SIZE PLACEHOLDER {i}" for i in range(tensor_rank))
|
| 346 |
+
return names
|
| 347 |
+
return []
|
| 348 |
+
|
| 349 |
+
ordered_tensor_names = list(
|
| 350 |
+
itertools.chain.from_iterable(
|
| 351 |
+
get_tensor_names(name, arg) for name, arg in ordered_args.items()
|
| 352 |
+
)
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
def _get_specialization(args): # type: ignore[no-untyped-def]
|
| 356 |
+
# Support multiple triton versions.
|
| 357 |
+
# This code basically copies JITFunction.run() logic to get the attrs to construct an ASTSource.
|
| 358 |
+
if triton_version == TritonAttrsDescriptorVersion.V1_COMPILER:
|
| 359 |
+
return kernel._get_config(*args)
|
| 360 |
+
elif triton_version in {
|
| 361 |
+
TritonAttrsDescriptorVersion.V2_BACKENDS,
|
| 362 |
+
TritonAttrsDescriptorVersion.V3_BACKENDS_TUPLE,
|
| 363 |
+
}:
|
| 364 |
+
from triton.backends.compiler import AttrsDescriptor # noqa: F401
|
| 365 |
+
|
| 366 |
+
target = triton.runtime.driver.active.get_current_target()
|
| 367 |
+
backend_ = triton.compiler.compiler.make_backend(target)
|
| 368 |
+
return backend_.get_attrs_descriptor(args, kernel.params)
|
| 369 |
+
else:
|
| 370 |
+
assert (
|
| 371 |
+
get_triton_attrs_descriptor_version()
|
| 372 |
+
== TritonAttrsDescriptorVersion.V4_DICT
|
| 373 |
+
)
|
| 374 |
+
# specialize_impl switched to create_specialize_impl in https://github.com/triton-lang/triton/pull/6099
|
| 375 |
+
if hasattr(triton.runtime.jit, "create_specialize_impl"):
|
| 376 |
+
try:
|
| 377 |
+
# Latest versions of Triton take specialize_extra as an arg to create_specialize_impl
|
| 378 |
+
specialize_impl = triton.runtime.jit.create_specialize_impl(
|
| 379 |
+
specialize_extra=backend.get_arg_specialization
|
| 380 |
+
)
|
| 381 |
+
except TypeError: # Unknown arg `specialize_extra`
|
| 382 |
+
# Older versions of Triton take specialize_extra as an arg to specialize_impl
|
| 383 |
+
specialize_impl = functools.partial(
|
| 384 |
+
triton.runtime.jit.create_specialize_impl(),
|
| 385 |
+
specialize_extra=backend.get_arg_specialization,
|
| 386 |
+
)
|
| 387 |
+
else:
|
| 388 |
+
from triton.runtime.jit import specialize_impl as specialize_impl_orig
|
| 389 |
+
|
| 390 |
+
specialize_impl = functools.partial(
|
| 391 |
+
specialize_impl_orig,
|
| 392 |
+
specialize_extra=backend.get_arg_specialization,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
from triton._utils import find_paths_if, get_iterable_path
|
| 396 |
+
|
| 397 |
+
# logic is copied from: binder = create_function_from_signature(self.signature, self.params, backend)
|
| 398 |
+
attrvals = []
|
| 399 |
+
for arg, kp in zip(args, kernel.params):
|
| 400 |
+
if kp.is_constexpr:
|
| 401 |
+
attrvals.append(arg)
|
| 402 |
+
else:
|
| 403 |
+
spec = specialize_impl(
|
| 404 |
+
arg,
|
| 405 |
+
is_const=kp.is_const,
|
| 406 |
+
specialize_value=not kp.do_not_specialize,
|
| 407 |
+
align=not kp.do_not_specialize_on_alignment,
|
| 408 |
+
)
|
| 409 |
+
attrvals.append(spec[1])
|
| 410 |
+
|
| 411 |
+
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
|
| 412 |
+
attrs = {
|
| 413 |
+
k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs
|
| 414 |
+
}
|
| 415 |
+
return attrs
|
| 416 |
+
|
| 417 |
+
specialization = _get_specialization(ordered_args.values())
|
| 418 |
+
constants = {
|
| 419 |
+
name: arg for name, arg in ordered_args.items() if not is_tensor_like_arg(arg)
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
if (mangle_type := getattr(triton.runtime.jit, "mangle_type", None)) is not None:
|
| 423 |
+
|
| 424 |
+
def get_signature_value(idx: int, arg: Any) -> str:
|
| 425 |
+
if kernel.params[idx].is_constexpr:
|
| 426 |
+
return "constexpr"
|
| 427 |
+
return mangle_type(arg)
|
| 428 |
+
|
| 429 |
+
else:
|
| 430 |
+
|
| 431 |
+
def get_signature_value(idx: int, arg: Any) -> str:
|
| 432 |
+
return kernel._type_of(kernel.key_of(arg))
|
| 433 |
+
|
| 434 |
+
if triton_version_uses_attrs_dict():
|
| 435 |
+
# In newer versions of Triton, the signature includes constexpr args
|
| 436 |
+
signature = {
|
| 437 |
+
name: get_signature_value(i, arg)
|
| 438 |
+
for i, (name, arg) in enumerate(ordered_args.items())
|
| 439 |
+
}
|
| 440 |
+
else:
|
| 441 |
+
# In older versions of Triton, the signature does not include constexpr args
|
| 442 |
+
signature = {
|
| 443 |
+
name: get_signature_value(i, arg)
|
| 444 |
+
for i, (name, arg) in enumerate(ordered_args.items())
|
| 445 |
+
if i not in kernel.constexprs
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
triton._C.libtriton.ir.load_dialects(context)
|
| 449 |
+
backend.load_dialects(context)
|
| 450 |
+
|
| 451 |
+
src = ASTSource(kernel, signature, constants, specialization)
|
| 452 |
+
|
| 453 |
+
# Triton changes ASTSource.make_ir to take 3/4 arguments. Handle
|
| 454 |
+
# backward compatibility here.
|
| 455 |
+
make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
|
| 456 |
+
get_codegen_implementation_sig_params = len(
|
| 457 |
+
inspect.signature(backend.get_codegen_implementation).parameters
|
| 458 |
+
)
|
| 459 |
+
if make_ir_sig_params == 2:
|
| 460 |
+
ttir_module = src.make_ir(options, context)
|
| 461 |
+
elif make_ir_sig_params == 3:
|
| 462 |
+
codegen_fns = backend.get_codegen_implementation()
|
| 463 |
+
ttir_module = src.make_ir(options, codegen_fns, context)
|
| 464 |
+
else:
|
| 465 |
+
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
|
| 466 |
+
codegen_fns = backend.get_codegen_implementation(*codegen_args)
|
| 467 |
+
module_map = backend.get_module_map()
|
| 468 |
+
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
|
| 469 |
+
if not ttir_module.verify():
|
| 470 |
+
raise RuntimeError("Verification for TTIR module has failed")
|
| 471 |
+
|
| 472 |
+
return ttir_module, ordered_tensor_names
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def ttir_to_functions(
|
| 476 |
+
ttir_module: "TritonIRModule",
|
| 477 |
+
) -> dict[str, dict[Intermediate, list[Op]]]:
|
| 478 |
+
"""
|
| 479 |
+
Walk the `ttir_module` bottom up to mine the `functions` from
|
| 480 |
+
the structured MLIR entities representing the Triton kernel
|
| 481 |
+
(mlir::Operation, mlir::Block, mlir::Region).
|
| 482 |
+
"""
|
| 483 |
+
functions: dict[str, dict[Intermediate, list[Op]]] = {}
|
| 484 |
+
|
| 485 |
+
# block id --> op result (Intermediate) --> one or more ops
|
| 486 |
+
op_stack: dict[int, dict[Intermediate, list[Op]]] = defaultdict(
|
| 487 |
+
lambda: defaultdict(list)
|
| 488 |
+
)
|
| 489 |
+
region_id_to_block_ids: dict[int, list[int]] = defaultdict(list)
|
| 490 |
+
block_id_to_block_arg_ids: dict[int, list[int]] = {}
|
| 491 |
+
replacements: dict[int, Union[Intermediate, Param]] = {}
|
| 492 |
+
reindex_map: dict[int, int] = {}
|
| 493 |
+
next_fake_intermediate = 0
|
| 494 |
+
|
| 495 |
+
def reindex(idx: int) -> int:
|
| 496 |
+
if idx not in reindex_map:
|
| 497 |
+
reindex_map[idx] = len(reindex_map)
|
| 498 |
+
return reindex_map[idx]
|
| 499 |
+
|
| 500 |
+
def mlir_to_functions(op: "TritonIROperation") -> None:
|
| 501 |
+
name: str = op.get_name()
|
| 502 |
+
if name == "builtin.module":
|
| 503 |
+
# this wraps all tt.func ops
|
| 504 |
+
return
|
| 505 |
+
|
| 506 |
+
operand_ids: list[int] = [
|
| 507 |
+
reindex(op.get_operand(i).id()) for i in range(op.get_num_operands())
|
| 508 |
+
]
|
| 509 |
+
result_ids: list[int] = [
|
| 510 |
+
reindex(op.get_result(i).id()) for i in range(op.get_num_results())
|
| 511 |
+
]
|
| 512 |
+
|
| 513 |
+
child_block_ids: list[int] = []
|
| 514 |
+
for i in [op.get_region(i).id() for i in range(op.get_num_regions())]:
|
| 515 |
+
# as the walk is bottom-up, the region_id_to_block_ids[i]
|
| 516 |
+
# must be populated by the time we process the enclosing op
|
| 517 |
+
child_block_ids.extend(region_id_to_block_ids[i])
|
| 518 |
+
|
| 519 |
+
parent_block_id = -1
|
| 520 |
+
parent_block = op.get_block()
|
| 521 |
+
if parent_block is not None:
|
| 522 |
+
parent_block_id = parent_block.id()
|
| 523 |
+
if parent_block_id not in block_id_to_block_arg_ids:
|
| 524 |
+
block_id_to_block_arg_ids[parent_block_id] = []
|
| 525 |
+
for i in range(parent_block.get_num_arguments()):
|
| 526 |
+
block_id_to_block_arg_ids[parent_block_id].append(
|
| 527 |
+
reindex(parent_block.get_argument(i).id()),
|
| 528 |
+
)
|
| 529 |
+
# the region info is collected via ops' parent blocks to be
|
| 530 |
+
# used later when the region's encloding op is traversed
|
| 531 |
+
parent_region = parent_block.get_parent()
|
| 532 |
+
if parent_region is not None:
|
| 533 |
+
region_id_to_block_ids[parent_region.id()].append(parent_block_id)
|
| 534 |
+
|
| 535 |
+
nonlocal next_fake_intermediate
|
| 536 |
+
|
| 537 |
+
if name == "tt.func":
|
| 538 |
+
# for function ops: gather and inline
|
| 539 |
+
# the ops from all child blocks
|
| 540 |
+
fn_ops = defaultdict(list)
|
| 541 |
+
for child_block_id in child_block_ids:
|
| 542 |
+
for result, block_fn_ops in op_stack.pop(child_block_id).items():
|
| 543 |
+
for block_fn_op in block_fn_ops:
|
| 544 |
+
fn_ops[result].append(block_fn_op)
|
| 545 |
+
|
| 546 |
+
# replace the corresponding Intermediates in the
|
| 547 |
+
# child op args with the function args (Params)
|
| 548 |
+
for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]):
|
| 549 |
+
replacements[idx] = Param(i)
|
| 550 |
+
|
| 551 |
+
for fn_op_list in fn_ops.values():
|
| 552 |
+
for fn_op in fn_op_list:
|
| 553 |
+
for i in range(len(fn_op.args)):
|
| 554 |
+
arg = fn_op.args[i]
|
| 555 |
+
seen = set() # to break cycles
|
| 556 |
+
# there can be transitive replacements, but likely
|
| 557 |
+
# no cycles (we keep the `seen` set just in case)
|
| 558 |
+
while (
|
| 559 |
+
isinstance(arg, Intermediate)
|
| 560 |
+
and arg.idx in replacements
|
| 561 |
+
and arg.idx not in seen
|
| 562 |
+
):
|
| 563 |
+
seen.add(arg.idx)
|
| 564 |
+
arg = fn_op.args[i] = replacements[arg.idx]
|
| 565 |
+
|
| 566 |
+
# next function capture starts
|
| 567 |
+
# with empty replacements
|
| 568 |
+
replacements.clear()
|
| 569 |
+
|
| 570 |
+
fn_name = op.get_str_attr("sym_name")
|
| 571 |
+
functions[fn_name] = fn_ops
|
| 572 |
+
elif child_block_ids:
|
| 573 |
+
if name in {"scf.if", "scf.for", "scf.while", "tt.reduce", "tt.scan"}:
|
| 574 |
+
# for blocked ops: inline the enclosed ops into
|
| 575 |
+
# the parent block + rewire the last op in each
|
| 576 |
+
# child block to return the block result
|
| 577 |
+
return_ops = []
|
| 578 |
+
for block_id in child_block_ids:
|
| 579 |
+
if name == "scf.for":
|
| 580 |
+
# example:
|
| 581 |
+
# %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (i32) ...
|
| 582 |
+
# block args: 2 (%iv, %arg)
|
| 583 |
+
# op operands: 4 (%lb, %ub, %step, %init)
|
| 584 |
+
# `%arg` is mapping to `%init`
|
| 585 |
+
for i, idx in enumerate(block_id_to_block_arg_ids[block_id]):
|
| 586 |
+
if i == 0:
|
| 587 |
+
next_fake_intermediate -= 1
|
| 588 |
+
replacements[idx] = Intermediate(next_fake_intermediate)
|
| 589 |
+
else:
|
| 590 |
+
replacements[idx] = Intermediate(operand_ids[i + 2])
|
| 591 |
+
elif name == "scf.while":
|
| 592 |
+
# example:
|
| 593 |
+
# %3:3 = scf.while (%arg2 = %1, %arg3 = %2, %arg4 = %c0_i32_8) ...
|
| 594 |
+
# block args: 3 (%arg2, %arg3, %arg4)
|
| 595 |
+
# op operands: 3 (%1, %2, %c0_i32_8)
|
| 596 |
+
# `%arg2` is mapping to `%1`, `%arg3` is mapping to `%2`, ...
|
| 597 |
+
for i, idx in enumerate(block_id_to_block_arg_ids[block_id]):
|
| 598 |
+
replacements[idx] = Intermediate(operand_ids[i])
|
| 599 |
+
elif name == "scf.if":
|
| 600 |
+
# the scf block args are ignored by the pass. but, as they
|
| 601 |
+
# may be used as operands of the ops inside the block
|
| 602 |
+
# (and nested blocks inlined in the current block by now),
|
| 603 |
+
# they are replaced by new fake Intermediates to avoid "this
|
| 604 |
+
# operand is not returned by any other op in the fn" error
|
| 605 |
+
# in the downstream analysis
|
| 606 |
+
for idx in block_id_to_block_arg_ids[block_id]:
|
| 607 |
+
next_fake_intermediate -= 1
|
| 608 |
+
replacements[idx] = Intermediate(next_fake_intermediate)
|
| 609 |
+
else:
|
| 610 |
+
assert name in ("tt.reduce", "tt.scan")
|
| 611 |
+
# wire the block arguments to the op arguments
|
| 612 |
+
num_operands = len(operand_ids)
|
| 613 |
+
block_arg_ids = block_id_to_block_arg_ids[block_id]
|
| 614 |
+
assert len(block_arg_ids) == 2 * num_operands, (
|
| 615 |
+
f"{name} is expected to have twice as "
|
| 616 |
+
"many block arguments as op arguments: "
|
| 617 |
+
f"{operand_ids=}, {block_arg_ids=}."
|
| 618 |
+
)
|
| 619 |
+
for i, idx in enumerate(block_arg_ids):
|
| 620 |
+
# for a tt.reduce/tt.scan op with N arguments, the block
|
| 621 |
+
# arguments comprise N reduced values followed by
|
| 622 |
+
# N current values corresponding to the N op args
|
| 623 |
+
replacements[idx] = Intermediate(
|
| 624 |
+
operand_ids[i % num_operands]
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
if block_id in op_stack:
|
| 628 |
+
block_ops = op_stack.pop(block_id)
|
| 629 |
+
if not block_ops:
|
| 630 |
+
continue
|
| 631 |
+
last_ret, last_ops = block_ops.popitem()
|
| 632 |
+
if all(
|
| 633 |
+
op.name
|
| 634 |
+
in ("scf.yield", "tt.reduce.return", "tt.scan.return")
|
| 635 |
+
for op in last_ops
|
| 636 |
+
):
|
| 637 |
+
# if last_ops are all return ops, treat them separately
|
| 638 |
+
return_ops.extend(last_ops)
|
| 639 |
+
else:
|
| 640 |
+
# otherwise, return last_ops to the block
|
| 641 |
+
block_ops[last_ret] = last_ops
|
| 642 |
+
for op_result, child_ops in block_ops.items():
|
| 643 |
+
op_stack[parent_block_id][op_result].extend(child_ops)
|
| 644 |
+
|
| 645 |
+
scf_results = [Intermediate(idx) for idx in result_ids]
|
| 646 |
+
|
| 647 |
+
if return_ops and all(
|
| 648 |
+
(op.name == "scf.yield" and len(result_ids) == len(op.args))
|
| 649 |
+
for op in return_ops
|
| 650 |
+
):
|
| 651 |
+
# [Note: scf.yield fix-up]
|
| 652 |
+
#
|
| 653 |
+
# TL;DR: if our scf.yield takes N args, then we'll create N scf.yield ops to handle each of the
|
| 654 |
+
# args.
|
| 655 |
+
#
|
| 656 |
+
# **Context**:
|
| 657 |
+
# During mutation analysis, the analysis pass will identify mutating ops (e.g. tt.store)
|
| 658 |
+
# and then DFS upwards towards the parameters of the function. Specifically, the analysis pass
|
| 659 |
+
# looks at the mutated arg in tt.store; then looks for its source ops; and then recurses on the
|
| 660 |
+
# arguments to each of the source ops.
|
| 661 |
+
#
|
| 662 |
+
# In the case of scf.if/scf.for, we may have multiple return ops, each passed as an arg
|
| 663 |
+
# to scf.yield:
|
| 664 |
+
#
|
| 665 |
+
# %18:2 = scf.if %... -> (!tt.ptr<f32>, !tt.ptr<f32>) {
|
| 666 |
+
# ...
|
| 667 |
+
# scf.yield %1, %2
|
| 668 |
+
# } else {
|
| 669 |
+
# scf.yield %3, %4
|
| 670 |
+
# }
|
| 671 |
+
#
|
| 672 |
+
# And for each of the returns of the scf.if, we'd naively assign the source op of each of the
|
| 673 |
+
# return values to be the scf.yields. But the scf.yields take _all_ the returns as arguments.
|
| 674 |
+
# Therefore, if _any_ of the return values of the scf.if are mutated, then the analysis pass
|
| 675 |
+
# would mark _all_ of the yield args as mutated.
|
| 676 |
+
#
|
| 677 |
+
# **Solution**:
|
| 678 |
+
# For the purposes of this analysis pass, we create N yield ops - one for each
|
| 679 |
+
# return-val/yield-arg. In the example above, we'll have two scf.yield's for each branch of the
|
| 680 |
+
# scf.if.
|
| 681 |
+
|
| 682 |
+
for return_op in return_ops:
|
| 683 |
+
for i, (scf_result, yield_arg) in enumerate(
|
| 684 |
+
zip(scf_results, return_op.args)
|
| 685 |
+
):
|
| 686 |
+
sub_yield_op = Op(
|
| 687 |
+
return_op.name,
|
| 688 |
+
return_op.fn_call_name,
|
| 689 |
+
[yield_arg],
|
| 690 |
+
return_op.ret,
|
| 691 |
+
sub_idx=i,
|
| 692 |
+
)
|
| 693 |
+
op_stack[parent_block_id][scf_result].append(sub_yield_op)
|
| 694 |
+
|
| 695 |
+
else:
|
| 696 |
+
for scf_result in scf_results:
|
| 697 |
+
for return_op in return_ops:
|
| 698 |
+
op_stack[parent_block_id][scf_result].append(return_op)
|
| 699 |
+
else:
|
| 700 |
+
raise RuntimeError(
|
| 701 |
+
f"Unknown blocked function: {name}. Can't capture the TTIR."
|
| 702 |
+
)
|
| 703 |
+
else:
|
| 704 |
+
callee = None
|
| 705 |
+
if name == "tt.call":
|
| 706 |
+
callee = op.get_flat_symbol_ref_attr("callee")
|
| 707 |
+
args: list[Union[Param, Intermediate]] = [
|
| 708 |
+
Intermediate(operand) for operand in operand_ids
|
| 709 |
+
]
|
| 710 |
+
block_ops = op_stack[parent_block_id]
|
| 711 |
+
|
| 712 |
+
is_pure = False
|
| 713 |
+
# Handle the case for tt.elementwise_inline_asm to set `is_pure` for mutation analysis
|
| 714 |
+
if name == "tt.elementwise_inline_asm":
|
| 715 |
+
is_pure = op.get_bool_attr("pure")
|
| 716 |
+
|
| 717 |
+
if result_ids:
|
| 718 |
+
for result_id in result_ids:
|
| 719 |
+
res = Intermediate(result_id)
|
| 720 |
+
block_ops[res].append(Op(name, callee, args, res, is_pure=is_pure))
|
| 721 |
+
else:
|
| 722 |
+
next_fake_intermediate -= 1
|
| 723 |
+
fake_res = Intermediate(next_fake_intermediate)
|
| 724 |
+
block_ops[fake_res].append(
|
| 725 |
+
Op(name, callee, args, fake_res, is_pure=is_pure)
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
ttir_module.walk(mlir_to_functions)
|
| 729 |
+
|
| 730 |
+
return functions
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
class MemoizeWithCycleCheck:
|
| 734 |
+
fn: Callable[..., Any]
|
| 735 |
+
cache: dict[tuple[Any], Any]
|
| 736 |
+
|
| 737 |
+
def __init__(self, fn: Callable[..., Any]) -> None:
|
| 738 |
+
self.fn = fn
|
| 739 |
+
self.reset()
|
| 740 |
+
|
| 741 |
+
def __call__(
|
| 742 |
+
self,
|
| 743 |
+
functions: dict[str, dict[Intermediate, list[Op]]],
|
| 744 |
+
fn_name: str,
|
| 745 |
+
*args: Any,
|
| 746 |
+
) -> list[bool]:
|
| 747 |
+
key: tuple[Any, ...] = (fn_name, *args)
|
| 748 |
+
if key not in self.cache:
|
| 749 |
+
self.cache[key] = None
|
| 750 |
+
self.cache[key] = self.fn(functions, fn_name, *args)
|
| 751 |
+
if self.cache[key] is None:
|
| 752 |
+
raise RuntimeError("Recursion is not supported")
|
| 753 |
+
return self.cache[key]
|
| 754 |
+
|
| 755 |
+
def reset(self) -> None:
|
| 756 |
+
self.cache = {}
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
@MemoizeWithCycleCheck
|
| 760 |
+
def get_tma_stores(
|
| 761 |
+
functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str
|
| 762 |
+
) -> set[Union[Intermediate, Param]]:
|
| 763 |
+
"""
|
| 764 |
+
Identifies all intermediates and parameters that are written to by a
|
| 765 |
+
`tt.experimental_descriptor_store`. It tracks only the specific values
|
| 766 |
+
written to via experimental_descriptor_store and the input values to
|
| 767 |
+
`tt.reinterpret_tensor_descriptor` used to construct the direct inputs
|
| 768 |
+
to tt.experimental_descriptor_store - not any recursive values
|
| 769 |
+
used to construct those values.
|
| 770 |
+
|
| 771 |
+
For example: for
|
| 772 |
+
tt.reinterpret_tensor_descriptor(Intermediate(idx=0), ...)
|
| 773 |
+
Intermediate(idx=1) = tt.experimental_descriptor_store(Intermediate(idx=0), ...)
|
| 774 |
+
this function will return [Intermediate(idx=0), Intermediate(idx=1)],
|
| 775 |
+
|
| 776 |
+
However
|
| 777 |
+
Intermediate(idx=4) = arith.addptr(Intermediate(idx=2), Intermediate(idx=3))
|
| 778 |
+
Intermediate(idx=5) = tt.experimental_descriptor_store(Intermediate(idx=4), ...)
|
| 779 |
+
tt.experimental_descriptor_store(Intermediate(idx=5), ...)
|
| 780 |
+
this function will mark only idx=4 and idx=5 (but not idx=2 or idx=3)
|
| 781 |
+
|
| 782 |
+
If an intermediate/parameter is passed into a function and is written to
|
| 783 |
+
via experimental_descriptor_store within that function, the argument to the
|
| 784 |
+
function will also be marked.
|
| 785 |
+
"""
|
| 786 |
+
|
| 787 |
+
result: set[Union[Intermediate, Param]] = set()
|
| 788 |
+
|
| 789 |
+
ops = functions[fn_name]
|
| 790 |
+
for op_list in ops.values():
|
| 791 |
+
for op in op_list:
|
| 792 |
+
if op.name == "tt.call":
|
| 793 |
+
assert op.fn_call_name in functions
|
| 794 |
+
tma_stores = get_tma_stores(functions, op.fn_call_name)
|
| 795 |
+
for i, inp in enumerate(op.args):
|
| 796 |
+
if Param(idx=i) in tma_stores:
|
| 797 |
+
result.add(inp)
|
| 798 |
+
elif op.name == "tt.experimental_descriptor_store":
|
| 799 |
+
assert len(op.args) >= 1
|
| 800 |
+
result.add(op.args[0])
|
| 801 |
+
|
| 802 |
+
for val in list(result):
|
| 803 |
+
if val in ops:
|
| 804 |
+
if not isinstance(val, Intermediate):
|
| 805 |
+
continue
|
| 806 |
+
for op in ops[val]:
|
| 807 |
+
if op.name == "tt.reinterpret_tensor_descriptor":
|
| 808 |
+
assert len(op.args) >= 1
|
| 809 |
+
result.add(op.args[0])
|
| 810 |
+
|
| 811 |
+
return result
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
@MemoizeWithCycleCheck
|
| 815 |
+
def analyze_kernel_mutations(
|
| 816 |
+
functions: dict[str, dict[Intermediate, list[Op]]], fn_name: str, num_args: int
|
| 817 |
+
) -> list[bool]:
|
| 818 |
+
"""
|
| 819 |
+
Analyzes the graph to detect all sinks from a predefined list of sinks
|
| 820 |
+
by using triton's MemWrite trait list. NOTE: What if triton exposed this?
|
| 821 |
+
From each sink, it traverses the CFG backwards to identify all the input
|
| 822 |
+
pointers that are mutated.
|
| 823 |
+
"""
|
| 824 |
+
# Name of mutation op to mutated parameter indices
|
| 825 |
+
# List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
|
| 826 |
+
# All the OPs that have MemWrite trait.
|
| 827 |
+
# What if Triton exposed this?
|
| 828 |
+
MUTATION_OPS = {
|
| 829 |
+
"tt.store": [0],
|
| 830 |
+
"tt.atomic_cas": [0],
|
| 831 |
+
"tt.atomic_rmw": [0],
|
| 832 |
+
"tt.experimental_descriptor_store": [0],
|
| 833 |
+
"tt.experimental_tensormap_create": [0],
|
| 834 |
+
"tt.descriptor_store": [0],
|
| 835 |
+
}
|
| 836 |
+
# Ops that we want to bail out on
|
| 837 |
+
UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
|
| 838 |
+
|
| 839 |
+
stack: list[Union[Param, Intermediate]] = []
|
| 840 |
+
visited = set()
|
| 841 |
+
ops = functions[fn_name]
|
| 842 |
+
tma_stores = get_tma_stores(functions, fn_name)
|
| 843 |
+
|
| 844 |
+
for op_list in ops.values():
|
| 845 |
+
for op in op_list:
|
| 846 |
+
# If we encounter an operation with effects that cannot be reliably analyzed
|
| 847 |
+
# (e.g. `tt.elementwise_inline_asm`), we assume it does not mutate any input parameters.
|
| 848 |
+
if op.name in UNKNOWN_OPS:
|
| 849 |
+
if op.name == "tt.elementwise_inline_asm" and op.is_pure:
|
| 850 |
+
log.warning(
|
| 851 |
+
"TTIR mutation analysis: Skipping pure tt.elementwise_inline_asm op (is_pure=True)"
|
| 852 |
+
)
|
| 853 |
+
continue
|
| 854 |
+
raise RuntimeError(
|
| 855 |
+
f"ttir analysis hit an op we do not know how to analyze: {op.name}"
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
if op.name == "tt.experimental_tensormap_create":
|
| 859 |
+
# Note: this is how we implement experimental_descriptor_store mutation analysis.
|
| 860 |
+
# for on-device TMA.
|
| 861 |
+
# experimental_tensormap_store(a, b, ...) stores b to the location specified
|
| 862 |
+
# by descriptor in the memory of a.
|
| 863 |
+
# To track this, we first find all the intermediates/params to which we store via
|
| 864 |
+
# experimental_tensormap_store (get_tma_stores, called above). Then, during this
|
| 865 |
+
# analysis we wait to find the corresponding experimental_tensormap_create (if it
|
| 866 |
+
# exists), at which point we will mark the global_ptr as mutated (as done below).
|
| 867 |
+
assert len(op.args) >= 2
|
| 868 |
+
if op.args[0] in tma_stores:
|
| 869 |
+
stack.append(op.args[1])
|
| 870 |
+
|
| 871 |
+
if op.name == "tt.call":
|
| 872 |
+
assert op.fn_call_name in functions
|
| 873 |
+
mutations = analyze_kernel_mutations(
|
| 874 |
+
functions, op.fn_call_name, len(op.args)
|
| 875 |
+
)
|
| 876 |
+
stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
|
| 877 |
+
else:
|
| 878 |
+
stack.extend(op.args[idx] for idx in MUTATION_OPS.get(op.name, []))
|
| 879 |
+
|
| 880 |
+
# The following is an iterative DFS algorithm
|
| 881 |
+
mutated = [False] * num_args
|
| 882 |
+
while stack:
|
| 883 |
+
arg = stack.pop()
|
| 884 |
+
if arg in visited:
|
| 885 |
+
continue
|
| 886 |
+
|
| 887 |
+
visited.add(arg)
|
| 888 |
+
|
| 889 |
+
if isinstance(arg, Param):
|
| 890 |
+
if arg.idx >= num_args:
|
| 891 |
+
# This is an argument defined in the kernel, not passed in
|
| 892 |
+
continue
|
| 893 |
+
mutated[arg.idx] = True
|
| 894 |
+
elif isinstance(arg, Intermediate) and not arg.fake():
|
| 895 |
+
for op in ops[arg]:
|
| 896 |
+
# Skip arguments to load
|
| 897 |
+
if op.name != "tt.load":
|
| 898 |
+
stack.extend(op.args)
|
| 899 |
+
return mutated
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
def identify_mutated_tensors(
|
| 903 |
+
kernel: "TritonKernelType",
|
| 904 |
+
kwargs: dict[str, Any],
|
| 905 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 906 |
+
) -> list[str]:
|
| 907 |
+
"""
|
| 908 |
+
Given a triton kernel and the arguments for this kernel, this function
|
| 909 |
+
1) Retrieves the TTIR converted version of the kernel from Triton's API.
|
| 910 |
+
2) Parses the TTIR and creates a control flow graph
|
| 911 |
+
3) Analyzes the graph to detect all input tensor mutations
|
| 912 |
+
"""
|
| 913 |
+
|
| 914 |
+
ttir_module = None
|
| 915 |
+
functions = None
|
| 916 |
+
try:
|
| 917 |
+
ttir_module, ordered_tensor_names = generate_ttir(
|
| 918 |
+
kernel, kwargs, tma_descriptor_metadata
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
# extract functions from TTIR using MLIR bindings exposed by Triton code
|
| 922 |
+
functions = ttir_to_functions(ttir_module)
|
| 923 |
+
|
| 924 |
+
assert functions is not None
|
| 925 |
+
kernel_name = next(iter(functions.keys()))
|
| 926 |
+
# Triton codegen modifies the name
|
| 927 |
+
assert kernel.fn.__name__ in kernel_name
|
| 928 |
+
# Reset the cache between top level invocations
|
| 929 |
+
# The cache for analyze kernel mutations is mainly used for cycle
|
| 930 |
+
# detection, so each top level invocation needs a clean cache
|
| 931 |
+
analyze_kernel_mutations.reset()
|
| 932 |
+
get_tma_stores.reset()
|
| 933 |
+
mutations = analyze_kernel_mutations(
|
| 934 |
+
functions, kernel_name, len(ordered_tensor_names)
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
return [
|
| 938 |
+
ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated
|
| 939 |
+
]
|
| 940 |
+
except Exception:
|
| 941 |
+
log.warning(
|
| 942 |
+
"Encountered an exception in identify_mutated_tensors, assuming every input is mutated",
|
| 943 |
+
exc_info=True,
|
| 944 |
+
)
|
| 945 |
+
if ttir_module is not None:
|
| 946 |
+
log.debug("TTIR:\n%s", str(ttir_module))
|
| 947 |
+
if functions is not None:
|
| 948 |
+
log.debug("functions:")
|
| 949 |
+
for name, fn in functions.items():
|
| 950 |
+
log.debug("===\t%s\t===", name)
|
| 951 |
+
for ret, ops in fn.items():
|
| 952 |
+
log.debug("%s\t=>\t%s", ret, ops)
|
| 953 |
+
return [key for key, value in kwargs.items() if isinstance(value, Tensor)]
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
###############################################################################
|
| 957 |
+
# Triton Kernel Wrappers
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
# Used for wrapping a Triton Kernel
|
| 961 |
+
class TritonKernelWrapperMutation(HigherOrderOperator):
|
| 962 |
+
def __init__(self) -> None:
|
| 963 |
+
super().__init__("triton_kernel_wrapper_mutation", cacheable=True)
|
| 964 |
+
|
| 965 |
+
def __call__(
|
| 966 |
+
self,
|
| 967 |
+
kernel_idx: int,
|
| 968 |
+
constant_args_idx: int,
|
| 969 |
+
grid: list["TritonGridType"],
|
| 970 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 971 |
+
kwargs: dict[str, Any],
|
| 972 |
+
) -> Any:
|
| 973 |
+
return super().__call__(
|
| 974 |
+
kernel_idx=kernel_idx,
|
| 975 |
+
constant_args_idx=constant_args_idx,
|
| 976 |
+
grid=grid,
|
| 977 |
+
tma_descriptor_metadata=tma_descriptor_metadata,
|
| 978 |
+
kwargs=kwargs,
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
# Used for wrapping a Triton Kernel in a functional manner
|
| 986 |
+
class TritonKernelWrapperFunctional(HigherOrderOperator):
|
| 987 |
+
def __init__(self) -> None:
|
| 988 |
+
super().__init__("triton_kernel_wrapper_functional", cacheable=True)
|
| 989 |
+
|
| 990 |
+
def __call__(
|
| 991 |
+
self,
|
| 992 |
+
kernel_idx: int,
|
| 993 |
+
constant_args_idx: int,
|
| 994 |
+
grid: list["TritonGridType"],
|
| 995 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 996 |
+
kwargs: dict[str, Any],
|
| 997 |
+
tensors_to_clone: list[str],
|
| 998 |
+
) -> dict[str, Any]:
|
| 999 |
+
return super().__call__(
|
| 1000 |
+
kernel_idx=kernel_idx,
|
| 1001 |
+
constant_args_idx=constant_args_idx,
|
| 1002 |
+
grid=grid,
|
| 1003 |
+
tma_descriptor_metadata=tma_descriptor_metadata,
|
| 1004 |
+
kwargs=kwargs,
|
| 1005 |
+
tensors_to_clone=tensors_to_clone,
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 1013 |
+
def triton_kernel_wrapper_mutation_dense(
|
| 1014 |
+
*,
|
| 1015 |
+
kernel_idx: int,
|
| 1016 |
+
constant_args_idx: int,
|
| 1017 |
+
grid: list["TritonGridType"],
|
| 1018 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1019 |
+
kwargs: dict[str, Any],
|
| 1020 |
+
) -> None:
|
| 1021 |
+
from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
|
| 1022 |
+
|
| 1023 |
+
kernel = kernel_side_table.get_kernel(kernel_idx)
|
| 1024 |
+
constant_args = kernel_side_table.get_constant_args(constant_args_idx)
|
| 1025 |
+
|
| 1026 |
+
if len(grid) == 1:
|
| 1027 |
+
grid_fn = grid[0]
|
| 1028 |
+
else:
|
| 1029 |
+
fn_name, code = user_defined_kernel_grid_fn_code(
|
| 1030 |
+
kernel.fn.__name__, kernel.configs, grid
|
| 1031 |
+
)
|
| 1032 |
+
namespace: dict[str, Any] = {}
|
| 1033 |
+
exec(code, namespace)
|
| 1034 |
+
grid_fn = namespace[fn_name]
|
| 1035 |
+
|
| 1036 |
+
if tma_descriptor_metadata:
|
| 1037 |
+
# as we need to launch the kernel here, we "unwrap" the
|
| 1038 |
+
# tma_descriptor_metadata, create the TMA descriptors
|
| 1039 |
+
# from it, and replace the tensors in the kwargs by the
|
| 1040 |
+
# correspoinding TMA descriptors before launching
|
| 1041 |
+
kwargs = kwargs.copy()
|
| 1042 |
+
for k, v in tma_descriptor_metadata.items():
|
| 1043 |
+
tensor = kwargs[k]
|
| 1044 |
+
if (exp_meta := maybe_unpack_tma_experimental_metadata(v)) is not None:
|
| 1045 |
+
from triton.tools.experimental_descriptor import ( # noqa: F401
|
| 1046 |
+
create_1d_tma_descriptor,
|
| 1047 |
+
create_2d_tma_descriptor,
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
dims, block_dims, element_size = exp_meta
|
| 1051 |
+
create_tma_descriptor = (
|
| 1052 |
+
create_1d_tma_descriptor
|
| 1053 |
+
if len(dims) == 1
|
| 1054 |
+
else create_2d_tma_descriptor
|
| 1055 |
+
)
|
| 1056 |
+
kwargs[k] = create_tma_descriptor(
|
| 1057 |
+
tensor.data_ptr(),
|
| 1058 |
+
*dims,
|
| 1059 |
+
*block_dims,
|
| 1060 |
+
element_size,
|
| 1061 |
+
)
|
| 1062 |
+
else:
|
| 1063 |
+
stable_meta = maybe_unpack_tma_stable_metadata(v)
|
| 1064 |
+
assert stable_meta is not None
|
| 1065 |
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
| 1066 |
+
|
| 1067 |
+
block_shape = stable_meta[0]
|
| 1068 |
+
kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape)
|
| 1069 |
+
|
| 1070 |
+
# move as many positional arguments from dicts to args as we
|
| 1071 |
+
# can to circumvent the bug with the kwargs and pre_/post_hook:
|
| 1072 |
+
# https://github.com/triton-lang/triton/issues/5082
|
| 1073 |
+
# TODO: remove this when the Triton issue above is fixed
|
| 1074 |
+
args = []
|
| 1075 |
+
# copy kwargs and constant_args here to
|
| 1076 |
+
# avoid mutating the original inputs
|
| 1077 |
+
kwargs = kwargs.copy()
|
| 1078 |
+
constant_args = constant_args.copy()
|
| 1079 |
+
for name in kernel.arg_names:
|
| 1080 |
+
if name in kwargs:
|
| 1081 |
+
args.append(kwargs.pop(name))
|
| 1082 |
+
elif name in constant_args:
|
| 1083 |
+
args.append(constant_args.pop(name))
|
| 1084 |
+
else:
|
| 1085 |
+
break
|
| 1086 |
+
|
| 1087 |
+
kernel[grid_fn](*args, **kwargs, **constant_args)
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode)
|
| 1091 |
+
def triton_kernel_wrapper_mutation_fake_tensor_mode(
|
| 1092 |
+
mode: FakeTensorMode,
|
| 1093 |
+
*,
|
| 1094 |
+
kernel_idx: int,
|
| 1095 |
+
constant_args_idx: int,
|
| 1096 |
+
grid: list["TritonGridType"],
|
| 1097 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1098 |
+
kwargs: dict[str, Any],
|
| 1099 |
+
) -> None:
|
| 1100 |
+
with mode:
|
| 1101 |
+
return None
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.Meta)
|
| 1105 |
+
def _(
|
| 1106 |
+
*,
|
| 1107 |
+
kernel_idx: int,
|
| 1108 |
+
constant_args_idx: int,
|
| 1109 |
+
grid: list["TritonGridType"],
|
| 1110 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1111 |
+
kwargs: dict[str, Any],
|
| 1112 |
+
) -> None:
|
| 1113 |
+
return None
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
def trace_triton_kernel_wrapper(
|
| 1117 |
+
proxy_mode: ProxyTorchDispatchMode,
|
| 1118 |
+
func_overload: Callable[..., Any],
|
| 1119 |
+
node_args: dict[str, Any],
|
| 1120 |
+
) -> Optional[dict[str, Any]]:
|
| 1121 |
+
with disable_proxy_modes_tracing():
|
| 1122 |
+
out = func_overload(**node_args)
|
| 1123 |
+
|
| 1124 |
+
proxy_args = pytree.tree_map(
|
| 1125 |
+
proxy_mode.tracer.unwrap_proxy, node_args # type: ignore[union-attr]
|
| 1126 |
+
)
|
| 1127 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 1128 |
+
"call_function",
|
| 1129 |
+
func_overload,
|
| 1130 |
+
(),
|
| 1131 |
+
proxy_args,
|
| 1132 |
+
name=func_overload.__name__ + "_proxy",
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
ret = track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
| 1136 |
+
return ret
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
|
| 1140 |
+
def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
|
| 1141 |
+
mode: ProxyTorchDispatchMode,
|
| 1142 |
+
*,
|
| 1143 |
+
kernel_idx: int,
|
| 1144 |
+
constant_args_idx: int,
|
| 1145 |
+
grid: list["TritonGridType"],
|
| 1146 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1147 |
+
kwargs: dict[str, Any],
|
| 1148 |
+
) -> None:
|
| 1149 |
+
trace_triton_kernel_wrapper(
|
| 1150 |
+
mode,
|
| 1151 |
+
triton_kernel_wrapper_mutation,
|
| 1152 |
+
{
|
| 1153 |
+
"kernel_idx": kernel_idx,
|
| 1154 |
+
"constant_args_idx": constant_args_idx,
|
| 1155 |
+
"grid": grid,
|
| 1156 |
+
"tma_descriptor_metadata": tma_descriptor_metadata,
|
| 1157 |
+
"kwargs": kwargs,
|
| 1158 |
+
},
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
return None
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
def get_mutated_tensors(
|
| 1165 |
+
kernel_idx: int,
|
| 1166 |
+
constant_args_idx: int,
|
| 1167 |
+
kwargs: dict[str, Any],
|
| 1168 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1169 |
+
) -> list[str]:
|
| 1170 |
+
kernel = kernel_side_table.get_kernel(kernel_idx)
|
| 1171 |
+
constant_args = kernel_side_table.get_constant_args(constant_args_idx)
|
| 1172 |
+
return identify_mutated_tensors(
|
| 1173 |
+
kernel, {**kwargs, **constant_args}, tma_descriptor_metadata
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
@triton_kernel_wrapper_mutation.py_functionalize_impl
|
| 1178 |
+
def triton_kernel_wrapper_mutation_functionalize(
|
| 1179 |
+
ctx: "BaseFunctionalizeAPI",
|
| 1180 |
+
kernel_idx: int,
|
| 1181 |
+
constant_args_idx: int,
|
| 1182 |
+
grid: list["TritonGridType"],
|
| 1183 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1184 |
+
kwargs: dict[str, Any],
|
| 1185 |
+
) -> None:
|
| 1186 |
+
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
|
| 1187 |
+
# TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
|
| 1188 |
+
# other, and one gets mutated in kernel, and later another gets mutated,
|
| 1189 |
+
# they are no longer equal. Fix this by graph breaking on this condition
|
| 1190 |
+
# earlier in dynamo.
|
| 1191 |
+
tensors_to_clone = get_mutated_tensors(
|
| 1192 |
+
kernel_idx, constant_args_idx, unwrapped_kwargs, tma_descriptor_metadata
|
| 1193 |
+
)
|
| 1194 |
+
with ctx.redispatch_to_next():
|
| 1195 |
+
unwrapped_outputs = triton_kernel_wrapper_functional(
|
| 1196 |
+
kernel_idx=kernel_idx,
|
| 1197 |
+
constant_args_idx=constant_args_idx,
|
| 1198 |
+
grid=grid,
|
| 1199 |
+
tma_descriptor_metadata=tma_descriptor_metadata,
|
| 1200 |
+
kwargs=unwrapped_kwargs,
|
| 1201 |
+
tensors_to_clone=tensors_to_clone,
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys()))
|
| 1205 |
+
for key, output_arg in unwrapped_outputs.items():
|
| 1206 |
+
if not isinstance(output_arg, Tensor):
|
| 1207 |
+
continue
|
| 1208 |
+
input_arg = kwargs[key]
|
| 1209 |
+
assert isinstance(input_arg, Tensor)
|
| 1210 |
+
|
| 1211 |
+
ctx.replace(input_arg, output_arg)
|
| 1212 |
+
# indicate that above replace is hidden from autograd
|
| 1213 |
+
ctx.mark_mutation_hidden_from_autograd(input_arg)
|
| 1214 |
+
ctx.commit_update(input_arg)
|
| 1215 |
+
ctx.sync(input_arg)
|
| 1216 |
+
return None
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 1220 |
+
def triton_kernel_wrapper_functional_dense(
|
| 1221 |
+
*,
|
| 1222 |
+
kernel_idx: int,
|
| 1223 |
+
constant_args_idx: int,
|
| 1224 |
+
grid: list["TritonGridType"],
|
| 1225 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1226 |
+
kwargs: dict[str, Any],
|
| 1227 |
+
tensors_to_clone: list[str],
|
| 1228 |
+
) -> dict[str, Any]:
|
| 1229 |
+
# TODO(oulgen): For performance reasons, we want to ensure that these
|
| 1230 |
+
# `clone_preserve_strides` calls are never executed at runtime
|
| 1231 |
+
# (inductor should always optimize them away).
|
| 1232 |
+
# Requires https://github.com/pytorch/pytorch/issues/109240
|
| 1233 |
+
kwargs = {
|
| 1234 |
+
key: (clone_preserve_strides(val) if key in tensors_to_clone else val)
|
| 1235 |
+
for key, val in kwargs.items()
|
| 1236 |
+
}
|
| 1237 |
+
triton_kernel_wrapper_mutation(
|
| 1238 |
+
kernel_idx=kernel_idx,
|
| 1239 |
+
constant_args_idx=constant_args_idx,
|
| 1240 |
+
grid=grid,
|
| 1241 |
+
tma_descriptor_metadata=tma_descriptor_metadata,
|
| 1242 |
+
kwargs=kwargs,
|
| 1243 |
+
)
|
| 1244 |
+
return {key: val for key, val in kwargs.items() if key in tensors_to_clone}
|
| 1245 |
+
|
| 1246 |
+
|
| 1247 |
+
@triton_kernel_wrapper_functional.py_impl(FakeTensorMode)
|
| 1248 |
+
def triton_kernel_wrapper_functional_fake_tensor_mode(
|
| 1249 |
+
mode: FakeTensorMode,
|
| 1250 |
+
*,
|
| 1251 |
+
kernel_idx: int,
|
| 1252 |
+
constant_args_idx: int,
|
| 1253 |
+
grid: list["TritonGridType"],
|
| 1254 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1255 |
+
kwargs: dict[str, Any],
|
| 1256 |
+
tensors_to_clone: list[str],
|
| 1257 |
+
) -> dict[str, Any]:
|
| 1258 |
+
# TODO(oulgen): For performance reasons, we want to ensure that these
|
| 1259 |
+
# `clone_preserve_strides` calls are never executed at runtime
|
| 1260 |
+
# (inductor should always optimize them away).
|
| 1261 |
+
# Requires https://github.com/pytorch/pytorch/issues/109240
|
| 1262 |
+
with mode:
|
| 1263 |
+
return {
|
| 1264 |
+
key: clone_preserve_strides(val)
|
| 1265 |
+
for key, val in kwargs.items()
|
| 1266 |
+
if key in tensors_to_clone
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
|
| 1270 |
+
@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode)
|
| 1271 |
+
def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
|
| 1272 |
+
mode: ProxyTorchDispatchMode,
|
| 1273 |
+
*,
|
| 1274 |
+
kernel_idx: int,
|
| 1275 |
+
constant_args_idx: int,
|
| 1276 |
+
grid: list["TritonGridType"],
|
| 1277 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1278 |
+
kwargs: dict[str, Any],
|
| 1279 |
+
tensors_to_clone: list[str],
|
| 1280 |
+
) -> dict[str, Any]:
|
| 1281 |
+
ret = trace_triton_kernel_wrapper(
|
| 1282 |
+
mode,
|
| 1283 |
+
triton_kernel_wrapper_functional,
|
| 1284 |
+
{
|
| 1285 |
+
"kernel_idx": kernel_idx,
|
| 1286 |
+
"constant_args_idx": constant_args_idx,
|
| 1287 |
+
"grid": grid,
|
| 1288 |
+
"tma_descriptor_metadata": tma_descriptor_metadata,
|
| 1289 |
+
"kwargs": kwargs,
|
| 1290 |
+
"tensors_to_clone": tensors_to_clone,
|
| 1291 |
+
},
|
| 1292 |
+
)
|
| 1293 |
+
assert ret is not None
|
| 1294 |
+
return ret
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
@triton_kernel_wrapper_functional.py_functionalize_impl
|
| 1298 |
+
def triton_kernel_wrapper_functional_functionalize(
|
| 1299 |
+
ctx: "BaseFunctionalizeAPI",
|
| 1300 |
+
kernel_idx: int,
|
| 1301 |
+
constant_args_idx: int,
|
| 1302 |
+
grid: list["TritonGridType"],
|
| 1303 |
+
tma_descriptor_metadata: TMADescriptorMetadata,
|
| 1304 |
+
kwargs: dict[str, Any],
|
| 1305 |
+
tensors_to_clone: list[str],
|
| 1306 |
+
) -> dict[str, Any]:
|
| 1307 |
+
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
|
| 1308 |
+
with ctx.redispatch_to_next():
|
| 1309 |
+
outputs = triton_kernel_wrapper_functional(
|
| 1310 |
+
kernel_idx=kernel_idx,
|
| 1311 |
+
constant_args_idx=constant_args_idx,
|
| 1312 |
+
grid=grid,
|
| 1313 |
+
tma_descriptor_metadata=tma_descriptor_metadata,
|
| 1314 |
+
kwargs=unwrapped_kwargs,
|
| 1315 |
+
tensors_to_clone=tensors_to_clone,
|
| 1316 |
+
)
|
| 1317 |
+
return ctx.wrap_tensors(outputs) # type: ignore[return-value,arg-type]
|
| 1318 |
+
|
| 1319 |
+
|
| 1320 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
|
| 1321 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
|
| 1322 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView)
|
| 1323 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect)
|
| 1324 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
|
| 1325 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
|
| 1326 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA)
|
| 1327 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU)
|
| 1328 |
+
|
| 1329 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
|
| 1330 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
|
| 1331 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView)
|
| 1332 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect)
|
| 1333 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
|
| 1334 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
|
| 1335 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
|
| 1336 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
|
| 1337 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU)
|
| 1338 |
+
|
| 1339 |
+
|
| 1340 |
+
###############################################################################
|
| 1341 |
+
# The "TritonHOPifier": a class that transforms a call to a triton kernel into
|
| 1342 |
+
# a call to the triton_kernel_wrapper_mutation HOP.
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
class TritonHOPifier:
|
| 1346 |
+
"""Orchestrator for converting a user-defined triton kernel into a call
|
| 1347 |
+
to the triton_kernel_wrapper_mutation HOP.
|
| 1348 |
+
|
| 1349 |
+
It has two main use cases.
|
| 1350 |
+
|
| 1351 |
+
1. When Dynamo sees a triton kernel, it wraps it into a TritonKernelVariable
|
| 1352 |
+
and uses the TritonHOPifier to convert calls to the TritonKernelVariable
|
| 1353 |
+
into a call to the HOP.
|
| 1354 |
+
|
| 1355 |
+
2. In order to capture a user-defined triton kernel while performing
|
| 1356 |
+
tracing (via make_fx or non-strict export), a user must annotate their
|
| 1357 |
+
triton kernel with the `wrap_triton` decorator. The decorator uses
|
| 1358 |
+
TritonHOPifier to convert calls to the triton kernel into a call
|
| 1359 |
+
to the HOP (which can then be traced).
|
| 1360 |
+
|
| 1361 |
+
Because Dynamo has its own calling conventions for e.g. invoking a user-defined function
|
| 1362 |
+
TritonHOPifier is an abstract class that can be overridden by its subclasses.
|
| 1363 |
+
"""
|
| 1364 |
+
|
| 1365 |
+
def raise_unsupported(self, msg: str) -> Never:
|
| 1366 |
+
raise NotImplementedError("abstract method")
|
| 1367 |
+
|
| 1368 |
+
def is_callable(self, maybe_callable: Any) -> bool:
|
| 1369 |
+
raise NotImplementedError("abstract method")
|
| 1370 |
+
|
| 1371 |
+
def get_value(self, val: Any) -> Any:
|
| 1372 |
+
raise NotImplementedError("abstract method")
|
| 1373 |
+
|
| 1374 |
+
def call_grid( # type: ignore[no-untyped-def]
|
| 1375 |
+
self,
|
| 1376 |
+
grid,
|
| 1377 |
+
meta,
|
| 1378 |
+
tx,
|
| 1379 |
+
) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]:
|
| 1380 |
+
raise NotImplementedError("abstract method")
|
| 1381 |
+
|
| 1382 |
+
def wrap_user_defined_obj(
|
| 1383 |
+
self,
|
| 1384 |
+
user_obj: Any,
|
| 1385 |
+
tx: Optional["InstructionTranslator"],
|
| 1386 |
+
variable: Optional[
|
| 1387 |
+
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
| 1388 |
+
],
|
| 1389 |
+
name: str,
|
| 1390 |
+
) -> Any:
|
| 1391 |
+
raise NotImplementedError("abstract method")
|
| 1392 |
+
|
| 1393 |
+
def call_user_defined_fn(
|
| 1394 |
+
self,
|
| 1395 |
+
user_fn: Callable[..., Any],
|
| 1396 |
+
args: list,
|
| 1397 |
+
kwargs: dict,
|
| 1398 |
+
tx: Optional["InstructionTranslator"],
|
| 1399 |
+
variable: Optional[
|
| 1400 |
+
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
| 1401 |
+
],
|
| 1402 |
+
) -> Any:
|
| 1403 |
+
raise NotImplementedError("abstract method")
|
| 1404 |
+
|
| 1405 |
+
def maybe_unpack_configs(
|
| 1406 |
+
self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"]
|
| 1407 |
+
) -> list["TritonConfig"]:
|
| 1408 |
+
raise NotImplementedError("abstract method")
|
| 1409 |
+
|
| 1410 |
+
def maybe_unpack_heuristic_result(self, result: Any) -> Any:
|
| 1411 |
+
raise NotImplementedError("abstract method")
|
| 1412 |
+
|
| 1413 |
+
@staticmethod
|
| 1414 |
+
def do_prune_configs( # type: ignore[no-untyped-def]
|
| 1415 |
+
autotuner: "TritonAutotunerType",
|
| 1416 |
+
early_config_prune: Optional[Callable],
|
| 1417 |
+
perf_model: Optional[Callable],
|
| 1418 |
+
top_k: float,
|
| 1419 |
+
configs: list,
|
| 1420 |
+
named_args: dict,
|
| 1421 |
+
kwargs: dict,
|
| 1422 |
+
) -> list["TritonConfig"]:
|
| 1423 |
+
# Reimplement autotuner.prune_configs(...) here
|
| 1424 |
+
# see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
|
| 1425 |
+
# We do this to avoid calling prune_configs, which in turn calls early_config_prune and perf_model
|
| 1426 |
+
# These are both user-defined functions which can contain side effects, so we want to sandbox them in Dynamo
|
| 1427 |
+
|
| 1428 |
+
if early_config_prune:
|
| 1429 |
+
configs = early_config_prune(configs, named_args, **kwargs)
|
| 1430 |
+
|
| 1431 |
+
if perf_model:
|
| 1432 |
+
# we assert top_k is a float before calling this
|
| 1433 |
+
if isinstance(top_k, float) and top_k <= 1.0:
|
| 1434 |
+
top_k = int(len(configs) * top_k)
|
| 1435 |
+
elif not isinstance(top_k, int):
|
| 1436 |
+
"""
|
| 1437 |
+
Slice index must be an integer, SupportsIndex or None
|
| 1438 |
+
"""
|
| 1439 |
+
raise TypeError(
|
| 1440 |
+
"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
|
| 1441 |
+
)
|
| 1442 |
+
if len(configs) > top_k:
|
| 1443 |
+
est_timing = [
|
| 1444 |
+
(
|
| 1445 |
+
config,
|
| 1446 |
+
float(
|
| 1447 |
+
perf_model(**named_args, **kwargs, **config.all_kwargs())
|
| 1448 |
+
),
|
| 1449 |
+
)
|
| 1450 |
+
for config in configs
|
| 1451 |
+
]
|
| 1452 |
+
configs = [
|
| 1453 |
+
config[0]
|
| 1454 |
+
for config in sorted(est_timing, key=operator.itemgetter(1))[:top_k]
|
| 1455 |
+
]
|
| 1456 |
+
return configs
|
| 1457 |
+
|
| 1458 |
+
def call_HOP( # type: ignore[no-untyped-def]
|
| 1459 |
+
self,
|
| 1460 |
+
variable,
|
| 1461 |
+
grids,
|
| 1462 |
+
combined_args: dict[str, Any],
|
| 1463 |
+
tx,
|
| 1464 |
+
) -> Optional["ConstantVariable"]:
|
| 1465 |
+
raise NotImplementedError("abstract method")
|
| 1466 |
+
|
| 1467 |
+
def check_grid( # type: ignore[no-untyped-def]
|
| 1468 |
+
self, grid
|
| 1469 |
+
) -> Union[tuple[Union[int, sympy.Expr, SymInt], ...], tuple["Proxy", ...]]:
|
| 1470 |
+
raise NotImplementedError("abstract method")
|
| 1471 |
+
|
| 1472 |
+
def init_variable(
|
| 1473 |
+
self,
|
| 1474 |
+
variable: Union["TraceableTritonKernelWrapper", "TritonKernelVariable"],
|
| 1475 |
+
kernel: "TritonKernelType",
|
| 1476 |
+
kernel_idx: Optional[int],
|
| 1477 |
+
grid: Optional["TritonGridType"],
|
| 1478 |
+
) -> None:
|
| 1479 |
+
from triton.runtime.autotuner import Autotuner
|
| 1480 |
+
|
| 1481 |
+
assert kernel is not None
|
| 1482 |
+
|
| 1483 |
+
variable.kernel = kernel
|
| 1484 |
+
variable.kernel_idx = kernel_side_table.add_kernel(kernel)
|
| 1485 |
+
|
| 1486 |
+
assert kernel_idx is None or variable.kernel_idx == kernel_idx
|
| 1487 |
+
|
| 1488 |
+
variable.grid = grid
|
| 1489 |
+
|
| 1490 |
+
if isinstance(kernel, Autotuner):
|
| 1491 |
+
import torch
|
| 1492 |
+
import torch._dynamo
|
| 1493 |
+
|
| 1494 |
+
# We only support configs, keys, and restore_value arguments
|
| 1495 |
+
# of triton.autotune. Make sure other arguments are defaulted.
|
| 1496 |
+
defaults = inspect.signature(Autotuner.__init__).parameters
|
| 1497 |
+
# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
|
| 1498 |
+
# The call to get_first_attr is to maintain backward-compatibility.
|
| 1499 |
+
|
| 1500 |
+
def defaults_ok(
|
| 1501 |
+
attr: str, alternates: tuple[str, ...], values: tuple[Any, ...]
|
| 1502 |
+
) -> bool:
|
| 1503 |
+
if attr not in defaults:
|
| 1504 |
+
return True
|
| 1505 |
+
value = torch._dynamo.utils.get_first_attr(kernel, attr, *alternates)
|
| 1506 |
+
if value == defaults[attr].default:
|
| 1507 |
+
return True
|
| 1508 |
+
return value in values
|
| 1509 |
+
|
| 1510 |
+
if (
|
| 1511 |
+
not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args
|
| 1512 |
+
and (
|
| 1513 |
+
not defaults_ok("num_warmups", ("warmup",), (25, None))
|
| 1514 |
+
or not defaults_ok("num_reps", ("rep",), (100, None))
|
| 1515 |
+
or not defaults_ok("use_cuda_graph", (), (False,))
|
| 1516 |
+
)
|
| 1517 |
+
):
|
| 1518 |
+
self.raise_unsupported(
|
| 1519 |
+
"Only configs, keys, restore_value, and reset_to_zero are supported for triton.autotune"
|
| 1520 |
+
)
|
| 1521 |
+
if (
|
| 1522 |
+
not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args
|
| 1523 |
+
and (
|
| 1524 |
+
# pre_hook requires running arbitrary code at runtime, which we cannot handle at this time
|
| 1525 |
+
# https://github.com/pytorch/pytorch/issues/139059
|
| 1526 |
+
# we can't support pre_hook or post_hook in user defined triton kernels at the moment,
|
| 1527 |
+
# as they require the ability to execute code at runtime (AOTI can't support this)
|
| 1528 |
+
(
|
| 1529 |
+
hasattr(kernel, "user_defined_pre_hook")
|
| 1530 |
+
and kernel.user_defined_pre_hook is not False
|
| 1531 |
+
)
|
| 1532 |
+
or (
|
| 1533 |
+
hasattr(kernel, "user_defined_post_hook")
|
| 1534 |
+
and kernel.user_defined_post_hook is not False
|
| 1535 |
+
)
|
| 1536 |
+
or (
|
| 1537 |
+
# Check Config passed to autotuner in configs
|
| 1538 |
+
any(cfg.pre_hook is not None for cfg in kernel.configs)
|
| 1539 |
+
)
|
| 1540 |
+
)
|
| 1541 |
+
):
|
| 1542 |
+
self.raise_unsupported(
|
| 1543 |
+
"pre_hook and post_hook are not supported in triton.Autotune or triton.Config"
|
| 1544 |
+
)
|
| 1545 |
+
|
| 1546 |
+
def call_getitem(
|
| 1547 |
+
self,
|
| 1548 |
+
variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
|
| 1549 |
+
args: Sequence[Any],
|
| 1550 |
+
) -> Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]:
|
| 1551 |
+
# __getitem__ should only be called if we don't already have a grid
|
| 1552 |
+
# Only grid needs to be passed
|
| 1553 |
+
if variable.grid is not None or len(args) != 1:
|
| 1554 |
+
self.raise_unsupported(
|
| 1555 |
+
"Triton kernels should be called with only a single grid"
|
| 1556 |
+
)
|
| 1557 |
+
|
| 1558 |
+
return type(variable)(
|
| 1559 |
+
kernel=variable.kernel,
|
| 1560 |
+
kernel_idx=variable.kernel_idx,
|
| 1561 |
+
grid=args[0],
|
| 1562 |
+
)
|
| 1563 |
+
|
| 1564 |
+
def call_run(
|
| 1565 |
+
self,
|
| 1566 |
+
variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
|
| 1567 |
+
args: Sequence[Any],
|
| 1568 |
+
kwargs: dict[str, Any],
|
| 1569 |
+
tx: Optional["InstructionTranslator"],
|
| 1570 |
+
) -> Optional["ConstantVariable"]:
|
| 1571 |
+
if "grid" not in kwargs:
|
| 1572 |
+
self.raise_unsupported("Triton kernel requires to be called with a grid")
|
| 1573 |
+
grid = kwargs.pop("grid")
|
| 1574 |
+
kwargs.pop("warmup", None)
|
| 1575 |
+
# rewrite kernel.run(*args, grid=grid) to kernel[grid](*args)
|
| 1576 |
+
return self.call_triton_kernel(
|
| 1577 |
+
type(variable)(
|
| 1578 |
+
kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid
|
| 1579 |
+
),
|
| 1580 |
+
args,
|
| 1581 |
+
kwargs,
|
| 1582 |
+
tx,
|
| 1583 |
+
)
|
| 1584 |
+
|
| 1585 |
+
def call_triton_kernel(
|
| 1586 |
+
self,
|
| 1587 |
+
variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"],
|
| 1588 |
+
args: Sequence[Any],
|
| 1589 |
+
kwargs: dict[str, Any],
|
| 1590 |
+
tx: Optional["InstructionTranslator"],
|
| 1591 |
+
) -> Optional["ConstantVariable"]:
|
| 1592 |
+
from triton import JITFunction
|
| 1593 |
+
from triton.runtime.autotuner import autotune, Autotuner, Config, Heuristics
|
| 1594 |
+
|
| 1595 |
+
# Check if num_ctas is in kwargs
|
| 1596 |
+
if "num_ctas" in kwargs:
|
| 1597 |
+
self.raise_unsupported(
|
| 1598 |
+
"Passing num_ctas directly to the Triton kernel is not supported. "
|
| 1599 |
+
"Please use a Config in @triton.autotune instead."
|
| 1600 |
+
)
|
| 1601 |
+
|
| 1602 |
+
# Make sure the kernel has a grid
|
| 1603 |
+
if variable.grid is None:
|
| 1604 |
+
self.raise_unsupported("Triton kernels should always be called with a grid")
|
| 1605 |
+
|
| 1606 |
+
# raise an exception if there are multiple @triton.autotune decorators
|
| 1607 |
+
iter_kernel = variable.kernel
|
| 1608 |
+
autotuner_count = 0
|
| 1609 |
+
while not isinstance(iter_kernel, JITFunction):
|
| 1610 |
+
if isinstance(iter_kernel, Autotuner):
|
| 1611 |
+
autotuner_count += 1
|
| 1612 |
+
if autotuner_count > 1:
|
| 1613 |
+
self.raise_unsupported(
|
| 1614 |
+
"Passing multiple @triton.autotune decorators is not supported. "
|
| 1615 |
+
"Please use a single @triton.autotune decorator instead."
|
| 1616 |
+
)
|
| 1617 |
+
iter_kernel = iter_kernel.fn
|
| 1618 |
+
|
| 1619 |
+
# Process the @triton.heuristics decorator:
|
| 1620 |
+
# - We know there is only 1 autotuner decorator here
|
| 1621 |
+
# - We can apply the heuristic to all triton.Configs in the order that the decorators appear
|
| 1622 |
+
# This way, when the config is selected, the heuristics have already been applied.
|
| 1623 |
+
# - Decorators that appear *before* the autotuner are already processed correctly
|
| 1624 |
+
if isinstance(variable.kernel, Autotuner) and isinstance(
|
| 1625 |
+
variable.kernel.fn, Heuristics
|
| 1626 |
+
):
|
| 1627 |
+
# unwrap the heuristics decorator, we don't need it anymore
|
| 1628 |
+
# variable.kernel ==> Autotuner
|
| 1629 |
+
# variable.kernel.fn ==> Heuristics
|
| 1630 |
+
# ...
|
| 1631 |
+
# There can be arbitrarily many heuristics wrappers here!
|
| 1632 |
+
# ...
|
| 1633 |
+
# variable.kernel.fn ==> JITFunction
|
| 1634 |
+
|
| 1635 |
+
# Copy the configs, we are going to be modifying them
|
| 1636 |
+
new_configs = copy.deepcopy(variable.kernel.configs)
|
| 1637 |
+
|
| 1638 |
+
named_args = dict(zip(variable.kernel.arg_names, args))
|
| 1639 |
+
|
| 1640 |
+
# Iterate through all of the heuristics wrappers that come after the autotune wrapper
|
| 1641 |
+
iter_kernel = variable.kernel.fn
|
| 1642 |
+
while isinstance(iter_kernel, Heuristics):
|
| 1643 |
+
# For each config, apply the heuristic fn(s)
|
| 1644 |
+
for config_idx in range(len(new_configs)):
|
| 1645 |
+
for kwarg_key, heuristic_fn in iter_kernel.values.items():
|
| 1646 |
+
# Run heuristics on the combined configs + kwargs
|
| 1647 |
+
heuristic_result = self.call_user_defined_fn(
|
| 1648 |
+
heuristic_fn,
|
| 1649 |
+
[
|
| 1650 |
+
{
|
| 1651 |
+
**named_args,
|
| 1652 |
+
**kwargs,
|
| 1653 |
+
**new_configs[config_idx].__dict__["kwargs"],
|
| 1654 |
+
},
|
| 1655 |
+
],
|
| 1656 |
+
{},
|
| 1657 |
+
tx,
|
| 1658 |
+
variable,
|
| 1659 |
+
)
|
| 1660 |
+
|
| 1661 |
+
# Update the kwargs in each config
|
| 1662 |
+
# maybe_unpack_heuristic_result raises unsupported if the value is non-constant
|
| 1663 |
+
new_configs[config_idx].__dict__["kwargs"][
|
| 1664 |
+
kwarg_key
|
| 1665 |
+
] = self.maybe_unpack_heuristic_result(heuristic_result)
|
| 1666 |
+
|
| 1667 |
+
iter_kernel = iter_kernel.fn
|
| 1668 |
+
assert isinstance(iter_kernel, JITFunction)
|
| 1669 |
+
prune_configs_by = {
|
| 1670 |
+
"perf_model": variable.kernel.perf_model,
|
| 1671 |
+
"early_config_prune": variable.kernel.early_config_prune,
|
| 1672 |
+
"configs_top_k": variable.kernel.configs_top_k,
|
| 1673 |
+
}
|
| 1674 |
+
new_kernel = autotune(
|
| 1675 |
+
configs=new_configs, key=[], prune_configs_by=prune_configs_by
|
| 1676 |
+
)(iter_kernel)
|
| 1677 |
+
# create a new variable to contain the new (wrapped) kernel;
|
| 1678 |
+
# skip kernel_idx to get a new record in the kernel side table
|
| 1679 |
+
new_var = type(variable)(new_kernel, None, variable.grid)
|
| 1680 |
+
return self.call_triton_kernel(new_var, args, kwargs, tx)
|
| 1681 |
+
|
| 1682 |
+
SPECIAL_CONFIG_NAMES = {
|
| 1683 |
+
"num_warps",
|
| 1684 |
+
"num_stages",
|
| 1685 |
+
"num_ctas",
|
| 1686 |
+
"num_consumer_groups",
|
| 1687 |
+
"num_buffers_warp_spec",
|
| 1688 |
+
}
|
| 1689 |
+
|
| 1690 |
+
# move special config names to configs out of kwargs
|
| 1691 |
+
special_kwargs = {}
|
| 1692 |
+
for name in SPECIAL_CONFIG_NAMES:
|
| 1693 |
+
if name in kwargs:
|
| 1694 |
+
# remove special kwargs from `kwargs`
|
| 1695 |
+
val = kwargs.pop(name)
|
| 1696 |
+
special_kwargs[name] = self.get_value(val)
|
| 1697 |
+
|
| 1698 |
+
if special_kwargs:
|
| 1699 |
+
if isinstance(variable.kernel, Autotuner):
|
| 1700 |
+
# if there is Autotuner already, set
|
| 1701 |
+
# special kwargs to each of its configs
|
| 1702 |
+
new_configs = copy.deepcopy(variable.kernel.configs)
|
| 1703 |
+
for config in new_configs:
|
| 1704 |
+
config.__dict__.update(special_kwargs)
|
| 1705 |
+
prune_configs_by = {
|
| 1706 |
+
"perf_model": variable.kernel.perf_model,
|
| 1707 |
+
"early_config_prune": variable.kernel.early_config_prune,
|
| 1708 |
+
"configs_top_k": variable.kernel.configs_top_k,
|
| 1709 |
+
}
|
| 1710 |
+
|
| 1711 |
+
new_kernel = autotune(
|
| 1712 |
+
configs=new_configs, key=[], prune_configs_by=prune_configs_by
|
| 1713 |
+
)(variable.kernel.fn)
|
| 1714 |
+
else:
|
| 1715 |
+
# if there is no Autotuner, wrap the kernel into a
|
| 1716 |
+
# new one with a single config with special kwargs
|
| 1717 |
+
new_config = Config(kwargs={}, **special_kwargs)
|
| 1718 |
+
|
| 1719 |
+
new_kernel = autotune(configs=[new_config], key=[])(variable.kernel)
|
| 1720 |
+
|
| 1721 |
+
# create a new variable to contain the new (wrapped) kernel;
|
| 1722 |
+
# skip kernel_idx to get a new record in the kernel side table
|
| 1723 |
+
new_var = type(variable)(new_kernel, None, variable.grid)
|
| 1724 |
+
return self.call_triton_kernel(new_var, args, kwargs, tx)
|
| 1725 |
+
|
| 1726 |
+
if isinstance(variable.kernel, Autotuner):
|
| 1727 |
+
special_param_names = []
|
| 1728 |
+
for name in SPECIAL_CONFIG_NAMES:
|
| 1729 |
+
if name in variable.kernel.fn.arg_names:
|
| 1730 |
+
special_param_names.append(name)
|
| 1731 |
+
|
| 1732 |
+
if special_param_names:
|
| 1733 |
+
# If the Triton kernel has SPECIAL_CONFIG_NAMES in parameters, those should
|
| 1734 |
+
# be passed from the kernel configs: the behavior of Triton runtime is that
|
| 1735 |
+
# those values get folded into the kernel arguments iff there are parameters
|
| 1736 |
+
# with the same name. Normally the values of those parameters are defined
|
| 1737 |
+
# outside the `kwargs` part of the autotuning configs. Here we move them to
|
| 1738 |
+
# the `kwargs` part (if they're absent there) to facilitate passing them as
|
| 1739 |
+
# arguments to the kernel downstream.
|
| 1740 |
+
updated = False
|
| 1741 |
+
new_configs = copy.deepcopy(variable.kernel.configs)
|
| 1742 |
+
for config in new_configs:
|
| 1743 |
+
for name in special_param_names:
|
| 1744 |
+
if name not in config.__dict__["kwargs"]:
|
| 1745 |
+
assert (
|
| 1746 |
+
name in config.__dict__
|
| 1747 |
+
), f"{name} must be in autotuning configs to be used as a kernel parameter"
|
| 1748 |
+
config.__dict__["kwargs"][name] = config.__dict__[name]
|
| 1749 |
+
updated = True
|
| 1750 |
+
|
| 1751 |
+
if updated:
|
| 1752 |
+
prune_configs_by = {
|
| 1753 |
+
"perf_model": variable.kernel.perf_model,
|
| 1754 |
+
"early_config_prune": variable.kernel.early_config_prune,
|
| 1755 |
+
"configs_top_k": variable.kernel.configs_top_k,
|
| 1756 |
+
}
|
| 1757 |
+
|
| 1758 |
+
new_kernel = autotune(
|
| 1759 |
+
configs=new_configs, prune_configs_by=prune_configs_by, key=[]
|
| 1760 |
+
)(variable.kernel.fn)
|
| 1761 |
+
new_var = type(variable)(new_kernel, None, variable.grid)
|
| 1762 |
+
return self.call_triton_kernel(new_var, args, kwargs, tx)
|
| 1763 |
+
|
| 1764 |
+
# These are the default values in upstream Triton
|
| 1765 |
+
# see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950
|
| 1766 |
+
default_perf_model = None
|
| 1767 |
+
default_early_config_prune = None
|
| 1768 |
+
|
| 1769 |
+
# run prune_configs_by
|
| 1770 |
+
if isinstance(variable.kernel, Autotuner) and (
|
| 1771 |
+
variable.kernel.perf_model != default_perf_model
|
| 1772 |
+
or variable.kernel.early_config_prune != default_early_config_prune
|
| 1773 |
+
):
|
| 1774 |
+
# Prune the configs
|
| 1775 |
+
named_args = dict(zip(variable.kernel.arg_names, args))
|
| 1776 |
+
|
| 1777 |
+
# The source information is important here so the guards are installed correctly
|
| 1778 |
+
|
| 1779 |
+
wrapped_early_configs_prune = self.wrap_user_defined_obj(
|
| 1780 |
+
variable.kernel.early_config_prune,
|
| 1781 |
+
tx,
|
| 1782 |
+
variable,
|
| 1783 |
+
"early_config_prune",
|
| 1784 |
+
)
|
| 1785 |
+
|
| 1786 |
+
wrapped_perf_model = self.wrap_user_defined_obj(
|
| 1787 |
+
variable.kernel.perf_model, tx, variable, "perf_model"
|
| 1788 |
+
)
|
| 1789 |
+
|
| 1790 |
+
wrapped_configs_top_k = self.wrap_user_defined_obj(
|
| 1791 |
+
variable.kernel.configs_top_k, tx, variable, "configs_top_k"
|
| 1792 |
+
)
|
| 1793 |
+
|
| 1794 |
+
wrapped_configs = self.wrap_user_defined_obj(
|
| 1795 |
+
variable.kernel.configs, tx, variable, "configs"
|
| 1796 |
+
)
|
| 1797 |
+
|
| 1798 |
+
pruned_configs = self.call_user_defined_fn(
|
| 1799 |
+
self.do_prune_configs,
|
| 1800 |
+
[
|
| 1801 |
+
variable,
|
| 1802 |
+
wrapped_early_configs_prune,
|
| 1803 |
+
wrapped_perf_model,
|
| 1804 |
+
wrapped_configs_top_k,
|
| 1805 |
+
wrapped_configs,
|
| 1806 |
+
named_args,
|
| 1807 |
+
kwargs,
|
| 1808 |
+
],
|
| 1809 |
+
{},
|
| 1810 |
+
tx,
|
| 1811 |
+
variable,
|
| 1812 |
+
)
|
| 1813 |
+
|
| 1814 |
+
pruned_configs = self.maybe_unpack_configs(pruned_configs, tx)
|
| 1815 |
+
|
| 1816 |
+
# after pruning the configs, create a new autotuner object with
|
| 1817 |
+
# these configs and recurse.
|
| 1818 |
+
new_kernel = autotune(configs=pruned_configs, key=[])(variable.kernel.fn)
|
| 1819 |
+
# create a new variable to contain the new (wrapped) kernel;
|
| 1820 |
+
# skip kernel_idx to get a new record in the kernel side table
|
| 1821 |
+
new_var = type(variable)(new_kernel, None, variable.grid)
|
| 1822 |
+
return self.call_triton_kernel(new_var, args, kwargs, tx)
|
| 1823 |
+
|
| 1824 |
+
# Both for grid's meta as well as for the kernel, we need combined
|
| 1825 |
+
# args and kwargs combined and normalized
|
| 1826 |
+
combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
|
| 1827 |
+
|
| 1828 |
+
# precompute the grid for the kernel
|
| 1829 |
+
configs = (
|
| 1830 |
+
[config.kwargs for config in variable.kernel.configs]
|
| 1831 |
+
if isinstance(variable.kernel, Autotuner)
|
| 1832 |
+
else [{}]
|
| 1833 |
+
)
|
| 1834 |
+
grids = []
|
| 1835 |
+
for config_args in configs:
|
| 1836 |
+
# If the grid is a function, then lets execute it and convert it to
|
| 1837 |
+
# a list
|
| 1838 |
+
grid = variable.grid
|
| 1839 |
+
assert grid is not None
|
| 1840 |
+
if self.is_callable(grid):
|
| 1841 |
+
# Populate the special "meta" argument to call the grid function
|
| 1842 |
+
meta = {**combined_args_raw, **config_args}
|
| 1843 |
+
grid = self.call_grid(grid, meta, tx) # type: ignore[arg-type]
|
| 1844 |
+
grids.append(self.check_grid(grid))
|
| 1845 |
+
|
| 1846 |
+
for i in range(len(grids)):
|
| 1847 |
+
if not isinstance(grids[i], tuple):
|
| 1848 |
+
self.raise_unsupported("Only tuple grids are supported")
|
| 1849 |
+
# inductor expects all grids to be 3-tuple so lets make it
|
| 1850 |
+
if len(grids[i]) == 1:
|
| 1851 |
+
grids[i] = (grids[i][0], 1, 1)
|
| 1852 |
+
elif len(grids[i]) == 2:
|
| 1853 |
+
grids[i] = (grids[i][0], grids[i][1], 1)
|
| 1854 |
+
elif len(grids[i]) > 3:
|
| 1855 |
+
self.raise_unsupported("Grid can have at most rank 3")
|
| 1856 |
+
|
| 1857 |
+
assert len(grids) != 0
|
| 1858 |
+
if isinstance(variable.kernel, JITFunction):
|
| 1859 |
+
constexprs = variable.kernel.constexprs
|
| 1860 |
+
else:
|
| 1861 |
+
# If we are looking at an @triton.autotune decorator, the nested function should be a JITFunction
|
| 1862 |
+
# This is because we don't support @triton.heuristics or nested @triton.autotune decorators yet
|
| 1863 |
+
assert isinstance(variable.kernel, Autotuner)
|
| 1864 |
+
constexprs = variable.kernel.fn.constexprs
|
| 1865 |
+
|
| 1866 |
+
for idx, arg_name in enumerate(variable.kernel.arg_names):
|
| 1867 |
+
if idx in constexprs:
|
| 1868 |
+
if arg_name in combined_args_raw:
|
| 1869 |
+
# [Note: Specialize tl.constexpr args in user-defined triton kernels]
|
| 1870 |
+
# This arg is marked as tl.constexpr. That means that triton will recompile every time
|
| 1871 |
+
# this value changes.
|
| 1872 |
+
# https://github.com/pytorch/pytorch/issues/136504
|
| 1873 |
+
# One option is to correctly pass the symints in so that the symbolic expressions are defined
|
| 1874 |
+
# when the triton code is being executed.
|
| 1875 |
+
# But since triton will have to recompile either way, we instead just specialize on the value.
|
| 1876 |
+
#
|
| 1877 |
+
# Depending on the type of `variable` we might expect different types for the symbolic args:
|
| 1878 |
+
# either SymNodeVariables (for TritonKernelVariables) or SymInts (TracingTritonKernelWrapper)
|
| 1879 |
+
combined_args_raw[arg_name] = variable.specialize_symbolic(
|
| 1880 |
+
combined_args_raw[arg_name]
|
| 1881 |
+
)
|
| 1882 |
+
return self.call_HOP(variable, grids, combined_args_raw, tx)
|
| 1883 |
+
|
| 1884 |
+
|
| 1885 |
+
###############################################################################
|
| 1886 |
+
# Helpers for wrap_triton API that makes a user-defined triton kernel traceable into
|
| 1887 |
+
# a graph via make_fx or non-strict export (coming soon)
|
| 1888 |
+
|
| 1889 |
+
|
| 1890 |
+
class TracingTritonHOPifier(TritonHOPifier):
|
| 1891 |
+
def raise_unsupported(self, msg: str) -> Never:
|
| 1892 |
+
raise RuntimeError(msg)
|
| 1893 |
+
|
| 1894 |
+
def is_callable(self, maybe_callable: Any) -> bool:
|
| 1895 |
+
return callable(maybe_callable)
|
| 1896 |
+
|
| 1897 |
+
def get_value(self, val: Any) -> Any:
|
| 1898 |
+
return val
|
| 1899 |
+
|
| 1900 |
+
def call_grid(
|
| 1901 |
+
self,
|
| 1902 |
+
grid: "TritonGridCallableType",
|
| 1903 |
+
meta: "TritonMetaParamsType",
|
| 1904 |
+
tx: None,
|
| 1905 |
+
) -> tuple[Union[int, sympy.Expr, SymInt], ...]:
|
| 1906 |
+
assert tx is None
|
| 1907 |
+
assert isinstance(meta, dict)
|
| 1908 |
+
assert callable(grid)
|
| 1909 |
+
return grid(meta)
|
| 1910 |
+
|
| 1911 |
+
def wrap_user_defined_obj(
|
| 1912 |
+
self,
|
| 1913 |
+
user_obj: Any,
|
| 1914 |
+
tx: Optional["InstructionTranslator"],
|
| 1915 |
+
variable: Optional[
|
| 1916 |
+
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
| 1917 |
+
],
|
| 1918 |
+
name: str,
|
| 1919 |
+
) -> Any:
|
| 1920 |
+
assert tx is None
|
| 1921 |
+
return user_obj
|
| 1922 |
+
|
| 1923 |
+
def call_user_defined_fn(
|
| 1924 |
+
self,
|
| 1925 |
+
user_fn: Callable[..., Any],
|
| 1926 |
+
args: list,
|
| 1927 |
+
kwargs: dict,
|
| 1928 |
+
tx: Optional["InstructionTranslator"],
|
| 1929 |
+
variable: Optional[
|
| 1930 |
+
Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]
|
| 1931 |
+
],
|
| 1932 |
+
) -> Any:
|
| 1933 |
+
assert isinstance(args, list)
|
| 1934 |
+
assert isinstance(kwargs, dict)
|
| 1935 |
+
assert callable(user_fn)
|
| 1936 |
+
return user_fn(*args, **kwargs)
|
| 1937 |
+
|
| 1938 |
+
def maybe_unpack_configs(
|
| 1939 |
+
self, configs: list["TritonConfig"], tx: Optional["InstructionTranslator"]
|
| 1940 |
+
) -> list["TritonConfig"]:
|
| 1941 |
+
assert isinstance(configs, list)
|
| 1942 |
+
return configs
|
| 1943 |
+
|
| 1944 |
+
def maybe_unpack_heuristic_result(self, result: Any) -> Any:
|
| 1945 |
+
return result
|
| 1946 |
+
|
| 1947 |
+
def check_grid(
|
| 1948 |
+
self,
|
| 1949 |
+
grid: "TritonGridType",
|
| 1950 |
+
) -> tuple[Union[int, sympy.Expr, SymInt], ...]:
|
| 1951 |
+
if not isinstance(grid, collections.abc.Sequence):
|
| 1952 |
+
raise RuntimeError(
|
| 1953 |
+
"wrap_triton can only handle grids that resolve to Sequence[int]."
|
| 1954 |
+
)
|
| 1955 |
+
# normalize to tuple
|
| 1956 |
+
return tuple(grid)
|
| 1957 |
+
|
| 1958 |
+
def store_non_graphable_args(
|
| 1959 |
+
self,
|
| 1960 |
+
combined_args: dict[str, Any],
|
| 1961 |
+
) -> tuple[dict, int]:
|
| 1962 |
+
"""
|
| 1963 |
+
Some args cannot be stored in the FX graph.
|
| 1964 |
+
Put them in the side table.
|
| 1965 |
+
"""
|
| 1966 |
+
|
| 1967 |
+
def is_graphable(val: Any) -> bool:
|
| 1968 |
+
return isinstance(val, (fx.node.base_types, fx.Node))
|
| 1969 |
+
|
| 1970 |
+
non_graphable_args = {
|
| 1971 |
+
k: v for k, v in combined_args.items() if not is_graphable(v)
|
| 1972 |
+
}
|
| 1973 |
+
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
|
| 1974 |
+
|
| 1975 |
+
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
|
| 1976 |
+
|
| 1977 |
+
return graphable_args, constant_args_idx
|
| 1978 |
+
|
| 1979 |
+
def call_HOP(
|
| 1980 |
+
self,
|
| 1981 |
+
variable: "TraceableTritonKernelWrapper",
|
| 1982 |
+
grids: list["TritonGridTupleType"],
|
| 1983 |
+
combined_args: dict[str, Any],
|
| 1984 |
+
tx: None,
|
| 1985 |
+
) -> None:
|
| 1986 |
+
assert tx is None
|
| 1987 |
+
assert isinstance(variable, TraceableTritonKernelWrapper)
|
| 1988 |
+
|
| 1989 |
+
graphable_args, constant_args_idx = self.store_non_graphable_args(combined_args)
|
| 1990 |
+
|
| 1991 |
+
assert isinstance(variable.kernel_idx, int)
|
| 1992 |
+
return triton_kernel_wrapper_mutation(
|
| 1993 |
+
kernel_idx=variable.kernel_idx,
|
| 1994 |
+
constant_args_idx=constant_args_idx,
|
| 1995 |
+
grid=grids, # type: ignore[arg-type]
|
| 1996 |
+
# TMA descriptor capturing not yet
|
| 1997 |
+
# supported in non-dynamo tracing
|
| 1998 |
+
tma_descriptor_metadata={},
|
| 1999 |
+
kwargs=graphable_args,
|
| 2000 |
+
)
|
| 2001 |
+
|
| 2002 |
+
|
| 2003 |
+
tracing_triton_hopifier_singleton = TracingTritonHOPifier()
|
| 2004 |
+
|
| 2005 |
+
|
| 2006 |
+
class TraceableTritonKernelWrapper:
|
| 2007 |
+
kernel: "TritonKernelType"
|
| 2008 |
+
kernel_idx: Optional[int]
|
| 2009 |
+
grid: Optional["TritonGridType"]
|
| 2010 |
+
|
| 2011 |
+
def __init__(
|
| 2012 |
+
self,
|
| 2013 |
+
kernel: "TritonKernelType",
|
| 2014 |
+
kernel_idx: Optional[int],
|
| 2015 |
+
grid: Optional["TritonGridType"],
|
| 2016 |
+
) -> None:
|
| 2017 |
+
self.kernel = None
|
| 2018 |
+
self.grid = None
|
| 2019 |
+
tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
|
| 2020 |
+
assert self.kernel is not None
|
| 2021 |
+
|
| 2022 |
+
def __getitem__(self, *args: Sequence[Any]) -> "TraceableTritonKernelWrapper":
|
| 2023 |
+
return tracing_triton_hopifier_singleton.call_getitem(self, args) # type: ignore[return-value]
|
| 2024 |
+
|
| 2025 |
+
def run(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
|
| 2026 |
+
from torch._library.triton import is_wrap_triton_enabled
|
| 2027 |
+
|
| 2028 |
+
if is_wrap_triton_enabled():
|
| 2029 |
+
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
|
| 2030 |
+
else:
|
| 2031 |
+
assert self.kernel is not None
|
| 2032 |
+
return self.kernel.run(*args, **kwargs)
|
| 2033 |
+
|
| 2034 |
+
def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
|
| 2035 |
+
from torch._library.triton import is_wrap_triton_enabled
|
| 2036 |
+
|
| 2037 |
+
if is_wrap_triton_enabled():
|
| 2038 |
+
return tracing_triton_hopifier_singleton.call_triton_kernel(
|
| 2039 |
+
self, args, kwargs, None
|
| 2040 |
+
)
|
| 2041 |
+
else:
|
| 2042 |
+
assert self.kernel is not None
|
| 2043 |
+
return self.kernel[self.grid](*args, **kwargs)
|
| 2044 |
+
|
| 2045 |
+
def specialize_symbolic(self, arg: Sequence[Any]) -> Any:
|
| 2046 |
+
import torch
|
| 2047 |
+
|
| 2048 |
+
# See [Note: Specialize tl.constexpr args in user-defined triton kernels]
|
| 2049 |
+
if isinstance(arg, (torch.SymInt, torch.SymBool, torch.SymFloat)):
|
| 2050 |
+
return guard_scalar(arg)
|
| 2051 |
+
return arg
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/utils.py
ADDED
|
@@ -0,0 +1,1134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import functools
|
| 4 |
+
from contextlib import contextmanager, ExitStack, nullcontext
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.fx.traceback as fx_traceback
|
| 10 |
+
import torch.utils._pytree as pytree
|
| 11 |
+
from torch._dispatch.python import suspend_functionalization
|
| 12 |
+
from torch._guards import detect_fake_mode
|
| 13 |
+
from torch._higher_order_ops.schema import HopSchema
|
| 14 |
+
from torch._ops import HigherOrderOperator, OperatorBase, OpOverload
|
| 15 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 16 |
+
from torch._subclasses.functional_tensor import (
|
| 17 |
+
disable_functional_mode,
|
| 18 |
+
FunctionalTensor,
|
| 19 |
+
)
|
| 20 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 21 |
+
_temp_remove_metadata_torch_function_mode,
|
| 22 |
+
disable_proxy_modes_tracing,
|
| 23 |
+
make_fx,
|
| 24 |
+
)
|
| 25 |
+
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
| 26 |
+
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
| 27 |
+
from torch.multiprocessing.reductions import StorageWeakRef
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class UnsupportedAliasMutationException(RuntimeError):
|
| 32 |
+
reason: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def autograd_not_implemented_inner(
|
| 36 |
+
operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
|
| 37 |
+
) -> Any:
|
| 38 |
+
"""If autograd is enabled and any of the arguments require grad this will either
|
| 39 |
+
raise an error or return a DelayedError depending on the value of delayed.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
operator: The Operator to call with the *args and **kwargs with
|
| 43 |
+
op_name: The name of the Operator
|
| 44 |
+
delayed_error: If True, return a DelayedError instead of raising an error
|
| 45 |
+
args: The flattened operands to the Operator
|
| 46 |
+
kwargs: The keyword arguments to the Operator
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
RuntimeError: If autograd is enabled and any of the arguments to the Operator
|
| 50 |
+
"""
|
| 51 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 52 |
+
result = operator(*args, **kwargs)
|
| 53 |
+
flat_operands = pytree.arg_tree_leaves(*args)
|
| 54 |
+
if torch.is_grad_enabled() and any(
|
| 55 |
+
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
|
| 56 |
+
):
|
| 57 |
+
if delayed_error:
|
| 58 |
+
err_fn = torch._C._functions.DelayedError(
|
| 59 |
+
f"Autograd not implemented for {str(operator)}",
|
| 60 |
+
1,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def fake_requires_grad(tensor):
|
| 64 |
+
if torch.is_floating_point(tensor) or torch.is_complex(tensor):
|
| 65 |
+
tensor = tensor.detach()
|
| 66 |
+
tensor.requires_grad = True
|
| 67 |
+
return tensor
|
| 68 |
+
|
| 69 |
+
return pytree.tree_map_only(
|
| 70 |
+
torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
raise RuntimeError(f"Autograd not implemented for {str(operator)}")
|
| 74 |
+
return result
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
|
| 78 |
+
def inner(*args, **kwargs):
|
| 79 |
+
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
|
| 80 |
+
|
| 81 |
+
return inner
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _maybe_run_with_interpreter(fn):
|
| 85 |
+
maybe_interpreted_fn = fn
|
| 86 |
+
if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
|
| 87 |
+
# Running graph with interpreter is needed for propagating the stack_trace
|
| 88 |
+
def graph_with_interpreter(*args):
|
| 89 |
+
with fx_traceback.preserve_node_meta():
|
| 90 |
+
return torch.fx.Interpreter(fn).run(*args)
|
| 91 |
+
|
| 92 |
+
maybe_interpreted_fn = graph_with_interpreter
|
| 93 |
+
return maybe_interpreted_fn
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _maybe_compile_and_run_fn(fn, *args):
|
| 97 |
+
if not torch.compiler.is_dynamo_compiling():
|
| 98 |
+
from torch._dynamo.backends.debugging import (
|
| 99 |
+
make_eager_backend_with_torch_function_mode,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
| 103 |
+
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
| 104 |
+
if metadata_mode:
|
| 105 |
+
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
|
| 106 |
+
else:
|
| 107 |
+
backend = "eager"
|
| 108 |
+
return torch.compile(fn, backend=backend, fullgraph=True)(*args)
|
| 109 |
+
else:
|
| 110 |
+
return fn(*args)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def reenter_make_fx(fn):
|
| 114 |
+
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
| 115 |
+
|
| 116 |
+
@functools.wraps(fn)
|
| 117 |
+
def wrapped(*args):
|
| 118 |
+
assert (
|
| 119 |
+
_CURRENT_MAKE_FX_TRACER is not None
|
| 120 |
+
), "Cannot reenter make_fx when we're not under a make_fx tracing session"
|
| 121 |
+
return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
|
| 122 |
+
_maybe_run_with_interpreter(fn), *args
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return wrapped
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _maybe_reenter_make_fx(fn):
|
| 129 |
+
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
| 130 |
+
|
| 131 |
+
if _CURRENT_MAKE_FX_TRACER is not None:
|
| 132 |
+
return reenter_make_fx(fn)
|
| 133 |
+
else:
|
| 134 |
+
|
| 135 |
+
def _maybe_make_fx_with_fake_mode(fn):
|
| 136 |
+
@functools.wraps(fn)
|
| 137 |
+
def wrapped(*args):
|
| 138 |
+
from torch._guards import detect_fake_mode
|
| 139 |
+
|
| 140 |
+
fake_mode = detect_fake_mode(args)
|
| 141 |
+
if fake_mode is None:
|
| 142 |
+
# we creaeta a fake_mode here to make sure we could
|
| 143 |
+
# trace the graph with data-dependent calls e.g. .item()
|
| 144 |
+
return make_fx(fn, tracing_mode="fake")(*args)
|
| 145 |
+
# Tracing with real if all inputs have been fakfied
|
| 146 |
+
return make_fx(fn)(*args)
|
| 147 |
+
|
| 148 |
+
return wrapped
|
| 149 |
+
|
| 150 |
+
return _maybe_make_fx_with_fake_mode(fn)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def check_meta_consistency(
|
| 154 |
+
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
| 155 |
+
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
| 156 |
+
lhs_name: str,
|
| 157 |
+
rhs_name: str,
|
| 158 |
+
include_contiguity: bool = True,
|
| 159 |
+
) -> None:
|
| 160 |
+
def diff_meta_pairs(
|
| 161 |
+
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
| 162 |
+
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
| 163 |
+
) -> list[str]:
|
| 164 |
+
def diff_meta(
|
| 165 |
+
lhs: Union[torch.Tensor, torch.SymInt, int],
|
| 166 |
+
rhs: Union[torch.Tensor, torch.SymInt, int],
|
| 167 |
+
) -> str:
|
| 168 |
+
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
|
| 169 |
+
return ", ".join(
|
| 170 |
+
diff_tensor_meta(
|
| 171 |
+
_extract_tensor_metadata(
|
| 172 |
+
lhs, include_contiguity=include_contiguity
|
| 173 |
+
),
|
| 174 |
+
_extract_tensor_metadata(
|
| 175 |
+
rhs, include_contiguity=include_contiguity
|
| 176 |
+
),
|
| 177 |
+
check_grad=False,
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
|
| 182 |
+
def _both_int_types(lhs, rhs):
|
| 183 |
+
return isinstance(lhs, (int, torch.SymInt)) and isinstance(
|
| 184 |
+
rhs, (int, torch.SymInt)
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def _both_tensor(lhs, rhs):
|
| 188 |
+
return isinstance(lhs, torch.Tensor) and isinstance(
|
| 189 |
+
rhs, torch.Tensor
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs):
|
| 193 |
+
return f"type: {lhs} vs {rhs}"
|
| 194 |
+
|
| 195 |
+
return ""
|
| 196 |
+
|
| 197 |
+
# Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata
|
| 198 |
+
def diff_device(
|
| 199 |
+
lhs: Union[torch.Tensor, torch.SymInt, int],
|
| 200 |
+
rhs: Union[torch.Tensor, torch.SymInt, int],
|
| 201 |
+
) -> str:
|
| 202 |
+
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
|
| 203 |
+
if (
|
| 204 |
+
rhs.device.type == lhs.device.type
|
| 205 |
+
and rhs.device.index == lhs.device.index
|
| 206 |
+
):
|
| 207 |
+
return ""
|
| 208 |
+
else:
|
| 209 |
+
return "device"
|
| 210 |
+
return ""
|
| 211 |
+
|
| 212 |
+
if len(lhs_list) != len(rhs_list):
|
| 213 |
+
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
|
| 214 |
+
f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}"
|
| 215 |
+
)
|
| 216 |
+
all_diffs = []
|
| 217 |
+
for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)):
|
| 218 |
+
if diff := diff_meta(lhs, rhs):
|
| 219 |
+
all_diffs.append(
|
| 220 |
+
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
|
| 221 |
+
)
|
| 222 |
+
if diff := diff_device(lhs, rhs):
|
| 223 |
+
all_diffs.append(
|
| 224 |
+
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
|
| 225 |
+
)
|
| 226 |
+
return all_diffs
|
| 227 |
+
|
| 228 |
+
if all_diffs := diff_meta_pairs(lhs_list, rhs_list):
|
| 229 |
+
diff_str = "\n".join(all_diffs)
|
| 230 |
+
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
|
| 231 |
+
f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@contextmanager
|
| 236 |
+
def _set_compilation_env():
|
| 237 |
+
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
| 238 |
+
_old_allow_empty_graphs = torch._dynamo.config.allow_empty_graphs
|
| 239 |
+
# The issue is tracked in https://github.com/pytorch/pytorch/issues/144360: when dynamo finds
|
| 240 |
+
# the top-level frame produces no graph, the default behavior is to fallback to eager.
|
| 241 |
+
# Then when it encounters an inner function, it will try to trace that function again, which is unnecessary.
|
| 242 |
+
# For while_loop, during inspecting the inner call, we trace into the python dispathcer
|
| 243 |
+
# logic, which is not tracable as of today. So the proper fix can be either 1. allow dispatch
|
| 244 |
+
# logic to be dynamo tracable or 2. fixing https://github.com/pytorch/pytorch/issues/144360.
|
| 245 |
+
# but it exposes some bugs in existing tests so we have to have a temporary flag to control
|
| 246 |
+
# the behavior, which allows dynamo to store an empty graph for a frame without falling back to eager
|
| 247 |
+
try:
|
| 248 |
+
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
| 249 |
+
# once we are confident fx tracing works with dynamo.
|
| 250 |
+
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
| 251 |
+
torch._dynamo.config.allow_empty_graphs = True
|
| 252 |
+
yield
|
| 253 |
+
finally:
|
| 254 |
+
torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
|
| 255 |
+
torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# The invariant here is that we always trace the branch with fake tensor
|
| 259 |
+
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
| 260 |
+
fake_mode = detect_fake_mode(inputs)
|
| 261 |
+
tracing_mode = "real"
|
| 262 |
+
if fake_mode is None:
|
| 263 |
+
fake_mode = nullcontext()
|
| 264 |
+
tracing_mode = "fake"
|
| 265 |
+
|
| 266 |
+
# Note: we need to turn off proxy tensor mode to avoid tracing infra
|
| 267 |
+
# code that happens in make_fx e.g. we now call as_strided when wrapping tensor
|
| 268 |
+
# as fake tensor.
|
| 269 |
+
with fake_mode, disable_proxy_modes_tracing():
|
| 270 |
+
gm = make_fx(
|
| 271 |
+
fn,
|
| 272 |
+
tracing_mode=tracing_mode,
|
| 273 |
+
pre_dispatch=pre_dispatch,
|
| 274 |
+
_error_on_data_dependent_ops=False,
|
| 275 |
+
)(*inputs)
|
| 276 |
+
if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None:
|
| 277 |
+
insert_deferred_runtime_asserts(
|
| 278 |
+
gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True
|
| 279 |
+
)
|
| 280 |
+
return gm
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
| 284 |
+
try:
|
| 285 |
+
gm = _maybe_fake_tracing(gm, inputs, pre_dispatch)
|
| 286 |
+
except UnsupportedAliasMutationException:
|
| 287 |
+
# this can happen when nested cond_op is
|
| 288 |
+
# functionalized
|
| 289 |
+
return True
|
| 290 |
+
except Exception as e:
|
| 291 |
+
raise e
|
| 292 |
+
|
| 293 |
+
example_inputs = [
|
| 294 |
+
ph.meta.get("val", None) for ph in gm.graph.find_nodes(op="placeholder")
|
| 295 |
+
]
|
| 296 |
+
(
|
| 297 |
+
inp_inp_alias_map,
|
| 298 |
+
inp_out_alias_map,
|
| 299 |
+
out_out_alias_map,
|
| 300 |
+
inp_mutation,
|
| 301 |
+
) = check_input_alias_and_mutation(gm, example_inputs)
|
| 302 |
+
return (inp_inp_alias_map, inp_out_alias_map, out_out_alias_map), inp_mutation
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations):
|
| 306 |
+
if any(len(a) > 0 for a in aliases):
|
| 307 |
+
# TODO: Investigate here further which node is exactly aliasing
|
| 308 |
+
raise RuntimeError(
|
| 309 |
+
f"{name} where aliases appear. "
|
| 310 |
+
+ f"In particular, these inputs \
|
| 311 |
+
{set(el for el_map in aliases if len(el_map.keys()) > 0 for el in el_map.keys())} " # noqa: C401
|
| 312 |
+
+ "get aliased. Please ensure that this doesn't happen."
|
| 313 |
+
)
|
| 314 |
+
if len(input_mutations):
|
| 315 |
+
# TODO: Investigate here further which node is exactly mutating the inputs
|
| 316 |
+
raise RuntimeError(
|
| 317 |
+
f"{name} where the inputs are mutated. "
|
| 318 |
+
+ f"In particular, these nodes are mutating the inputs \
|
| 319 |
+
{set(el for el in input_mutations)}." # noqa: C401
|
| 320 |
+
+ "Please ensure that this doesn't happen."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False):
|
| 325 |
+
(
|
| 326 |
+
_,
|
| 327 |
+
_,
|
| 328 |
+
_,
|
| 329 |
+
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
| 330 |
+
|
| 331 |
+
return len(inp_mutation) > 0
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False):
|
| 335 |
+
(
|
| 336 |
+
inp_inp_alias_map,
|
| 337 |
+
inp_out_alias_map,
|
| 338 |
+
out_out_alias_map,
|
| 339 |
+
), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch)
|
| 340 |
+
return (
|
| 341 |
+
any(
|
| 342 |
+
(
|
| 343 |
+
len(inp_inp_alias_map) > 0,
|
| 344 |
+
len(inp_out_alias_map) > 0,
|
| 345 |
+
len(out_out_alias_map) > 0,
|
| 346 |
+
)
|
| 347 |
+
),
|
| 348 |
+
len(inp_mutation) > 0,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _collect_fake_inputs(inputs):
|
| 353 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 354 |
+
|
| 355 |
+
# Get the example values of the inputs.
|
| 356 |
+
inputs_fake: list[Union[FakeTensor, torch.Tensor, int]] = []
|
| 357 |
+
for inp in inputs:
|
| 358 |
+
if isinstance(inp, (torch.fx.proxy.Proxy, torch.fx.node.Node)):
|
| 359 |
+
inp = inp.node if isinstance(inp, torch.fx.proxy.Proxy) else inp
|
| 360 |
+
if hasattr(inp, "meta"):
|
| 361 |
+
val = inp.meta["example_value"]
|
| 362 |
+
if isinstance(val, torch.Tensor):
|
| 363 |
+
if torch._C._functorch.is_batchedtensor(
|
| 364 |
+
val
|
| 365 |
+
) or torch._C._functorch.is_functionaltensor(val):
|
| 366 |
+
# This case is for batched or functional tensors
|
| 367 |
+
# Unwrap the tensors
|
| 368 |
+
while torch._C._functorch.is_batchedtensor(
|
| 369 |
+
val
|
| 370 |
+
) or torch._C._functorch.is_functionaltensor(val):
|
| 371 |
+
val = torch._C._functorch.get_unwrapped(val)
|
| 372 |
+
assert isinstance(val, FakeTensor)
|
| 373 |
+
inputs_fake.append(val)
|
| 374 |
+
else:
|
| 375 |
+
# This is the standard case of a TensorVariable
|
| 376 |
+
assert isinstance(val, FakeTensor)
|
| 377 |
+
inputs_fake.append(val)
|
| 378 |
+
else:
|
| 379 |
+
# This case is for SymInts and other non-Tensor elements
|
| 380 |
+
assert not isinstance(val, torch.Tensor)
|
| 381 |
+
inputs_fake.append(val)
|
| 382 |
+
else:
|
| 383 |
+
# This case is for ints
|
| 384 |
+
assert isinstance(inp, int)
|
| 385 |
+
inputs_fake.append(inp)
|
| 386 |
+
|
| 387 |
+
return inputs_fake
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch):
|
| 391 |
+
aliases, inp_mutation = has_potential_input_alias_or_mutation(
|
| 392 |
+
graph_module, inputs_fake, pre_dispatch=pre_dispatch
|
| 393 |
+
)
|
| 394 |
+
if aliases:
|
| 395 |
+
raise RuntimeError(
|
| 396 |
+
f"{name} might be aliasing the input or the output!"
|
| 397 |
+
) # noqa: F541
|
| 398 |
+
if inp_mutation:
|
| 399 |
+
raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def unique_graph_id(proxy_mode, prefix):
|
| 403 |
+
"""Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
|
| 404 |
+
# There are probably better ways - I know that create_arg has some self incrementing name
|
| 405 |
+
# magic to it, but since we explicitly have to get the name for register_module,
|
| 406 |
+
# I was not sure how to do that. This kinda simulates it.
|
| 407 |
+
return unique_graph_name_with_root(proxy_mode.tracer.root, prefix)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def unique_graph_name_with_root(
|
| 411 |
+
root: torch.fx.GraphModule, prefix: str
|
| 412 |
+
) -> tuple[int, str]:
|
| 413 |
+
next_name = None
|
| 414 |
+
i = 0
|
| 415 |
+
while not next_name:
|
| 416 |
+
candidate = f"{prefix}_{i}"
|
| 417 |
+
if hasattr(root, candidate):
|
| 418 |
+
i += 1
|
| 419 |
+
else:
|
| 420 |
+
next_name = candidate
|
| 421 |
+
return i, next_name
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def _from_fun(t):
|
| 425 |
+
from torch._functorch.aot_autograd import from_fun
|
| 426 |
+
|
| 427 |
+
if isinstance(t, torch.Tensor):
|
| 428 |
+
if t.dtype != torch.bool:
|
| 429 |
+
return torch.empty_strided(
|
| 430 |
+
t.size(),
|
| 431 |
+
t.stride(),
|
| 432 |
+
dtype=t.dtype,
|
| 433 |
+
requires_grad=t.requires_grad,
|
| 434 |
+
device=t.device,
|
| 435 |
+
)
|
| 436 |
+
else:
|
| 437 |
+
# clone of a functional tensor produces a functional tensor
|
| 438 |
+
# but we want to avoid it so we clone a non-functional version
|
| 439 |
+
maybe_unfunc_t = t
|
| 440 |
+
if isinstance(t, FunctionalTensor):
|
| 441 |
+
torch._sync(t)
|
| 442 |
+
maybe_unfunc_t = from_fun(t)
|
| 443 |
+
elif torch._is_functional_tensor(t):
|
| 444 |
+
# need to handle both types of functionalization here:
|
| 445 |
+
# these are the tensors that came from the user,
|
| 446 |
+
# which could be either FunctionalTensorWrapper or FunctionalTensor
|
| 447 |
+
torch._sync(t)
|
| 448 |
+
maybe_unfunc_t = torch._from_functional_tensor(t)
|
| 449 |
+
return maybe_unfunc_t.clone()
|
| 450 |
+
return t
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def clone_outputs_aliasing_inputs(args):
|
| 454 |
+
input_storage = {
|
| 455 |
+
StorageWeakRef(arg._typed_storage())
|
| 456 |
+
for arg in args
|
| 457 |
+
if isinstance(arg, torch.Tensor)
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
def maybe_clone(t):
|
| 461 |
+
if (
|
| 462 |
+
isinstance(t, torch.Tensor)
|
| 463 |
+
and StorageWeakRef(t._typed_storage()) in input_storage
|
| 464 |
+
):
|
| 465 |
+
return t.clone()
|
| 466 |
+
return t
|
| 467 |
+
|
| 468 |
+
return maybe_clone
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def prepare_fw_with_masks(fn):
|
| 472 |
+
def fw_with_masks(*args):
|
| 473 |
+
fw_out = fn(*args)
|
| 474 |
+
return fw_out, [
|
| 475 |
+
True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
|
| 476 |
+
for ret in fw_out
|
| 477 |
+
]
|
| 478 |
+
|
| 479 |
+
return fw_with_masks
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def prepare_fw_with_masks_all_requires_grad(fn):
|
| 483 |
+
def fw_with_masks(*args):
|
| 484 |
+
fw_out = fn(*args)
|
| 485 |
+
# Note [force all outputs to be require grad]
|
| 486 |
+
# Instead of using the original fn, we set the output of original
|
| 487 |
+
# fn to all require grad. This is consistent with the behavior
|
| 488 |
+
# of autograd.Function, where if any one of the inputs requires grad
|
| 489 |
+
# all output will be require grad. This also makes the downstream
|
| 490 |
+
# require_gradness reasoning much easier.
|
| 491 |
+
if pytree.tree_any_only(torch.Tensor, lambda t: t.requires_grad, args):
|
| 492 |
+
fw_out = pytree.tree_map_only(
|
| 493 |
+
torch.Tensor, lambda x: x.requires_grad_(True), fw_out
|
| 494 |
+
)
|
| 495 |
+
return fw_out, pytree.tree_map_only(
|
| 496 |
+
torch.Tensor, lambda x: x.requires_grad, fw_out
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
return fw_with_masks
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# This function replaces None gradients with all-zero gradients.
|
| 503 |
+
# `None` gradients are problematic for CUDA graphs. Those gradients are
|
| 504 |
+
# replaced with an all-zero tensor for better optimization
|
| 505 |
+
def unmask_none_gradients(grads, operands):
|
| 506 |
+
allowed_types = (torch.Tensor, int, torch.SymInt)
|
| 507 |
+
assert all(
|
| 508 |
+
isinstance(o, allowed_types) for o in operands
|
| 509 |
+
), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}"
|
| 510 |
+
|
| 511 |
+
unmasked_grads = []
|
| 512 |
+
for g, o in zip(grads, operands):
|
| 513 |
+
if g is not None:
|
| 514 |
+
unmasked_grads.append(g)
|
| 515 |
+
else:
|
| 516 |
+
# In case the operand is an int or a torch.SymInt, return None
|
| 517 |
+
# This can happen for lifted_arguments. E.g., the shapes of a dynamic tensor are lifted and passed
|
| 518 |
+
# as additional arguments
|
| 519 |
+
unmasked_grads.append(
|
| 520 |
+
torch.zeros_like(o) if isinstance(o, torch.Tensor) else None
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
return unmasked_grads
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def _maybe_fake_prop_ignore_unbacked(fn, args):
|
| 527 |
+
with ExitStack() as ctx_stack:
|
| 528 |
+
if (fake_mode := detect_fake_mode(args)) is not None:
|
| 529 |
+
ctx_stack.enter_context(fake_mode)
|
| 530 |
+
if fake_mode.shape_env is not None:
|
| 531 |
+
ctx_stack.enter_context(
|
| 532 |
+
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
| 533 |
+
)
|
| 534 |
+
return fn(*args)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def redirect_to_mode(hop: OperatorBase, mode):
|
| 538 |
+
"""Utility for redispatching HOP to underlying mode
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
hop: The HOP to redispatch
|
| 542 |
+
mode: The mode to redispatch to
|
| 543 |
+
|
| 544 |
+
Returns:
|
| 545 |
+
A decorated function that implements the HOP for the given mode
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
@hop.py_impl(mode)
|
| 549 |
+
def impl(mode, *args, **kwargs):
|
| 550 |
+
return mode.__torch_dispatch__(hop, [], args, kwargs)
|
| 551 |
+
|
| 552 |
+
return impl
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
# TODO: The parameter use_output_and_grad_bw is required because some operations
|
| 556 |
+
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
|
| 557 |
+
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
|
| 558 |
+
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
| 559 |
+
|
| 560 |
+
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
|
| 561 |
+
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
| 562 |
+
# added when required. Will encounter two problems if we don't suspend functionalization:
|
| 563 |
+
#
|
| 564 |
+
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
| 565 |
+
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
| 566 |
+
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
| 567 |
+
# fetch the proxy for the inputs and fail to capture any operations on them.
|
| 568 |
+
#
|
| 569 |
+
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
| 570 |
+
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
| 571 |
+
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
| 572 |
+
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
| 573 |
+
# Instead, it will create _tensor_constant as output.
|
| 574 |
+
|
| 575 |
+
dummy_aot_config = AOTConfig(
|
| 576 |
+
fw_compiler=None, # type: ignore[arg-type]
|
| 577 |
+
bw_compiler=None, # type: ignore[arg-type]
|
| 578 |
+
partition_fn=None, # type: ignore[arg-type]
|
| 579 |
+
decompositions={},
|
| 580 |
+
num_params_buffers=0,
|
| 581 |
+
aot_id=0,
|
| 582 |
+
keep_inference_input_mutations=False,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
example_grad = [_from_fun(out) for out in fw_outputs]
|
| 586 |
+
num_grads = len(example_grad)
|
| 587 |
+
fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
|
| 588 |
+
|
| 589 |
+
def joint_fn(*joint_operands_grads):
|
| 590 |
+
if use_output_and_grad_bw:
|
| 591 |
+
grads = joint_operands_grads[0]
|
| 592 |
+
inputs = joint_operands_grads[1][-1:]
|
| 593 |
+
else:
|
| 594 |
+
grads = joint_operands_grads[:num_grads]
|
| 595 |
+
inputs = joint_operands_grads[num_grads:]
|
| 596 |
+
|
| 597 |
+
joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
|
| 598 |
+
_, grads = joint(
|
| 599 |
+
list(inputs),
|
| 600 |
+
[grad for grad in grads if grad is not None and grad.requires_grad],
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
# Unmask None gradients to all-zero gradients
|
| 604 |
+
unmasked_grads = unmask_none_gradients(grads, inputs)
|
| 605 |
+
|
| 606 |
+
# In order to keep map functional for backward graph,
|
| 607 |
+
# we clone outputs that are aliasing inputs
|
| 608 |
+
maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
|
| 609 |
+
|
| 610 |
+
return pytree.tree_map(maybe_clone, unmasked_grads)
|
| 611 |
+
|
| 612 |
+
if use_output_and_grad_bw:
|
| 613 |
+
example_xs_out = list(fw_inputs) + list(fw_outputs)
|
| 614 |
+
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
| 615 |
+
(list(example_grad), list(example_xs_out))
|
| 616 |
+
)
|
| 617 |
+
else:
|
| 618 |
+
example_xs_out = list(fw_inputs)
|
| 619 |
+
joint_graph = _maybe_reenter_make_fx(joint_fn)(
|
| 620 |
+
*(list(example_grad) + list(example_xs_out))
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
return fw_graph, joint_graph
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def _unstack_pytree(xs):
|
| 627 |
+
flat_xs, inspec = pytree.tree_flatten(xs)
|
| 628 |
+
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
| 629 |
+
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
| 630 |
+
|
| 631 |
+
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
| 632 |
+
raise RuntimeError(
|
| 633 |
+
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
a = zip(*flat_xs)
|
| 637 |
+
|
| 638 |
+
pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
|
| 639 |
+
return pytrees
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def _stack_pytree(pytrees):
|
| 643 |
+
flat_out = []
|
| 644 |
+
out_spec = None
|
| 645 |
+
for pt in pytrees:
|
| 646 |
+
flat_pt, out_spec = pytree.tree_flatten(pt)
|
| 647 |
+
flat_out.append(flat_pt)
|
| 648 |
+
assert out_spec is not None
|
| 649 |
+
b = zip(*flat_out)
|
| 650 |
+
stacked_out = []
|
| 651 |
+
for leaves in b:
|
| 652 |
+
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
| 653 |
+
stacked_out.append(torch.stack(leaves))
|
| 654 |
+
elif all(leaf is None for leaf in leaves):
|
| 655 |
+
# Backward graph can return None output when forward inputs doesn't require grad.
|
| 656 |
+
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
| 657 |
+
# therefore we need to deal with None output.
|
| 658 |
+
stacked_out.append(None) # type: ignore[arg-type]
|
| 659 |
+
else:
|
| 660 |
+
raise RuntimeError(f"Cannot stack {leaves}.")
|
| 661 |
+
return pytree.tree_unflatten(stacked_out, out_spec)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
# We cannot call save_for_backward for symints. This helper function
|
| 665 |
+
# can be used to save symints as direct attributes of ctx in autograd.Function.
|
| 666 |
+
#
|
| 667 |
+
# For example, if args = (x, y, s0, z, s1),
|
| 668 |
+
# save_tensors_and_symints_for_backward will partition the args into two lists, and a bookkeeping list pos:
|
| 669 |
+
# partitioned_args[0] = (x, y, z)
|
| 670 |
+
# partitioned_args[1] = (s0, s1)
|
| 671 |
+
# pos = (0, 0, 1, 0, 1)
|
| 672 |
+
# pos list keeps track of which partition the args
|
| 673 |
+
# is partitioned into in order to recover it in saved_tensors_and_symints.
|
| 674 |
+
#
|
| 675 |
+
# In saved_tensors_and_symints, we can recover the original args by:
|
| 676 |
+
# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]].
|
| 677 |
+
# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists.
|
| 678 |
+
def save_tensors_and_symints_for_backward(ctx, args):
|
| 679 |
+
assert all(
|
| 680 |
+
isinstance(arg, (torch.Tensor, torch.SymInt, int, type(None))) for arg in args
|
| 681 |
+
), args
|
| 682 |
+
partitioned_args: list[Any] = [[], []]
|
| 683 |
+
pos = []
|
| 684 |
+
for arg in args:
|
| 685 |
+
idx = 0 if isinstance(arg, torch.Tensor) else 1
|
| 686 |
+
partitioned_args[idx].append(arg)
|
| 687 |
+
pos.append(idx)
|
| 688 |
+
|
| 689 |
+
assert not hasattr(ctx, "sym_int_args"), "ctx already has sym_int_args attribute."
|
| 690 |
+
assert not hasattr(ctx, "pos"), "ctx already has pos attribute."
|
| 691 |
+
ctx.save_for_backward(*partitioned_args[0])
|
| 692 |
+
ctx.sym_int_args = partitioned_args[1]
|
| 693 |
+
ctx.pos = pos
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def saved_tensors_and_symints(ctx):
|
| 697 |
+
args = []
|
| 698 |
+
t_idx = 0
|
| 699 |
+
s_idx = 0
|
| 700 |
+
saved_tensors = ctx.saved_tensors
|
| 701 |
+
for p in ctx.pos:
|
| 702 |
+
if p == 0:
|
| 703 |
+
args.append(saved_tensors[t_idx])
|
| 704 |
+
t_idx += 1
|
| 705 |
+
else:
|
| 706 |
+
args.append(ctx.sym_int_args[s_idx])
|
| 707 |
+
s_idx += 1
|
| 708 |
+
assert t_idx + s_idx == len(ctx.pos)
|
| 709 |
+
return tuple(args)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def get_dummy_aot_autograd_config():
|
| 713 |
+
from torch._functorch.aot_autograd import AOTConfig
|
| 714 |
+
|
| 715 |
+
return AOTConfig(
|
| 716 |
+
fw_compiler=None, # type: ignore[arg-type]
|
| 717 |
+
bw_compiler=None, # type: ignore[arg-type]
|
| 718 |
+
partition_fn=None, # type: ignore[arg-type]
|
| 719 |
+
decompositions={},
|
| 720 |
+
num_params_buffers=0,
|
| 721 |
+
aot_id=0,
|
| 722 |
+
keep_inference_input_mutations=False,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
# Slices off the first element of a given dimension
|
| 727 |
+
def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
| 728 |
+
return torch.select_copy(t, dim, 0)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
# Reports the difference between meta of two tensors in a string
|
| 732 |
+
def diff_tensor_meta(
|
| 733 |
+
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
|
| 734 |
+
) -> list[str]:
|
| 735 |
+
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
| 736 |
+
|
| 737 |
+
pair_diffs = []
|
| 738 |
+
for meta_name in TensorMetadata._fields:
|
| 739 |
+
if not check_grad and meta_name == "requires_grad":
|
| 740 |
+
continue
|
| 741 |
+
val1 = getattr(meta1, meta_name)
|
| 742 |
+
val2 = getattr(meta2, meta_name)
|
| 743 |
+
try:
|
| 744 |
+
if val1 != val2:
|
| 745 |
+
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
| 746 |
+
except GuardOnDataDependentSymNode as _:
|
| 747 |
+
pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'")
|
| 748 |
+
continue
|
| 749 |
+
return pair_diffs
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
# Note [lifted arg types in hop]
|
| 753 |
+
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
|
| 754 |
+
# This has implications for the types of lifted args for different dispatch keys:
|
| 755 |
+
# 1. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd need to support torch.Symint
|
| 756 |
+
# lifted args because it's on the path of torch.compile(dynamic=True).
|
| 757 |
+
# 2. functionalization, FakeTensorMode, ProxyTorchDispatchMode, Autograd, CompositeExplicitAutograd need
|
| 758 |
+
# to support int arguments. In the eager run case, we re-trace the subgraph in AutogradKey, so inner
|
| 759 |
+
# hops may receive int inputs from the shape of outer tensor inputs.
|
| 760 |
+
# However, CompositeExplicitAutograd won't receive SymInt inputs because it only accepts real tensor inputs.
|
| 761 |
+
def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]):
|
| 762 |
+
allowed_types = (torch.Tensor, int, torch.SymInt)
|
| 763 |
+
assert all(
|
| 764 |
+
isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args
|
| 765 |
+
), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}"
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
# TODO: Return a more detailed information as to which node
|
| 769 |
+
# causes a mutation or an alias. This may requires a per operator tensor version checking
|
| 770 |
+
def check_input_alias_and_mutation(
|
| 771 |
+
gm: torch.fx.GraphModule,
|
| 772 |
+
fake_args: list[FakeTensor],
|
| 773 |
+
) -> tuple[dict[int, int], dict[int, int], dict[int, int], list[int]]:
|
| 774 |
+
(
|
| 775 |
+
inp_inp_alias_map,
|
| 776 |
+
inp_out_alias_map,
|
| 777 |
+
out_out_alias_map,
|
| 778 |
+
mutated_inputs,
|
| 779 |
+
) = check_input_alias_and_mutation_return_outputs(gm, fake_args)[:-1]
|
| 780 |
+
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def check_input_alias_and_mutation_return_outputs(
|
| 784 |
+
gm: torch.fx.GraphModule,
|
| 785 |
+
fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]],
|
| 786 |
+
) -> tuple[
|
| 787 |
+
dict[int, int],
|
| 788 |
+
dict[int, int],
|
| 789 |
+
dict[int, int],
|
| 790 |
+
list[int],
|
| 791 |
+
Union[tuple[Any, ...], list[Any]],
|
| 792 |
+
]:
|
| 793 |
+
# This function can be called under autograd, functional, proxy and fake tensor mode.
|
| 794 |
+
# We need to return either a fake tensor or a real tensor depending on the mode.
|
| 795 |
+
# to detect the input mutation/aliasing.
|
| 796 |
+
with disable_proxy_modes_tracing(), disable_functional_mode(), suspend_functionalization():
|
| 797 |
+
|
| 798 |
+
def _from_functional_tensor(t: torch.Tensor) -> torch.Tensor:
|
| 799 |
+
if isinstance(t, FunctionalTensor) or torch._is_functional_tensor(t):
|
| 800 |
+
return torch.empty_strided(
|
| 801 |
+
t.size(),
|
| 802 |
+
t.stride(),
|
| 803 |
+
dtype=t.dtype,
|
| 804 |
+
requires_grad=t.requires_grad,
|
| 805 |
+
device=t.device,
|
| 806 |
+
)
|
| 807 |
+
return t
|
| 808 |
+
|
| 809 |
+
fake_args = pytree.tree_map_only(
|
| 810 |
+
torch.Tensor, _from_functional_tensor, fake_args
|
| 811 |
+
)
|
| 812 |
+
# We want to disable active functional, proxy and fake modes if any.
|
| 813 |
+
# to create a encapsulated environment for fake tensor prop
|
| 814 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 815 |
+
"""This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias
|
| 816 |
+
in the graph module gm. It checks whether input tensor versions have
|
| 817 |
+
changed after run gm once to detect mutation and checks tensor storage
|
| 818 |
+
to detect alias.
|
| 819 |
+
"""
|
| 820 |
+
|
| 821 |
+
def _tensor_version(t) -> Optional[int]:
|
| 822 |
+
if isinstance(t, torch.Tensor):
|
| 823 |
+
if not isinstance(t, FakeTensor):
|
| 824 |
+
raise RuntimeError("Only fake tensor is allowed")
|
| 825 |
+
return t._version
|
| 826 |
+
return None
|
| 827 |
+
|
| 828 |
+
def _tensor_storage(t) -> StorageWeakRef:
|
| 829 |
+
return StorageWeakRef(t._typed_storage())
|
| 830 |
+
|
| 831 |
+
def _get_shape_env(
|
| 832 |
+
fake_args,
|
| 833 |
+
) -> Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]:
|
| 834 |
+
# detect_fake_mode requires there could be only one active fake mode. This
|
| 835 |
+
# restricts the usage of this function because the global TracingContext
|
| 836 |
+
# has a persistent fake mode but fake tensors can be created
|
| 837 |
+
# outside of the tracing context (e.g. in testing).
|
| 838 |
+
# Instead, we just look at fake_args fake tensor mode
|
| 839 |
+
if len(fake_args) == 0:
|
| 840 |
+
return torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
| 841 |
+
|
| 842 |
+
for arg in fake_args:
|
| 843 |
+
if isinstance(arg, FakeTensor):
|
| 844 |
+
return arg.fake_mode.shape_env
|
| 845 |
+
return None
|
| 846 |
+
|
| 847 |
+
# Clone the fake args to avoid mutating the original fake args
|
| 848 |
+
with ExitStack() as ctx_stack:
|
| 849 |
+
# We need to re-use prev_fake_mode's shape env to resolve
|
| 850 |
+
# the runtime assertions for unbacked symbols.
|
| 851 |
+
new_fake_mode = torch._subclasses.FakeTensorMode(
|
| 852 |
+
shape_env=_get_shape_env(fake_args),
|
| 853 |
+
allow_non_fake_inputs=False,
|
| 854 |
+
)
|
| 855 |
+
# We need to temporarily turn inference_mode off because
|
| 856 |
+
# under inference mode, tensor version counter is not tracked.
|
| 857 |
+
no_inference_mode_ctx = torch.inference_mode(False)
|
| 858 |
+
ctx_stack.enter_context(new_fake_mode)
|
| 859 |
+
ctx_stack.enter_context(no_inference_mode_ctx)
|
| 860 |
+
if new_fake_mode.shape_env is not None:
|
| 861 |
+
ctx_stack.enter_context(
|
| 862 |
+
new_fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
# create new fake tensors in new fake mode to avoid mutating original tensors
|
| 866 |
+
cloned = [
|
| 867 |
+
torch.empty_strided(
|
| 868 |
+
arg.size(),
|
| 869 |
+
arg.stride(),
|
| 870 |
+
dtype=arg.dtype,
|
| 871 |
+
device=arg.device,
|
| 872 |
+
requires_grad=arg.requires_grad,
|
| 873 |
+
layout=arg.layout,
|
| 874 |
+
)
|
| 875 |
+
if isinstance(arg, torch.Tensor)
|
| 876 |
+
else arg
|
| 877 |
+
for arg in fake_args
|
| 878 |
+
]
|
| 879 |
+
before = [_tensor_version(arg) for arg in cloned]
|
| 880 |
+
outputs = gm(*cloned)
|
| 881 |
+
outputs = [outputs] if not isinstance(outputs, (list, tuple)) else outputs
|
| 882 |
+
after = [_tensor_version(arg) for arg in cloned]
|
| 883 |
+
mutated_inputs = [
|
| 884 |
+
i for i, (v1, v2) in enumerate(zip(before, after)) if v1 != v2
|
| 885 |
+
]
|
| 886 |
+
# We need to analyze the original fake_args to detect
|
| 887 |
+
# inp-inp alias.
|
| 888 |
+
inp_storage_map = {
|
| 889 |
+
_tensor_storage(inp): i
|
| 890 |
+
for i, inp in enumerate(fake_args)
|
| 891 |
+
if isinstance(inp, torch.Tensor)
|
| 892 |
+
}
|
| 893 |
+
inp_inp_alias_map = {
|
| 894 |
+
i: inp_storage_map[_tensor_storage(inp)]
|
| 895 |
+
for i, inp in enumerate(fake_args)
|
| 896 |
+
if isinstance(inp, torch.Tensor)
|
| 897 |
+
and inp_storage_map[_tensor_storage(inp)] != i
|
| 898 |
+
}
|
| 899 |
+
out_storage_map = {
|
| 900 |
+
_tensor_storage(out): i
|
| 901 |
+
for i, out in enumerate(outputs)
|
| 902 |
+
if isinstance(out, torch.Tensor)
|
| 903 |
+
}
|
| 904 |
+
out_out_alias_map = {
|
| 905 |
+
i: out_storage_map[_tensor_storage(out)]
|
| 906 |
+
for i, out in enumerate(outputs)
|
| 907 |
+
if isinstance(out, torch.Tensor)
|
| 908 |
+
and out_storage_map[_tensor_storage(out)] != i
|
| 909 |
+
}
|
| 910 |
+
inp_out_alias_map = {
|
| 911 |
+
i: out_storage_map[_tensor_storage(inp)]
|
| 912 |
+
for i, inp in enumerate(cloned)
|
| 913 |
+
if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map
|
| 914 |
+
}
|
| 915 |
+
return (
|
| 916 |
+
inp_inp_alias_map,
|
| 917 |
+
inp_out_alias_map,
|
| 918 |
+
out_out_alias_map,
|
| 919 |
+
mutated_inputs,
|
| 920 |
+
outputs,
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
registered_hop_fake_fns: dict[torch._ops.OpOverload, Callable] = {}
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
F = TypeVar("F", bound=Callable)
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
@overload
|
| 931 |
+
def register_fake(hop, fn: None = None) -> Callable[[F], F]:
|
| 932 |
+
...
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
@overload
|
| 936 |
+
def register_fake(hop, fn: F) -> F:
|
| 937 |
+
...
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def register_fake(hop, fn=None):
|
| 941 |
+
"""
|
| 942 |
+
Register a fake function for a HOP. This is conceptually equivalent of the
|
| 943 |
+
register_fake utility for the custom ops. The registered function is called
|
| 944 |
+
inside the fake_tensor _dispatch_impl.
|
| 945 |
+
"""
|
| 946 |
+
assert hop not in registered_hop_fake_fns
|
| 947 |
+
|
| 948 |
+
def register(func: F) -> F:
|
| 949 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 950 |
+
|
| 951 |
+
redirect_to_mode(hop, FakeTensorMode)
|
| 952 |
+
|
| 953 |
+
registered_hop_fake_fns[hop] = func
|
| 954 |
+
return func
|
| 955 |
+
|
| 956 |
+
if fn is None:
|
| 957 |
+
return register
|
| 958 |
+
return register(fn)
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
class FunctionalizeCtxWrapper:
|
| 962 |
+
"""
|
| 963 |
+
This is a dummy wrapper to facilitate fake tensor caching.
|
| 964 |
+
|
| 965 |
+
For AOT Dispatcher metadata collection pass, HOPs go from functionalization
|
| 966 |
+
key to fake tensor key. The functionalization key wraps the subgraphs in a
|
| 967 |
+
function, which changes from call to call even though the subgraph might
|
| 968 |
+
still be same.
|
| 969 |
+
|
| 970 |
+
To enable fake tensor caching, we just wrap the ctx and subgraph in this
|
| 971 |
+
class and then use the subgraph as the hash.
|
| 972 |
+
"""
|
| 973 |
+
|
| 974 |
+
# Prevents PYTORCH_TEST_WITH_DYNAMO=1 test failures
|
| 975 |
+
@torch._disable_dynamo
|
| 976 |
+
def __init__(self, ctx, subgraph):
|
| 977 |
+
self.ctx = ctx
|
| 978 |
+
self.subgraph = subgraph
|
| 979 |
+
|
| 980 |
+
def __hash__(self):
|
| 981 |
+
return id(self.subgraph)
|
| 982 |
+
|
| 983 |
+
def __repr__(self):
|
| 984 |
+
return f"FunctionalizeCtxWrapper on subgraph {self.subgraph})"
|
| 985 |
+
|
| 986 |
+
def __call__(self, *args, **kwargs):
|
| 987 |
+
if isinstance(self.subgraph, torch.fx.GraphModule):
|
| 988 |
+
# Running graph with interpreter is needed for propagating the stack_trace
|
| 989 |
+
with fx_traceback.preserve_node_meta():
|
| 990 |
+
return self.ctx.functionalize(torch.fx.Interpreter(self.subgraph).run)(
|
| 991 |
+
*args, **kwargs
|
| 992 |
+
)
|
| 993 |
+
return self.ctx.functionalize(self.subgraph)(*args, **kwargs)
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
# A wrapper over HigherOrderOperator that also carries its schema
|
| 997 |
+
class HopInstance:
|
| 998 |
+
def __init__(self, op: HigherOrderOperator, schema: HopSchema):
|
| 999 |
+
assert isinstance(op, HigherOrderOperator), op
|
| 1000 |
+
self._op = op
|
| 1001 |
+
# Using "_" to be consistent with how we access _schema of OpOverload
|
| 1002 |
+
self._schema = schema
|
| 1003 |
+
|
| 1004 |
+
def __call__(self, *args, **kwargs):
|
| 1005 |
+
return self._op(*args, **kwargs)
|
| 1006 |
+
|
| 1007 |
+
@staticmethod
|
| 1008 |
+
def create(hop: HigherOrderOperator, *args, **kwargs):
|
| 1009 |
+
return HopInstance(hop, hop.gen_schema(*args, **kwargs))
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
# This call_op can be used to call a HopInstance with
|
| 1013 |
+
# flat args and kwargs. We need to make use of the hop's schema's tree_spec
|
| 1014 |
+
# to unflatten the args and kwargs before calling the hop.
|
| 1015 |
+
def call_op(op: Union[OpOverload, HopInstance], args, kwargs):
|
| 1016 |
+
if isinstance(op, OpOverload):
|
| 1017 |
+
return op(*args, **kwargs)
|
| 1018 |
+
|
| 1019 |
+
assert isinstance(op, HopInstance), op
|
| 1020 |
+
schema = op._schema
|
| 1021 |
+
bound_args = list(args)
|
| 1022 |
+
bound_kwargs = {}
|
| 1023 |
+
for arg in schema.arguments[len(bound_args) :]:
|
| 1024 |
+
assert arg.name in kwargs, (arg.name, kwargs)
|
| 1025 |
+
val = kwargs[arg.name]
|
| 1026 |
+
if not arg.kwarg_only:
|
| 1027 |
+
bound_args.append(val)
|
| 1028 |
+
else:
|
| 1029 |
+
bound_kwargs[arg.name] = val
|
| 1030 |
+
|
| 1031 |
+
if schema.tree_spec is not None:
|
| 1032 |
+
assert len(bound_args) == len(schema.arguments) and len(bound_kwargs) == 0
|
| 1033 |
+
args, kwargs = pytree.tree_unflatten(bound_args, schema.tree_spec)
|
| 1034 |
+
return op(*args, **kwargs)
|
| 1035 |
+
else:
|
| 1036 |
+
assert len(bound_args) + len(bound_kwargs) == len(schema.arguments)
|
| 1037 |
+
return op(*bound_args, **bound_kwargs)
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
def materialize_as_graph(
|
| 1041 |
+
fn: Callable,
|
| 1042 |
+
args: tuple[Any],
|
| 1043 |
+
include_key_set: Optional[torch._C.DispatchKeySet] = None,
|
| 1044 |
+
exclude_key_set: Optional[torch._C.DispatchKeySet] = None,
|
| 1045 |
+
force_enable_grad=False,
|
| 1046 |
+
) -> torch.fx.GraphModule:
|
| 1047 |
+
if include_key_set is None:
|
| 1048 |
+
include_key_set = torch._C._dispatch_tls_local_include_set()
|
| 1049 |
+
if exclude_key_set is None:
|
| 1050 |
+
exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
|
| 1051 |
+
|
| 1052 |
+
@torch._dynamo.disable(recursive=True, reason=None)
|
| 1053 |
+
def _materialize_as_graph_inner():
|
| 1054 |
+
with suspend_functionalization(), disable_functional_mode():
|
| 1055 |
+
with disable_proxy_modes_tracing():
|
| 1056 |
+
unfunc_t = [_from_fun(arg) for arg in args]
|
| 1057 |
+
with contextlib.ExitStack() as stack:
|
| 1058 |
+
stack.enter_context(
|
| 1059 |
+
torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
|
| 1060 |
+
)
|
| 1061 |
+
if force_enable_grad:
|
| 1062 |
+
stack.enter_context(torch.enable_grad())
|
| 1063 |
+
return _maybe_reenter_make_fx(fn)(*unfunc_t)
|
| 1064 |
+
|
| 1065 |
+
gm = _materialize_as_graph_inner()
|
| 1066 |
+
assert gm is not None
|
| 1067 |
+
return gm
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
def materialize_callable_in_args(op: HopInstance, args, kwargs):
|
| 1071 |
+
schema = op._schema
|
| 1072 |
+
hop = op._op
|
| 1073 |
+
flat_args, flat_spec = pytree.tree_flatten((args, kwargs))
|
| 1074 |
+
|
| 1075 |
+
def wrapped_fn(*flat_args):
|
| 1076 |
+
return call_op(op, args, kwargs)
|
| 1077 |
+
|
| 1078 |
+
# We need to trace the higher order op in order to materilaize the callable inputs that
|
| 1079 |
+
# are a callable (e.g. after functionalization key)
|
| 1080 |
+
gm = reenter_make_fx(wrapped_fn)(*flat_args)
|
| 1081 |
+
hop_node = gm.graph.find_nodes(op="call_function", target=hop)[0]
|
| 1082 |
+
arg_proxies = pytree.tree_leaves((hop_node.args, hop_node.kwargs))
|
| 1083 |
+
assert isinstance(schema, torch._C.FunctionSchema) and len(arg_proxies) == len(
|
| 1084 |
+
schema.arguments
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
# call_op preserves ordering of proxies via schema
|
| 1088 |
+
materialized_args = []
|
| 1089 |
+
for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)):
|
| 1090 |
+
if (
|
| 1091 |
+
isinstance(proxy, torch.fx.Node)
|
| 1092 |
+
and proxy.op == "get_attr"
|
| 1093 |
+
and isinstance(getattr(gm, proxy.target), torch.fx.GraphModule) # type: ignore[arg-type]
|
| 1094 |
+
):
|
| 1095 |
+
assert callable(flat_args[i]), (schema, args, kwargs)
|
| 1096 |
+
materialized_args.append(getattr(gm, proxy.target)) # type: ignore[arg-type]
|
| 1097 |
+
else:
|
| 1098 |
+
materialized_args.append(flat_args[i])
|
| 1099 |
+
|
| 1100 |
+
return pytree.tree_unflatten(materialized_args, flat_spec)
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
def has_user_subclass(args, allowed_subclasses):
|
| 1104 |
+
"""Check if any tensor arguments are user subclasses.
|
| 1105 |
+
|
| 1106 |
+
This is used to determine if tensor subclasses should get a chance to run
|
| 1107 |
+
their own implementation first before falling back to the default implementation.
|
| 1108 |
+
|
| 1109 |
+
Args:
|
| 1110 |
+
args: Arguments to check (will be flattened with pytree)
|
| 1111 |
+
allowed_subclasses: Tuple of allowed subclass types
|
| 1112 |
+
|
| 1113 |
+
Returns:
|
| 1114 |
+
True if user tensor subclasses are found, False otherwise
|
| 1115 |
+
"""
|
| 1116 |
+
flat_args, _ = pytree.tree_flatten(args)
|
| 1117 |
+
|
| 1118 |
+
val = any(
|
| 1119 |
+
isinstance(a, torch.Tensor)
|
| 1120 |
+
and type(a) is not torch.Tensor
|
| 1121 |
+
and not isinstance(a, allowed_subclasses)
|
| 1122 |
+
for a in flat_args
|
| 1123 |
+
)
|
| 1124 |
+
return val
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
def _has_gen_schema(op: HigherOrderOperator):
|
| 1128 |
+
# There is an InvokeQuant argument we cannot gen_schema.
|
| 1129 |
+
if op is torch.ops.higher_order.invoke_quant_packed:
|
| 1130 |
+
return False
|
| 1131 |
+
method = "gen_schema"
|
| 1132 |
+
return hasattr(type(op), method) and getattr(type(op), method) is not getattr(
|
| 1133 |
+
HigherOrderOperator, method
|
| 1134 |
+
)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/while_loop.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
from typing import Callable, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils._pytree as pytree
|
| 7 |
+
from torch._C import DispatchKey
|
| 8 |
+
from torch._higher_order_ops.utils import (
|
| 9 |
+
_maybe_run_with_interpreter,
|
| 10 |
+
_set_compilation_env,
|
| 11 |
+
autograd_not_implemented,
|
| 12 |
+
check_meta_consistency,
|
| 13 |
+
reenter_make_fx,
|
| 14 |
+
validate_subgraph_args_types,
|
| 15 |
+
)
|
| 16 |
+
from torch._ops import HigherOrderOperator
|
| 17 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 18 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 19 |
+
_temp_remove_metadata_torch_function_mode,
|
| 20 |
+
ProxyTorchDispatchMode,
|
| 21 |
+
track_tensor_tree,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class WhileLoopOp(HigherOrderOperator):
|
| 26 |
+
def __init__(self) -> None:
|
| 27 |
+
super().__init__("while_loop")
|
| 28 |
+
|
| 29 |
+
def __call__(
|
| 30 |
+
self,
|
| 31 |
+
cond_fn: Callable,
|
| 32 |
+
body_fn: Callable,
|
| 33 |
+
carried_inputs: tuple[Union[torch.Tensor, int, float, bool]],
|
| 34 |
+
additional_inputs: tuple[Union[torch.Tensor, torch.SymInt, int], ...],
|
| 35 |
+
/,
|
| 36 |
+
):
|
| 37 |
+
if not isinstance(carried_inputs, (tuple, list)):
|
| 38 |
+
raise RuntimeError(
|
| 39 |
+
f"carried_inputs must be a tuple or list, got {type(carried_inputs)}"
|
| 40 |
+
)
|
| 41 |
+
if not isinstance(additional_inputs, (tuple, list)):
|
| 42 |
+
raise RuntimeError(
|
| 43 |
+
f"additional_inputs must be a tuple or list, got {type(additional_inputs)}"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
validate_subgraph_args_types(carried_inputs)
|
| 47 |
+
validate_subgraph_args_types(additional_inputs)
|
| 48 |
+
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
while_loop_op = WhileLoopOp()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def while_loop(cond_fn, body_fn, carried_inputs):
|
| 55 |
+
r"""
|
| 56 |
+
Run body_fn(*carried_inputs) while cond_fn(*carried_inputs) returns a True scalar tensor. Returns the output of body_fn or
|
| 57 |
+
initial carried_inputs.
|
| 58 |
+
|
| 59 |
+
.. warning::
|
| 60 |
+
`torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and
|
| 61 |
+
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
|
| 62 |
+
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
| 63 |
+
|
| 64 |
+
`while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export.
|
| 65 |
+
|
| 66 |
+
`while_loop` is equivalent to the following:
|
| 67 |
+
|
| 68 |
+
def while_loop(cond_fn, body_fn, carried_inputs):
|
| 69 |
+
val = carried_inputs
|
| 70 |
+
while cond_fn(*val):
|
| 71 |
+
val = body_fn(*val)
|
| 72 |
+
return val
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
cond_fn (Callable): A callable function that returns a boolean Scalar tensor or a python boolean.
|
| 76 |
+
|
| 77 |
+
body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors or ints
|
| 78 |
+
|
| 79 |
+
carried_inputs (Tuple of possibly nested dict/list/tuple of tensors or ints): A tuple of inputs to cond_fn and body_fn.
|
| 80 |
+
It's also the initial value of states that are carried across iterations. Note that when pass an integer as carry,
|
| 81 |
+
the corresponding return of while_loop will be another int with unknown values because we don't know how many
|
| 82 |
+
iterations while_loop will run.
|
| 83 |
+
|
| 84 |
+
Example 1:
|
| 85 |
+
|
| 86 |
+
def cond_fn(iter, x):
|
| 87 |
+
return iter.sum() < 10
|
| 88 |
+
|
| 89 |
+
def body_fn(iter, x):
|
| 90 |
+
return iter + 1, x.sin()
|
| 91 |
+
|
| 92 |
+
while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4)))
|
| 93 |
+
|
| 94 |
+
Example 2:
|
| 95 |
+
|
| 96 |
+
def cond_fn(int_iter, x):
|
| 97 |
+
return 2 * int_iter < x.shape[0]
|
| 98 |
+
|
| 99 |
+
def body_fn(int_iter, x):
|
| 100 |
+
return int_iter + 1, x + int_iter
|
| 101 |
+
|
| 102 |
+
while_loop(cond,_fn, body_fn, (0, torch.randn(3, 4)))
|
| 103 |
+
|
| 104 |
+
Restrictions:
|
| 105 |
+
|
| 106 |
+
- body_fn must return tensors or int with the same metadata (e.g.shape, dtype) as inputs.
|
| 107 |
+
|
| 108 |
+
- body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required.
|
| 109 |
+
|
| 110 |
+
- body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn.
|
| 111 |
+
|
| 112 |
+
- body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required.
|
| 113 |
+
|
| 114 |
+
.. warning::
|
| 115 |
+
Temporal Limitations:
|
| 116 |
+
|
| 117 |
+
- 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
|
| 118 |
+
|
| 119 |
+
"""
|
| 120 |
+
from torch._dynamo.backends.debugging import (
|
| 121 |
+
make_eager_backend_with_torch_function_mode,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
|
| 125 |
+
# parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
|
| 126 |
+
additional_inputs: tuple = ()
|
| 127 |
+
|
| 128 |
+
# The reason we flatten the output before calling into dynamo is that
|
| 129 |
+
# we want to create a consistent input ordering for cond_fn and body_fn.
|
| 130 |
+
# and we also want to the input ordering matches the output ordering.
|
| 131 |
+
# Also see NOTE: [why we cannot use "automatic" for while_loop]
|
| 132 |
+
# Construct flat cond_fn and flat_body_fn, which takes flattened inputs
|
| 133 |
+
flat_inputs, in_spec = pytree.tree_flatten((carried_inputs, additional_inputs))
|
| 134 |
+
|
| 135 |
+
def flat_cond_fn(*flat_args):
|
| 136 |
+
carried, additional = pytree.tree_unflatten(flat_args, in_spec)
|
| 137 |
+
return cond_fn(*carried, *additional)
|
| 138 |
+
|
| 139 |
+
def flat_body_fn(*flat_args):
|
| 140 |
+
carried, additional = pytree.tree_unflatten(flat_args, in_spec)
|
| 141 |
+
return body_fn(*carried, *additional)
|
| 142 |
+
|
| 143 |
+
if torch.compiler.is_dynamo_compiling():
|
| 144 |
+
return while_loop_op(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
|
| 145 |
+
|
| 146 |
+
def _validate_input(cond_fn, body_fn, carried_inputs):
|
| 147 |
+
from torch._higher_order_ops.utils import validate_subgraph_args_types
|
| 148 |
+
|
| 149 |
+
if not callable(cond_fn) or not callable(body_fn):
|
| 150 |
+
raise RuntimeError("Expect cond_fn and body_fn to be callable.")
|
| 151 |
+
|
| 152 |
+
validate_subgraph_args_types(flat_inputs)
|
| 153 |
+
|
| 154 |
+
if not pytree.tree_all(
|
| 155 |
+
lambda t: isinstance(t, (torch.Tensor, torch.SymInt, int)), carried_inputs
|
| 156 |
+
):
|
| 157 |
+
raise RuntimeError(
|
| 158 |
+
"Expect carried_inputs to be a tuple of possibly nested dict/list/tuple that only"
|
| 159 |
+
f"consists of tensor or int leaves, but got {carried_inputs}."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
_validate_input(cond_fn, body_fn, carried_inputs)
|
| 163 |
+
|
| 164 |
+
# Dynamo is expecting a callable with "__code__" attribute.
|
| 165 |
+
# We cannot directly pass cond_op to it. So we wrap it in a dummy function.
|
| 166 |
+
def _while_loop_op_wrapper(*args, **kwargs):
|
| 167 |
+
return while_loop_op(*args, **kwargs)
|
| 168 |
+
|
| 169 |
+
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
| 170 |
+
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
| 171 |
+
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
|
| 172 |
+
if metadata_mode:
|
| 173 |
+
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
|
| 174 |
+
else:
|
| 175 |
+
backend = "eager"
|
| 176 |
+
return torch.compile(
|
| 177 |
+
_while_loop_op_wrapper, backend=backend, fullgraph=True
|
| 178 |
+
)(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 182 |
+
def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
|
| 183 |
+
carried_vals = carried_inputs
|
| 184 |
+
|
| 185 |
+
def _validate_cond_output(pred):
|
| 186 |
+
if (
|
| 187 |
+
isinstance(pred, torch.Tensor)
|
| 188 |
+
and pred.size() == torch.Size([])
|
| 189 |
+
and pred.dtype == torch.bool
|
| 190 |
+
) or isinstance(pred, bool):
|
| 191 |
+
return
|
| 192 |
+
else:
|
| 193 |
+
raise RuntimeError(
|
| 194 |
+
f"cond_fn must return a boolean scalar tensor or a boolean but got {pred}"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
if not isinstance(carried_inputs, (tuple, list)):
|
| 198 |
+
raise RuntimeError(
|
| 199 |
+
f"carried_inputs must be a tuple or list but got {type(carried_inputs)}"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
while pred := cond_fn(*carried_vals, *additional_inputs):
|
| 203 |
+
_validate_cond_output(pred)
|
| 204 |
+
out = body_fn(*carried_vals, *additional_inputs)
|
| 205 |
+
assert isinstance(
|
| 206 |
+
out, tuple
|
| 207 |
+
), f"body_fn should return a tuple but got {type(out)}"
|
| 208 |
+
assert len(out) == len(
|
| 209 |
+
carried_inputs
|
| 210 |
+
), "body_fn should return the same number of elements as carried_inputs"
|
| 211 |
+
carried_vals = out
|
| 212 |
+
return carried_vals
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
while_loop_op.py_autograd_impl(
|
| 216 |
+
autograd_not_implemented(while_loop_op, deferred_error=True)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _find_or_create_fake_mode() -> FakeTensorMode:
|
| 221 |
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
| 222 |
+
|
| 223 |
+
fake_mode = torch._guards.detect_fake_mode()
|
| 224 |
+
if fake_mode is None:
|
| 225 |
+
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
| 226 |
+
|
| 227 |
+
return fake_mode
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _create_unbacked_symint(
|
| 231 |
+
fake_mode: FakeTensorMode, ignore_fresh_unbacked_symbols: bool
|
| 232 |
+
) -> torch.SymInt:
|
| 233 |
+
assert (
|
| 234 |
+
fake_mode is not None and fake_mode.shape_env is not None
|
| 235 |
+
), "Must provide a fake_mode with shape_env."
|
| 236 |
+
ctx = (
|
| 237 |
+
contextlib.nullcontext()
|
| 238 |
+
if not ignore_fresh_unbacked_symbols
|
| 239 |
+
else fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
| 240 |
+
)
|
| 241 |
+
with ctx:
|
| 242 |
+
return fake_mode.shape_env.create_unbacked_symint()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@while_loop_op.py_impl(ProxyTorchDispatchMode)
|
| 246 |
+
def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
|
| 247 |
+
def _trace_while_loop(
|
| 248 |
+
proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
|
| 249 |
+
):
|
| 250 |
+
# NOTE [unspecialize int carry with unbacked symints]
|
| 251 |
+
# When we support int carry, we'll also need to support int output of body_fn because.
|
| 252 |
+
# previous iteration's output is next iteration's input and they must match.
|
| 253 |
+
# For carries, when we start tracing while_loop, they can be
|
| 254 |
+
# - constants e.g. (0, [1, 3])
|
| 255 |
+
# - backed symints (x.shape[0], [x.shape[1] + x.stride[1], x.shape[2]])
|
| 256 |
+
# - unbacked symints e.g. (u0, [u0 + u1, u2])
|
| 257 |
+
# We choose the most conservative design: in all cases, we create new unbacked symints to trace the
|
| 258 |
+
# subgraph. It's possible to do some analysis on initial carry and the output of first
|
| 259 |
+
# iteration to determine a better range for the output unbacked symbol e.g. when input is an unbacked
|
| 260 |
+
# symint >= 0 before the while_loop but in general this is difficult because we don't know
|
| 261 |
+
# the number of iterations. Users would have to re-constrain the unbacked symint in subgraph if needed.
|
| 262 |
+
#
|
| 263 |
+
# For output of fake cond_fn, it could be constant bool or SymBool (e.g. return x.shape[0] < 4,
|
| 264 |
+
# where x.shape[0] can be either static of dynamic). In the case of constant bool, we should do a
|
| 265 |
+
# specialization (NYI).
|
| 266 |
+
|
| 267 |
+
# For output of fake body_fn, it could be all three types though from user's point of view,
|
| 268 |
+
# they're all integers e.g.
|
| 269 |
+
|
| 270 |
+
# init_carry = (0, s0, u1, t)
|
| 271 |
+
# def body_fn(u0, s0, u1, t):
|
| 272 |
+
# ...
|
| 273 |
+
# return (t.shape[0], t.shape[1], t.shape[2], y + 1)
|
| 274 |
+
#
|
| 275 |
+
# It may seem that a constant output isn't possible: users shouldn't write a while_loop
|
| 276 |
+
# that always return 0. But it could be that a shape is not set as dynamic properly (e.g.
|
| 277 |
+
# automatic dynamic hasn't been triggered).
|
| 278 |
+
#
|
| 279 |
+
# For this reason, we treat int, symint outputs in the same way:
|
| 280 |
+
# - they can match against any of int, symint carry
|
| 281 |
+
# - we unspecialize them with new unbacked symints in fake while_loop
|
| 282 |
+
# Similarly, we could do some analysis to refine the output ranges but it's eaiser to start with
|
| 283 |
+
# fresh unbacked symints. One suprising case can be: an input unbacked symint is constrained by
|
| 284 |
+
# users to be >= 0 (either before while_loop or inside body_fn) and it increments by 1 in each
|
| 285 |
+
# iteration. Ideally, we should know that the final output is >= 0 but we didn't constrain the
|
| 286 |
+
# unbacked symint output of subgraph as of today because this requires a smart range analysis.
|
| 287 |
+
fake_mode: FakeTensorMode = _find_or_create_fake_mode()
|
| 288 |
+
unspecialized_carried_inputs = pytree.tree_map_only(
|
| 289 |
+
(int, torch.SymInt),
|
| 290 |
+
# For temporarily created unbacked symints, we don't need to bind them to any proxy
|
| 291 |
+
lambda _: _create_unbacked_symint(
|
| 292 |
+
fake_mode, ignore_fresh_unbacked_symbols=True
|
| 293 |
+
),
|
| 294 |
+
carried_inputs,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
cond_graph = reenter_make_fx(cond_fn)(
|
| 298 |
+
*unspecialized_carried_inputs, *additional_inputs
|
| 299 |
+
)
|
| 300 |
+
body_graph = reenter_make_fx(body_fn)(
|
| 301 |
+
*unspecialized_carried_inputs, *additional_inputs
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
next_name = None
|
| 305 |
+
i = 0
|
| 306 |
+
while not next_name:
|
| 307 |
+
candidate = f"while_loop_cond_graph_{i}"
|
| 308 |
+
if hasattr(proxy_mode.tracer.root, candidate):
|
| 309 |
+
i += 1
|
| 310 |
+
else:
|
| 311 |
+
next_name = candidate
|
| 312 |
+
cond_graph_name = next_name
|
| 313 |
+
body_graph_name = f"while_loop_body_graph_{i}"
|
| 314 |
+
assert not hasattr(proxy_mode.tracer.root, body_graph_name)
|
| 315 |
+
|
| 316 |
+
proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
|
| 317 |
+
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
|
| 318 |
+
|
| 319 |
+
args = (cond_graph, body_graph, carried_inputs, additional_inputs)
|
| 320 |
+
|
| 321 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
| 322 |
+
|
| 323 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 324 |
+
"call_function", while_loop_op, proxy_args, {}, name="while_loop"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
out = while_loop_op(
|
| 328 |
+
cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
|
| 329 |
+
)
|
| 330 |
+
return track_tensor_tree(
|
| 331 |
+
out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
return _trace_while_loop(
|
| 335 |
+
mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@while_loop_op.py_impl(FakeTensorMode)
|
| 340 |
+
def while_loop_fake_tensor_mode(
|
| 341 |
+
mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
| 342 |
+
):
|
| 343 |
+
with mode:
|
| 344 |
+
# NOTE: [Handling unback symints in subgraph of while_loop]
|
| 345 |
+
# The idea is that the scope of unbacked symints are limited to the subgraph.
|
| 346 |
+
#
|
| 347 |
+
# We're implementing the fake tensor mode of while_loop operator.
|
| 348 |
+
# and we run body_fn once to get an fake output.
|
| 349 |
+
# Let's first consider the case that unbacked symints are tensor shapes:
|
| 350 |
+
#
|
| 351 |
+
# Case 1:
|
| 352 |
+
# if the unbacked symints is local to the subgraph e.g.
|
| 353 |
+
# def body_fn(it, x):
|
| 354 |
+
# nz = x.nonzero()
|
| 355 |
+
# return it+1. nz.sum()
|
| 356 |
+
# we can just ignore the newly created unbacked symints because it has
|
| 357 |
+
# no effect on the output of while_loop and it's tracked when we tracing.
|
| 358 |
+
# the subgraph.
|
| 359 |
+
#
|
| 360 |
+
# Case 2:
|
| 361 |
+
# if the unbacked symints are shape of output of while_loop e.g.
|
| 362 |
+
# def body_fn(it, x):
|
| 363 |
+
# nz = x.nonzero()
|
| 364 |
+
# return it+1, nz
|
| 365 |
+
# This will fail the shape check because in each iteration, the carried_input's shape
|
| 366 |
+
# must match the output shape as nz.shape contains newly allocated unbacked symint, this
|
| 367 |
+
# won't match the carried_input's shape.
|
| 368 |
+
#
|
| 369 |
+
# Case 3:
|
| 370 |
+
# if the unbacked symints are shape of carried_inputs e.g.
|
| 371 |
+
# nz = a.nonzero()
|
| 372 |
+
# body_fn(it, nz):
|
| 373 |
+
# return it+1. nz.sin() + 1,
|
| 374 |
+
# There's no new unbacked symints allocated in subgraph, so we're safe.
|
| 375 |
+
with mode.shape_env.ignore_fresh_unbacked_symbols():
|
| 376 |
+
# body_fn return output with the same pytree and tensor meta data as carried_inputs
|
| 377 |
+
# so we could just return the output after one iteration.
|
| 378 |
+
body_outs = body_fn(*carried_inputs, *additional_inputs)
|
| 379 |
+
check_meta_consistency(
|
| 380 |
+
carried_inputs,
|
| 381 |
+
body_outs,
|
| 382 |
+
"carried_inputs",
|
| 383 |
+
"body_output",
|
| 384 |
+
include_contiguity=False,
|
| 385 |
+
)
|
| 386 |
+
# See NOTE [unspecialize int carry with unbacked symints]
|
| 387 |
+
return pytree.tree_map_only(
|
| 388 |
+
(int, torch.SymInt),
|
| 389 |
+
# For while_loop's unbacked symint output, we want them to be bound
|
| 390 |
+
# to the proxy of while_loop's output.
|
| 391 |
+
lambda _: _create_unbacked_symint(
|
| 392 |
+
mode, ignore_fresh_unbacked_symbols=False
|
| 393 |
+
),
|
| 394 |
+
body_outs,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
@while_loop_op.py_functionalize_impl
|
| 399 |
+
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
|
| 400 |
+
from torch._higher_order_ops.utils import _check_alias_and_mutation
|
| 401 |
+
|
| 402 |
+
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
|
| 403 |
+
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
|
| 404 |
+
unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
|
| 405 |
+
with ctx.redispatch_to_next():
|
| 406 |
+
functional_cond_fn = ctx.functionalize(_maybe_run_with_interpreter(cond_fn))
|
| 407 |
+
functional_body_fn = ctx.functionalize(_maybe_run_with_interpreter(body_fn))
|
| 408 |
+
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
| 409 |
+
for fn, fn_name in [
|
| 410 |
+
(cond_fn, "cond_fn"),
|
| 411 |
+
(body_fn, "body_fn"),
|
| 412 |
+
]:
|
| 413 |
+
_check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch)
|
| 414 |
+
ret = while_loop_op(
|
| 415 |
+
functional_cond_fn,
|
| 416 |
+
functional_body_fn,
|
| 417 |
+
unwrapped_carried_inputs,
|
| 418 |
+
unwrapped_additional_inputs,
|
| 419 |
+
)
|
| 420 |
+
return ctx.wrap_tensors(ret)
|
archive/.venv/Lib/site-packages/torch/_higher_order_ops/wrap.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import inspect
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from torch._logging import warning_once
|
| 8 |
+
from torch._ops import HigherOrderOperator
|
| 9 |
+
from torch.types import _dtype
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
log = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
uid = itertools.count(1)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Used for testing the HigherOrderOperator mechanism
|
| 18 |
+
class Wrap(HigherOrderOperator):
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
super().__init__("wrap")
|
| 21 |
+
|
| 22 |
+
def __call__(self, func, *args, **kwargs):
|
| 23 |
+
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
| 24 |
+
# so no need to trace into it.
|
| 25 |
+
import torch._dynamo # noqa: F401
|
| 26 |
+
from torch._dynamo import disable
|
| 27 |
+
|
| 28 |
+
@disable
|
| 29 |
+
def wrapper():
|
| 30 |
+
result = func(*args, **kwargs)
|
| 31 |
+
return result
|
| 32 |
+
|
| 33 |
+
return wrapper()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
wrap = Wrap()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class WrapWithSetGradEnabled(HigherOrderOperator):
|
| 40 |
+
def __init__(self) -> None:
|
| 41 |
+
super().__init__("wrap_with_set_grad_enabled")
|
| 42 |
+
|
| 43 |
+
def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
|
| 44 |
+
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
| 45 |
+
# so no need to trace into it.
|
| 46 |
+
import torch._dynamo # noqa: F401
|
| 47 |
+
from torch._dynamo import disable
|
| 48 |
+
|
| 49 |
+
@disable
|
| 50 |
+
def wrapper():
|
| 51 |
+
with torch.set_grad_enabled(enable_grad):
|
| 52 |
+
return wrapped_func(*args, **kwargs)
|
| 53 |
+
|
| 54 |
+
return wrapper()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
wrap_with_set_grad_enabled = WrapWithSetGradEnabled()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class WrapWithAutocast(HigherOrderOperator):
|
| 61 |
+
def __init__(self):
|
| 62 |
+
super().__init__("wrap_with_autocast")
|
| 63 |
+
|
| 64 |
+
def __call__(
|
| 65 |
+
self,
|
| 66 |
+
device_type: str,
|
| 67 |
+
dtype: Optional[_dtype],
|
| 68 |
+
enabled: bool,
|
| 69 |
+
cache_enabled: Optional[bool],
|
| 70 |
+
wrapped_func,
|
| 71 |
+
*args,
|
| 72 |
+
**kwargs,
|
| 73 |
+
):
|
| 74 |
+
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
| 75 |
+
# so no need to trace into it.
|
| 76 |
+
import torch._dynamo # noqa: F401
|
| 77 |
+
from torch._dynamo import disable
|
| 78 |
+
|
| 79 |
+
@disable
|
| 80 |
+
def wrapper():
|
| 81 |
+
with torch.autocast(device_type, dtype, enabled, cache_enabled):
|
| 82 |
+
return wrapped_func(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
return wrapper()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
wrap_with_autocast = WrapWithAutocast()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# This HOP allows you to bypass dynamo tracing of the wrapper function while
|
| 91 |
+
# still tracing the inner function.
|
| 92 |
+
# Takes two callables: The first, `wrapper_fn`, accepts `inner_fn` and returns a
|
| 93 |
+
# callable with the same signature. The second is the `inner_fn` itself. Any
|
| 94 |
+
# extra *args and **kwargs are forwarded to `wrapper_fn(inner_fn)` when it is
|
| 95 |
+
# executed.
|
| 96 |
+
class DynamoBypassingWrapper(HigherOrderOperator):
|
| 97 |
+
def __init__(self):
|
| 98 |
+
super().__init__("dynamo_bypassing_wrapper")
|
| 99 |
+
|
| 100 |
+
def __call__(
|
| 101 |
+
self,
|
| 102 |
+
wrapper_fn_or_key,
|
| 103 |
+
inner_fn,
|
| 104 |
+
*args,
|
| 105 |
+
**kwargs,
|
| 106 |
+
):
|
| 107 |
+
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
| 108 |
+
# so no need to trace into it.
|
| 109 |
+
import torch._dynamo # noqa: F401
|
| 110 |
+
from torch._dynamo import disable
|
| 111 |
+
|
| 112 |
+
is_compiling = isinstance(wrapper_fn_or_key, str)
|
| 113 |
+
if is_compiling:
|
| 114 |
+
assert isinstance(inner_fn, torch.fx.GraphModule)
|
| 115 |
+
wrapper_fn = inner_fn.meta[wrapper_fn_or_key]
|
| 116 |
+
else:
|
| 117 |
+
wrapper_fn = wrapper_fn_or_key
|
| 118 |
+
|
| 119 |
+
@disable
|
| 120 |
+
def wrapper():
|
| 121 |
+
return wrapper_fn(inner_fn)(*args, **kwargs)
|
| 122 |
+
|
| 123 |
+
return wrapper()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
dynamo_bypassing_wrapper = DynamoBypassingWrapper()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class WrapActivationCheckpoint(HigherOrderOperator):
|
| 130 |
+
"""
|
| 131 |
+
This operator is used to wrap torch.utils.checkpoint. This avoids
|
| 132 |
+
TorchDynamo to look into saved tensor hooks and directly passes the control
|
| 133 |
+
to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of
|
| 134 |
+
AOT tracing torch.utils.checkpoint code, we have a backward graph with
|
| 135 |
+
recomputed forward nodes.
|
| 136 |
+
|
| 137 |
+
However, we might deprecate this operator soon. The difficulty arises in the
|
| 138 |
+
functionalization of rng ops. Today, there are two different
|
| 139 |
+
functionalization of rng ops - one at AOT autograd and other at Inductor.
|
| 140 |
+
And they are difficult to map to each other. The rng states also complicate
|
| 141 |
+
pattern matching in Inductor. Due to the ease of implementation, we are
|
| 142 |
+
currently inclined towards functionalization at Inductor level, which means
|
| 143 |
+
that duplication/recomputation is done as a compiler pass in the
|
| 144 |
+
partitioners. See TagActivationCheckpoint for more information.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(self) -> None:
|
| 148 |
+
super().__init__("wrap_activation_checkpoint", cacheable=False)
|
| 149 |
+
|
| 150 |
+
def __call__(self, function, *args, **kwargs):
|
| 151 |
+
# use_reentrant is set to False because this op is going to be traced.
|
| 152 |
+
# And we ensure that AOT Autograd traces through the non reentrant
|
| 153 |
+
# version of checkpointing.
|
| 154 |
+
import torch.fx.traceback as fx_traceback
|
| 155 |
+
from torch.fx import Interpreter
|
| 156 |
+
|
| 157 |
+
kwargs["use_reentrant"] = False
|
| 158 |
+
kwargs["preserve_rng_state"] = False
|
| 159 |
+
# Using interpreter allows preservation of metadata through torch.compile stack.
|
| 160 |
+
with fx_traceback.preserve_node_meta():
|
| 161 |
+
from torch.utils.checkpoint import checkpoint
|
| 162 |
+
|
| 163 |
+
return checkpoint(Interpreter(function).run, *args, **kwargs)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
wrap_activation_checkpoint = WrapActivationCheckpoint()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class TagActivationCheckpoint(HigherOrderOperator):
|
| 170 |
+
"""
|
| 171 |
+
This operator is supposed to be used only with torch.compile stack. This
|
| 172 |
+
accepts a Fx graph module which needs to be checkpointed. This operator adds
|
| 173 |
+
"recomputable" tag to the nodes of the Fx graph that should be recomputed.
|
| 174 |
+
|
| 175 |
+
The goal is to:
|
| 176 |
+
1. Avoid using Dynamo to trace through saved tensor hooks.
|
| 177 |
+
2. For selective checkpointing case, let AOTAutograd trace through
|
| 178 |
+
saved tensor hooks but has special logic with TorchDispatchMode to override
|
| 179 |
+
the usual saved_tensor_hooks fn logic in order to tag the nodes.
|
| 180 |
+
3. Rely on the partitioners to actually duplicate the nodes.
|
| 181 |
+
This sits well in the torch.compile stack, because by the time graph
|
| 182 |
+
reaches partitioner, inductor has already run its functionalization of rng
|
| 183 |
+
ops (by setting fixed seed for each random op, see `replace_random_passes`).
|
| 184 |
+
Therefore, the duplication of nodes, by design, respects the rng states in
|
| 185 |
+
the forward and recomputed forward in backward.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(self) -> None:
|
| 189 |
+
super().__init__("tag_activation_checkpoint", cacheable=False)
|
| 190 |
+
|
| 191 |
+
@staticmethod
|
| 192 |
+
def divide_kwargs(kwargs):
|
| 193 |
+
"""
|
| 194 |
+
checkpoint fn can have mixed kwargs between checkpointed fn and
|
| 195 |
+
checkpoint fn itself. For example
|
| 196 |
+
>> def gn(x, y, z=None):
|
| 197 |
+
>> a = torch.matmul(x, y)
|
| 198 |
+
>> if z is not None:
|
| 199 |
+
>> return torch.matmul(a, z)
|
| 200 |
+
>> return a
|
| 201 |
+
>> def fn(x, y, z):
|
| 202 |
+
>> return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
|
| 203 |
+
In the above case, z belongs to checkpointed function gn, but
|
| 204 |
+
use_reentrant belongs to the checkpoint function. This function splits
|
| 205 |
+
the kwargs into checkpoint_kwargs and gmod_kwargs (or
|
| 206 |
+
checkpointed_fn_kwargs).
|
| 207 |
+
We do sorting to ensure same graph from run to run for better
|
| 208 |
+
debuggability. It is not required for correctness.
|
| 209 |
+
"""
|
| 210 |
+
from torch.utils.checkpoint import checkpoint
|
| 211 |
+
|
| 212 |
+
ckpt_signature = inspect.signature(checkpoint)
|
| 213 |
+
checkpoint_keys = set()
|
| 214 |
+
for name in ckpt_signature.parameters:
|
| 215 |
+
if name in ("function", "args", "kwargs"):
|
| 216 |
+
continue
|
| 217 |
+
checkpoint_keys.add(name)
|
| 218 |
+
|
| 219 |
+
# `preserve_rng_state` is not a regular kwarg
|
| 220 |
+
checkpoint_keys.add("preserve_rng_state")
|
| 221 |
+
|
| 222 |
+
checkpoint_kwargs = {
|
| 223 |
+
name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys
|
| 224 |
+
}
|
| 225 |
+
gmod_kwargs = {
|
| 226 |
+
name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys
|
| 227 |
+
}
|
| 228 |
+
return checkpoint_kwargs, gmod_kwargs
|
| 229 |
+
|
| 230 |
+
def tag_nodes(self, gmod, is_sac):
|
| 231 |
+
from torch.utils.checkpoint import CheckpointPolicy
|
| 232 |
+
|
| 233 |
+
unique_graph_id = next(uid)
|
| 234 |
+
for node in gmod.graph.nodes:
|
| 235 |
+
if node.op in ("call_function", "call_method", "call_module"):
|
| 236 |
+
node.meta["ac_graph_id"] = unique_graph_id
|
| 237 |
+
if is_sac:
|
| 238 |
+
# For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode.
|
| 239 |
+
node.meta["recompute"] = None
|
| 240 |
+
else:
|
| 241 |
+
# Under vanilla activation checkpointing, all nodes should be recomputed.
|
| 242 |
+
node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
|
| 243 |
+
return gmod
|
| 244 |
+
|
| 245 |
+
def __call__(self, gmod, *args, **kwargs):
|
| 246 |
+
import torch.fx.traceback as fx_traceback
|
| 247 |
+
from torch.fx import Interpreter
|
| 248 |
+
|
| 249 |
+
if "_checkpoint_context_fn" in gmod.meta:
|
| 250 |
+
warning_once(
|
| 251 |
+
log,
|
| 252 |
+
"""
|
| 253 |
+
Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
|
| 254 |
+
Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
|
| 255 |
+
""",
|
| 256 |
+
)
|
| 257 |
+
# use_reentrant is set to False because this op is going to be traced.
|
| 258 |
+
# And we ensure that AOT Autograd traces through the non reentrant
|
| 259 |
+
# version of checkpointing.
|
| 260 |
+
kwargs["use_reentrant"] = False
|
| 261 |
+
# preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through
|
| 262 |
+
# `torch.random.fork_rng` op (which is not supported yet under CUDA).
|
| 263 |
+
# This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state
|
| 264 |
+
# regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor
|
| 265 |
+
# instead of in AOTAutograd).
|
| 266 |
+
kwargs["preserve_rng_state"] = False
|
| 267 |
+
kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"]
|
| 268 |
+
# We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag
|
| 269 |
+
# for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py.
|
| 270 |
+
gmod = self.tag_nodes(gmod, is_sac=True)
|
| 271 |
+
# Using interpreter allows preservation of metadata through torch.compile stack.
|
| 272 |
+
with fx_traceback.preserve_node_meta():
|
| 273 |
+
from torch.utils.checkpoint import checkpoint
|
| 274 |
+
|
| 275 |
+
return checkpoint(Interpreter(gmod).run, *args, **kwargs)
|
| 276 |
+
else:
|
| 277 |
+
gmod = self.tag_nodes(gmod, is_sac=False)
|
| 278 |
+
# Using interpreter allows preservation of metadata through torch.compile stack.
|
| 279 |
+
# TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here
|
| 280 |
+
# as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile.
|
| 281 |
+
# (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test)
|
| 282 |
+
with fx_traceback.preserve_node_meta():
|
| 283 |
+
return Interpreter(gmod).run(*args)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
tag_activation_checkpoint = TagActivationCheckpoint()
|