Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/auto_functionalize.py +261 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/cond.py +349 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/effects.py +204 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/map.py +358 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/strict_mode.py +100 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/torchbind.py +94 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py +842 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/while_loop.py +232 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mkl/__init__.py +56 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mps/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/openmp/__init__.py +6 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/autograd/__init__.py +52 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/api.py +112 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/handlers.py +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py +201 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py +375 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py +16 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +32 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__init__.py +44 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/local_timer.py +125 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/__init__.py +4 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/api/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/templates/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/api.py +108 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .cond import cond
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (263 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-311.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-311.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-311.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-311.pyc
ADDED
|
Binary file (5.79 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-311.pyc
ADDED
|
Binary file (42 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/auto_functionalize.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.utils._pytree as pytree
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch._C import DispatchKey
|
| 7 |
+
from torch._ops import HigherOrderOperator
|
| 8 |
+
from torch._prims_common import clone_preserve_strides
|
| 9 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 10 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 11 |
+
disable_proxy_modes_tracing,
|
| 12 |
+
ProxyTorchDispatchMode,
|
| 13 |
+
track_tensor_tree,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# NOTE: [auto-functionalizing custom ops]
|
| 18 |
+
# Users may wish to torch.compile custom ops that mutate their inputs.
|
| 19 |
+
# torch.compile will automatically support this op without anyone needing
|
| 20 |
+
# to provide a functionalization kernel for it. Here's how.
|
| 21 |
+
#
|
| 22 |
+
# Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
|
| 23 |
+
# op. First, when FakeTensor sees this op:
|
| 24 |
+
# - If the schema says it returns nothing, we can generate a trivial
|
| 25 |
+
# FakeTensor rule for it (that returns nothing).
|
| 26 |
+
# - Otherwise, the user needs to provide a FakeTensor rule (abstract impl)
|
| 27 |
+
#
|
| 28 |
+
# Next, when Python FunctionalTensor sees the op, it will functionalize
|
| 29 |
+
# it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
|
| 30 |
+
# HOP and replacing the mutated inputs with corresponding outputs of this HOP.
|
| 31 |
+
# This HOP effectively runs the functional version of the op when
|
| 32 |
+
# called: it clones inputs that will be mutated, runs the op, and
|
| 33 |
+
# then returns (output, Tensors with the new values)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AutoFunctionalized(HigherOrderOperator):
|
| 37 |
+
"""auto_functionalized(_mutable_op, **kwargs)
|
| 38 |
+
|
| 39 |
+
This HOP runs a "functional" version of _mutable_op.
|
| 40 |
+
|
| 41 |
+
Concretely, it looks at all the arguments that are mutable through
|
| 42 |
+
_mutable_op's operator schema, clones those kwargs, runs
|
| 43 |
+
`out = _mutable_op(**kwargs)` with the cloned values, and then returns the
|
| 44 |
+
operator output concatenated with the cloned values that were mutated.
|
| 45 |
+
|
| 46 |
+
We have some restrictions on `_mutable_op`.
|
| 47 |
+
See `can_auto_functionalize` for the restrictions. We can likely lift
|
| 48 |
+
many of these if users request it.
|
| 49 |
+
|
| 50 |
+
The reason why _mutable_op is prefixed with an
|
| 51 |
+
underscore is to prevent collisions with kwarg names in **kwargs.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self):
|
| 55 |
+
super().__init__("auto_functionalized")
|
| 56 |
+
|
| 57 |
+
def __call__(
|
| 58 |
+
self,
|
| 59 |
+
_mutable_op: torch._ops.OpOverload,
|
| 60 |
+
**kwargs: Dict[str, Any],
|
| 61 |
+
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
| 62 |
+
assert can_auto_functionalize(_mutable_op)
|
| 63 |
+
assert isinstance(kwargs, dict)
|
| 64 |
+
return super().__call__(_mutable_op, **kwargs)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
auto_functionalized = AutoFunctionalized()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def can_auto_functionalize(op: torch._ops.OperatorBase) -> bool:
|
| 71 |
+
if not isinstance(op, torch._ops.OpOverload):
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
if torch._library.utils.is_builtin(op):
|
| 75 |
+
# We control the built-ins. These may (in rare cases)
|
| 76 |
+
# do input metadata mutation (which we have banned on custom ops)
|
| 77 |
+
return False
|
| 78 |
+
schema = op._schema
|
| 79 |
+
if not schema.is_mutable:
|
| 80 |
+
return False
|
| 81 |
+
schema = op._schema
|
| 82 |
+
|
| 83 |
+
for arg in schema.arguments:
|
| 84 |
+
if arg.alias_info is None:
|
| 85 |
+
continue
|
| 86 |
+
if not arg.alias_info.is_write:
|
| 87 |
+
continue
|
| 88 |
+
if type(arg.type) is torch.TensorType:
|
| 89 |
+
continue
|
| 90 |
+
if (
|
| 91 |
+
type(arg.type) is torch.OptionalType
|
| 92 |
+
and type(arg.type.getElementType()) is torch.TensorType
|
| 93 |
+
):
|
| 94 |
+
continue
|
| 95 |
+
# Not yet supported: other Tensor types. This includes things like
|
| 96 |
+
# Tensor[], Tensor?[], Tensor[]?.
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
# The returns must not alias anything
|
| 100 |
+
for ret in schema.returns:
|
| 101 |
+
if ret.alias_info is None and type(ret.type) is torch.TensorType:
|
| 102 |
+
continue
|
| 103 |
+
# Not yet supported: List[Tensor] return.
|
| 104 |
+
return False
|
| 105 |
+
return True
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 109 |
+
def auto_functionalized_dense(
|
| 110 |
+
_mutable_op: torch._ops.OpOverload,
|
| 111 |
+
_only_clone_these_tensors: Optional[Tuple[str, ...]] = None,
|
| 112 |
+
**kwargs: Dict[str, Any],
|
| 113 |
+
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
| 114 |
+
new_kwargs = dict(**kwargs)
|
| 115 |
+
result = []
|
| 116 |
+
|
| 117 |
+
_mutable_args_names = get_mutable_arg_names(_mutable_op)
|
| 118 |
+
for name in _mutable_args_names:
|
| 119 |
+
if (
|
| 120 |
+
_only_clone_these_tensors is not None
|
| 121 |
+
and name not in _only_clone_these_tensors
|
| 122 |
+
):
|
| 123 |
+
new_kwargs[name] = kwargs[name]
|
| 124 |
+
else:
|
| 125 |
+
new_kwargs[name] = (
|
| 126 |
+
clone_preserve_strides(kwargs[name])
|
| 127 |
+
if kwargs[name] is not None
|
| 128 |
+
else None
|
| 129 |
+
)
|
| 130 |
+
result.append(new_kwargs[name])
|
| 131 |
+
out = _mutable_op(**new_kwargs)
|
| 132 |
+
|
| 133 |
+
if isinstance(out, tuple):
|
| 134 |
+
return (*out, *result) # type: ignore[return-value]
|
| 135 |
+
else:
|
| 136 |
+
return (out, *result) # type: ignore[return-value]
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@auto_functionalized.py_impl(FakeTensorMode)
|
| 140 |
+
def auto_functionalized_fake(
|
| 141 |
+
mode,
|
| 142 |
+
_mutable_op: torch._ops.OpOverload,
|
| 143 |
+
**kwargs: Dict[str, Any],
|
| 144 |
+
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
| 145 |
+
with mode:
|
| 146 |
+
result = auto_functionalized_dense(_mutable_op, **kwargs)
|
| 147 |
+
return result
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@auto_functionalized.py_impl(ProxyTorchDispatchMode)
|
| 151 |
+
def auto_functionalized_proxy(
|
| 152 |
+
mode,
|
| 153 |
+
_mutable_op: torch._ops.OpOverload,
|
| 154 |
+
**kwargs: Dict[str, Any],
|
| 155 |
+
) -> Tuple[Any, Tuple[Tensor, ...]]:
|
| 156 |
+
if not mode.enable_tracing:
|
| 157 |
+
return auto_functionalized(_mutable_op, **kwargs)
|
| 158 |
+
|
| 159 |
+
with disable_proxy_modes_tracing():
|
| 160 |
+
out = auto_functionalized(_mutable_op, **kwargs)
|
| 161 |
+
|
| 162 |
+
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
|
| 163 |
+
out_proxy = mode.tracer.create_proxy(
|
| 164 |
+
"call_function",
|
| 165 |
+
auto_functionalized,
|
| 166 |
+
(_mutable_op,),
|
| 167 |
+
proxy_kwargs,
|
| 168 |
+
)
|
| 169 |
+
result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
| 170 |
+
return result
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
auto_functionalized.fallthrough(DispatchKey.AutogradCPU)
|
| 174 |
+
auto_functionalized.fallthrough(DispatchKey.AutogradCUDA)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_mutable_arg_names(op: torch._ops.OpOverload) -> List[str]:
|
| 178 |
+
"""
|
| 179 |
+
Returns the list of argument names that get mutated according to the
|
| 180 |
+
schema.
|
| 181 |
+
"""
|
| 182 |
+
mutable_args_names = [
|
| 183 |
+
arg.name
|
| 184 |
+
for arg in op._schema.arguments
|
| 185 |
+
if arg.alias_info is not None and arg.alias_info.is_write
|
| 186 |
+
]
|
| 187 |
+
return mutable_args_names
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def do_auto_functionalize(
|
| 191 |
+
op: torch._ops.OpOverload, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
| 192 |
+
) -> Any:
|
| 193 |
+
"""Functionalizes a call to op(*args, **kwargs) by emitting a call to
|
| 194 |
+
`outs = auto_functionalized(op, normalized_kwargs)`
|
| 195 |
+
and replacing the mutated (args, kwargs) with the corresponding outputs.
|
| 196 |
+
|
| 197 |
+
The normalized_kwargs are just the (args, kwargs), but all in kwarg form.
|
| 198 |
+
This makes handling easier for the auto_functionalized HOP.
|
| 199 |
+
"""
|
| 200 |
+
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
|
| 201 |
+
|
| 202 |
+
ctx = PythonFunctionalizeAPI()
|
| 203 |
+
|
| 204 |
+
# All of the (args, kwargs), but all as kwargs. The names for the
|
| 205 |
+
# args come from the schema. This makes it easier for us to work with them.
|
| 206 |
+
normalized_kwargs = {}
|
| 207 |
+
schema = op._schema
|
| 208 |
+
for idx, arg in enumerate(schema.arguments):
|
| 209 |
+
# NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema
|
| 210 |
+
if arg.name in kwargs:
|
| 211 |
+
normalized_kwargs[arg.name] = kwargs[arg.name]
|
| 212 |
+
elif idx < len(args):
|
| 213 |
+
# if its out of bounds we don't need to do anything
|
| 214 |
+
# as it means the the optional arg was passed with its default
|
| 215 |
+
# value
|
| 216 |
+
normalized_kwargs[arg.name] = args[idx]
|
| 217 |
+
else:
|
| 218 |
+
normalized_kwargs[arg.name] = arg.default_value
|
| 219 |
+
|
| 220 |
+
unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
|
| 221 |
+
with ctx.redispatch_to_next():
|
| 222 |
+
unwrapped_outs = auto_functionalized(
|
| 223 |
+
op, **unwrapped_kwargs # type: ignore[arg-type]
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# List of the name of args that get mutated (according to the schema)
|
| 227 |
+
mutable_args_names = get_mutable_arg_names(op)
|
| 228 |
+
|
| 229 |
+
unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[
|
| 230 |
+
: -len(mutable_args_names)
|
| 231 |
+
]
|
| 232 |
+
unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :]
|
| 233 |
+
|
| 234 |
+
if len(op._schema.returns) == 0:
|
| 235 |
+
assert unwrapped_actual_out[0] is None
|
| 236 |
+
unwrapped_actual_out = None
|
| 237 |
+
elif len(op._schema.returns) == 1:
|
| 238 |
+
assert len(unwrapped_actual_out) == 1
|
| 239 |
+
unwrapped_actual_out = unwrapped_actual_out[0]
|
| 240 |
+
else:
|
| 241 |
+
assert len(unwrapped_actual_out) == len(op._schema.returns)
|
| 242 |
+
|
| 243 |
+
for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out):
|
| 244 |
+
# Can be None if input was `Tensor(a!)?`
|
| 245 |
+
if unwrapped_out is None:
|
| 246 |
+
continue
|
| 247 |
+
assert isinstance(unwrapped_out, torch.Tensor)
|
| 248 |
+
orig_arg = normalized_kwargs[name]
|
| 249 |
+
ctx.replace(orig_arg, unwrapped_out)
|
| 250 |
+
ctx.commit_update(orig_arg)
|
| 251 |
+
ctx.sync(orig_arg)
|
| 252 |
+
|
| 253 |
+
return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@auto_functionalized.py_functionalize_impl
|
| 257 |
+
def auto_functionalized_func(ctx, _mutable_op, **kwargs):
|
| 258 |
+
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
| 259 |
+
with ctx.redispatch_to_next():
|
| 260 |
+
result = auto_functionalized(_mutable_op, **unwrapped_kwargs)
|
| 261 |
+
return ctx.wrap_tensors(result)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/cond.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch._subclasses.functional_tensor
|
| 3 |
+
|
| 4 |
+
import torch.utils._pytree as pytree
|
| 5 |
+
|
| 6 |
+
from torch._C import DispatchKey
|
| 7 |
+
from torch._C._functorch import (
|
| 8 |
+
_add_batch_dim,
|
| 9 |
+
get_unwrapped,
|
| 10 |
+
is_batchedtensor,
|
| 11 |
+
maybe_get_bdim,
|
| 12 |
+
)
|
| 13 |
+
from torch._functorch.utils import exposed_in
|
| 14 |
+
|
| 15 |
+
from torch._higher_order_ops.utils import (
|
| 16 |
+
_has_potential_branch_input_alias,
|
| 17 |
+
_has_potential_branch_input_mutation,
|
| 18 |
+
_set_compilation_env,
|
| 19 |
+
autograd_not_implemented,
|
| 20 |
+
reenter_make_fx,
|
| 21 |
+
UnsupportedAliasMutationException,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
from torch._ops import HigherOrderOperator
|
| 25 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 26 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 27 |
+
disable_proxy_modes_tracing,
|
| 28 |
+
ProxyTorchDispatchMode,
|
| 29 |
+
track_tensor_tree,
|
| 30 |
+
)
|
| 31 |
+
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
| 32 |
+
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@exposed_in("torch")
|
| 36 |
+
def cond(pred, true_fn, false_fn, operands):
|
| 37 |
+
r"""
|
| 38 |
+
Conditionally applies `true_fn` or `false_fn`.
|
| 39 |
+
|
| 40 |
+
.. warning::
|
| 41 |
+
`torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
|
| 42 |
+
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
|
| 43 |
+
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
| 44 |
+
|
| 45 |
+
`cond` is structured control flow operator. That is, it is like a Python if-statement,
|
| 46 |
+
but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be
|
| 47 |
+
capturable using torch.compile and torch.export.
|
| 48 |
+
|
| 49 |
+
Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following::
|
| 50 |
+
|
| 51 |
+
def cond(pred, true_branch, false_branch, operands):
|
| 52 |
+
if pred:
|
| 53 |
+
return true_branch(*operands)
|
| 54 |
+
else:
|
| 55 |
+
return false_branch(*operands)
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element,
|
| 59 |
+
indicating which branch function to apply.
|
| 60 |
+
|
| 61 |
+
true_fn (Callable): A callable function (a -> b) that is within the
|
| 62 |
+
scope that is being traced.
|
| 63 |
+
|
| 64 |
+
false_fn (Callable): A callable function (a -> b) that is within the
|
| 65 |
+
scope that is being traced. The true branch and false branch must
|
| 66 |
+
have consistent input and outputs, meaning the inputs have to be
|
| 67 |
+
the same, and the outputs have to be the same type and shape.
|
| 68 |
+
|
| 69 |
+
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions.
|
| 70 |
+
|
| 71 |
+
Example::
|
| 72 |
+
|
| 73 |
+
def true_fn(x: torch.Tensor):
|
| 74 |
+
return x.cos()
|
| 75 |
+
def false_fn(x: torch.Tensor):
|
| 76 |
+
return x.sin()
|
| 77 |
+
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
|
| 78 |
+
|
| 79 |
+
Restrictions:
|
| 80 |
+
- The conditional statement (aka `pred`) must meet one of the following constraints:
|
| 81 |
+
|
| 82 |
+
- It's a `torch.Tensor` with only one element, and torch.bool dtype
|
| 83 |
+
|
| 84 |
+
- It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10`
|
| 85 |
+
|
| 86 |
+
- The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints:
|
| 87 |
+
|
| 88 |
+
- The function signature must match with operands.
|
| 89 |
+
|
| 90 |
+
- The function must return a tensor with the same metadata, e.g. shape,
|
| 91 |
+
dtype, etc.
|
| 92 |
+
|
| 93 |
+
- The function cannot have in-place mutations on inputs or global variables.
|
| 94 |
+
(Note: in-place tensor operations such as `add_` for intermediate results
|
| 95 |
+
are allowed in a branch)
|
| 96 |
+
|
| 97 |
+
.. warning::
|
| 98 |
+
Temporal Limitations:
|
| 99 |
+
|
| 100 |
+
- `cond` only supports **inference** right now. Autograd will be supported in the future.
|
| 101 |
+
|
| 102 |
+
- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
|
| 103 |
+
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
if torch.compiler.is_dynamo_compiling():
|
| 107 |
+
return cond_op(pred, true_fn, false_fn, operands)
|
| 108 |
+
|
| 109 |
+
def _validate_input(pred, true_fn, false_fn, operands):
|
| 110 |
+
if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
|
| 111 |
+
raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
|
| 112 |
+
|
| 113 |
+
if isinstance(pred, torch.Tensor) and pred.numel() != 1:
|
| 114 |
+
raise RuntimeError(
|
| 115 |
+
f"Expected pred to be bool or single-element tensor, but got {pred}."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if not callable(true_fn) or not callable(false_fn):
|
| 119 |
+
raise RuntimeError("Expect both branches to be callbale.")
|
| 120 |
+
|
| 121 |
+
if not isinstance(operands, (tuple, list)) or pytree.tree_any(
|
| 122 |
+
lambda t: not isinstance(t, torch.Tensor), operands
|
| 123 |
+
):
|
| 124 |
+
raise RuntimeError(
|
| 125 |
+
"Expect operands to be a tuple of possibly nested dict/list/tuple that only"
|
| 126 |
+
f"consists of tensor leaves, but got {operands}."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
_validate_input(pred, true_fn, false_fn, operands)
|
| 130 |
+
|
| 131 |
+
if not torch._dynamo.is_dynamo_supported():
|
| 132 |
+
raise RuntimeError("torch.cond requires dynamo support.")
|
| 133 |
+
|
| 134 |
+
with _set_compilation_env():
|
| 135 |
+
with torch._dynamo.utils.disable_cache_limit():
|
| 136 |
+
return torch.compile(cond_op, backend="eager", fullgraph=True)(
|
| 137 |
+
pred, true_fn, false_fn, operands
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
"""
|
| 142 |
+
We're going to define a `cond_op` operation.
|
| 143 |
+
In order to do this, we need implementations for each of the dispatch keys.
|
| 144 |
+
"""
|
| 145 |
+
cond_op = HigherOrderOperator("cond")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
| 149 |
+
assert isinstance(
|
| 150 |
+
operands, (list, tuple)
|
| 151 |
+
), "Cond operands must be a list or tuple of tensors"
|
| 152 |
+
assert all(
|
| 153 |
+
isinstance(o, torch.Tensor) for o in operands
|
| 154 |
+
), "Cond operands must be a list of tensors"
|
| 155 |
+
|
| 156 |
+
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
|
| 157 |
+
|
| 158 |
+
with disable_proxy_modes_tracing():
|
| 159 |
+
true_graph = reenter_make_fx(true_fn, pre_dispatch)(*operands)
|
| 160 |
+
false_graph = reenter_make_fx(false_fn, pre_dispatch)(*operands)
|
| 161 |
+
|
| 162 |
+
true_outs = []
|
| 163 |
+
false_outs = []
|
| 164 |
+
for node in true_graph.graph.nodes:
|
| 165 |
+
if node.op == "output":
|
| 166 |
+
true_outs.extend(node.args)
|
| 167 |
+
|
| 168 |
+
for node in false_graph.graph.nodes:
|
| 169 |
+
if node.op == "output":
|
| 170 |
+
false_outs.extend(node.args)
|
| 171 |
+
|
| 172 |
+
flat_true_outs = pytree.arg_tree_leaves(*true_outs)
|
| 173 |
+
flat_false_outs = pytree.arg_tree_leaves(*false_outs)
|
| 174 |
+
if len(flat_true_outs) != len(flat_false_outs):
|
| 175 |
+
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
| 176 |
+
f"Expected to return same number of outputs but got:"
|
| 177 |
+
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
|
| 178 |
+
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
for i in range(0, len(flat_true_outs)):
|
| 182 |
+
true_out = flat_true_outs[i]
|
| 183 |
+
false_out = flat_false_outs[i]
|
| 184 |
+
if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
|
| 185 |
+
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
| 186 |
+
f"Expected each tensor to have same metadata but got:"
|
| 187 |
+
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
|
| 188 |
+
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# There are probably better ways - I know that create_arg has some self incrementing name
|
| 192 |
+
# magic to it, but since we explicitly have to get the name for register_module,
|
| 193 |
+
# I was not sure how to do that. This kinda simulates it.
|
| 194 |
+
next_name = None
|
| 195 |
+
i = 0
|
| 196 |
+
while not next_name:
|
| 197 |
+
candidate = f"true_graph_{i}"
|
| 198 |
+
if hasattr(proxy_mode.tracer.root, candidate):
|
| 199 |
+
i += 1
|
| 200 |
+
else:
|
| 201 |
+
next_name = candidate
|
| 202 |
+
|
| 203 |
+
true_name = next_name
|
| 204 |
+
false_name = f"false_graph_{i}"
|
| 205 |
+
assert not hasattr(proxy_mode.tracer.root, false_name)
|
| 206 |
+
|
| 207 |
+
proxy_mode.tracer.root.register_module(true_name, true_graph)
|
| 208 |
+
proxy_mode.tracer.root.register_module(false_name, false_graph)
|
| 209 |
+
|
| 210 |
+
args = (pred, true_graph, false_graph, operands)
|
| 211 |
+
|
| 212 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
| 213 |
+
|
| 214 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 215 |
+
"call_function", func_overload, proxy_args, {}, name="conditional"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# At this point, we're *guaranteed* that whether an output came from the
|
| 219 |
+
# true or false branch is indistinguishable. So, as this is just for tracing
|
| 220 |
+
# purposes, choose the true branch.
|
| 221 |
+
|
| 222 |
+
# TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
|
| 223 |
+
# a FakeTensorMode error :
|
| 224 |
+
# `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
|
| 225 |
+
# TODO Sometimes the operands are not completely FakeTensor, something seems went wrong in
|
| 226 |
+
# dynamo? Because of that it runs real computation sometimes and re-triggering downstream dispatch keys.
|
| 227 |
+
out = false_fn(*operands)
|
| 228 |
+
|
| 229 |
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 233 |
+
def cond_op_dense(pred, true_fn, false_fn, operands):
|
| 234 |
+
mode = _get_current_dispatch_mode()
|
| 235 |
+
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
| 236 |
+
if pred:
|
| 237 |
+
return true_fn(*operands)
|
| 238 |
+
else:
|
| 239 |
+
return false_fn(*operands)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
cond_op.py_impl(DispatchKey.Autograd)(
|
| 243 |
+
autograd_not_implemented(cond_op, deferred_error=True)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
@cond_op.py_impl(ProxyTorchDispatchMode)
|
| 248 |
+
def inner(mode, pred, true_fn, false_fn, operands):
|
| 249 |
+
if mode.enable_tracing:
|
| 250 |
+
return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
|
| 251 |
+
else:
|
| 252 |
+
return cond_op(pred, true_fn, false_fn, operands)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@cond_op.py_impl(FakeTensorMode)
|
| 256 |
+
def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
|
| 257 |
+
with mode:
|
| 258 |
+
true_outs = true_fn(*operands)
|
| 259 |
+
flat_true_outs = pytree.tree_leaves(true_outs)
|
| 260 |
+
flat_false_outs = pytree.tree_leaves(false_fn(*operands))
|
| 261 |
+
if len(flat_true_outs) != len(flat_false_outs):
|
| 262 |
+
raise RuntimeError("Unmatched number of outputs from cond() branches.")
|
| 263 |
+
|
| 264 |
+
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
|
| 265 |
+
true_meta = _extract_tensor_metadata(true_out)
|
| 266 |
+
false_meta = _extract_tensor_metadata(false_out)
|
| 267 |
+
if true_meta != false_meta:
|
| 268 |
+
raise torch._dynamo.exc.CondOpArgsMismatchError(
|
| 269 |
+
f"Expected each tensor to have same metadata but got:"
|
| 270 |
+
f"\n {true_fn.__name__} returns {true_meta}"
|
| 271 |
+
f"\n {false_fn.__name__} returns {false_meta}"
|
| 272 |
+
)
|
| 273 |
+
return true_outs
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@cond_op.py_functionalize_impl
|
| 277 |
+
def cond_func(ctx, pred, true_fn, false_fn, inputs):
|
| 278 |
+
unwrapped_inputs = ctx.unwrap_tensors(inputs)
|
| 279 |
+
unwrapped_pred = ctx.unwrap_tensors(pred)
|
| 280 |
+
with ctx.redispatch_to_next() as m:
|
| 281 |
+
functional_true = ctx.functionalize(true_fn)
|
| 282 |
+
functional_false = ctx.functionalize(false_fn)
|
| 283 |
+
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
| 284 |
+
for branch in [functional_true, functional_false]:
|
| 285 |
+
if _has_potential_branch_input_mutation(
|
| 286 |
+
branch, unwrapped_inputs, pre_dispatch=pre_dispatch
|
| 287 |
+
):
|
| 288 |
+
raise UnsupportedAliasMutationException(
|
| 289 |
+
"One of torch.cond branch might be modifying the input!"
|
| 290 |
+
)
|
| 291 |
+
for branch in [true_fn, false_fn]:
|
| 292 |
+
if _has_potential_branch_input_alias(
|
| 293 |
+
branch, unwrapped_inputs, pre_dispatch=pre_dispatch
|
| 294 |
+
):
|
| 295 |
+
raise UnsupportedAliasMutationException(
|
| 296 |
+
"One of torch.cond branch might be aliasing the input!"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
cond_return = cond_op(
|
| 300 |
+
unwrapped_pred, functional_true, functional_false, unwrapped_inputs
|
| 301 |
+
)
|
| 302 |
+
return ctx.wrap_tensors(cond_return)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
@cond_op.py_impl(torch._C._functorch.TransformType.Vmap)
|
| 306 |
+
def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs):
|
| 307 |
+
assert isinstance(
|
| 308 |
+
inputs, (list, tuple)
|
| 309 |
+
), "Cond inputs must be a list or tuple of tensors"
|
| 310 |
+
assert all(
|
| 311 |
+
isinstance(i, torch.Tensor) for i in inputs
|
| 312 |
+
), "Cond inputs must be a list of tensors"
|
| 313 |
+
|
| 314 |
+
pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred
|
| 315 |
+
|
| 316 |
+
# unbatched tensors are not vmapped
|
| 317 |
+
tensors, in_dims = zip(
|
| 318 |
+
*[
|
| 319 |
+
(get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None)
|
| 320 |
+
for t in inputs
|
| 321 |
+
]
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if is_batchedtensor(pred):
|
| 325 |
+
# prepend "pred" and vmap everything
|
| 326 |
+
tensors = (pred_,) + tensors
|
| 327 |
+
in_dims = (0,) + in_dims
|
| 328 |
+
|
| 329 |
+
def fn(p, *args):
|
| 330 |
+
t = true_fn(*args)
|
| 331 |
+
f = false_fn(*args)
|
| 332 |
+
return torch.where(p, t[0], f[0])
|
| 333 |
+
|
| 334 |
+
with interpreter.lower():
|
| 335 |
+
result = torch.vmap(fn, in_dims=in_dims)(*tensors)
|
| 336 |
+
|
| 337 |
+
else:
|
| 338 |
+
# predicate is known at this stage and it is a boolean expression or a
|
| 339 |
+
# tensor with one element.
|
| 340 |
+
true_fn = torch.vmap(true_fn, in_dims=in_dims)
|
| 341 |
+
false_fn = torch.vmap(false_fn, in_dims=in_dims)
|
| 342 |
+
|
| 343 |
+
with interpreter.lower():
|
| 344 |
+
result = cond_op(pred, true_fn, false_fn, tensors)
|
| 345 |
+
|
| 346 |
+
if not isinstance(result, tuple):
|
| 347 |
+
result = (result,)
|
| 348 |
+
lvl = interpreter.level()
|
| 349 |
+
return tuple([_add_batch_dim(r, 0, lvl) for r in result])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/effects.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import Any, Dict, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils._pytree as pytree
|
| 6 |
+
from torch._C import DispatchKey
|
| 7 |
+
from torch._ops import HigherOrderOperator
|
| 8 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 9 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 10 |
+
disable_proxy_modes_tracing,
|
| 11 |
+
ProxyTorchDispatchMode,
|
| 12 |
+
track_tensor_tree,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class _EffectType(Enum):
|
| 17 |
+
ORDERED = "Ordered"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = {
|
| 21 |
+
torch.ops.aten._print.default: _EffectType.ORDERED,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class WithEffects(HigherOrderOperator):
|
| 26 |
+
"""
|
| 27 |
+
with_effects(token, op, args, kwargs) -> (new_token, op_results)
|
| 28 |
+
|
| 29 |
+
This HOP helps ensure ordering between side effectful ops like prints or ops
|
| 30 |
+
using torchbind objects. This is needed to ensure a traced graph from
|
| 31 |
+
AOTAutograd is functional so that future optimization passes do not reorder
|
| 32 |
+
these operators. This is done through threading "effect tokens" through the
|
| 33 |
+
graph to enforce data dependence between side effectful ops.
|
| 34 |
+
|
| 35 |
+
The tokens are basically dummy values (torch.tensor([])). We create a token
|
| 36 |
+
per "effect type", which are enumerated in the _EffectType enum.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
super().__init__("with_effects")
|
| 41 |
+
|
| 42 |
+
def __call__(
|
| 43 |
+
self,
|
| 44 |
+
token,
|
| 45 |
+
op: torch._ops.OpOverload,
|
| 46 |
+
*args: Tuple[Any, ...],
|
| 47 |
+
**kwargs: Dict[str, Any],
|
| 48 |
+
) -> Tuple[Any, ...]:
|
| 49 |
+
assert isinstance(op, torch._ops.OpOverload)
|
| 50 |
+
assert not has_aliasing(op), "Ops with aliasing is not supported"
|
| 51 |
+
assert has_effects(op, args, kwargs)
|
| 52 |
+
assert isinstance(kwargs, dict)
|
| 53 |
+
return super().__call__(token, op, *args, **kwargs)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
with_effects = WithEffects()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def has_aliasing(op: torch._ops.OpOverload):
|
| 60 |
+
for arg in op._schema.arguments:
|
| 61 |
+
if arg.alias_info is not None:
|
| 62 |
+
return True
|
| 63 |
+
for arg in op._schema.returns:
|
| 64 |
+
if arg.alias_info is not None:
|
| 65 |
+
return True
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def has_effects(op, args, kwargs) -> bool:
|
| 70 |
+
return (
|
| 71 |
+
isinstance(op, torch._ops.OpOverload)
|
| 72 |
+
and not has_aliasing(op)
|
| 73 |
+
and get_effect_key(op, args, kwargs) is not None
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
|
| 78 |
+
if op in SIDE_EFFECTS:
|
| 79 |
+
return SIDE_EFFECTS[op]
|
| 80 |
+
|
| 81 |
+
for arg in args:
|
| 82 |
+
if isinstance(arg, torch.ScriptObject):
|
| 83 |
+
return _EffectType.ORDERED
|
| 84 |
+
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 89 |
+
def with_effects_dense(
|
| 90 |
+
token: torch.Tensor,
|
| 91 |
+
op: torch._ops.OpOverload,
|
| 92 |
+
*args: Tuple[Any, ...],
|
| 93 |
+
**kwargs: Dict[str, Any],
|
| 94 |
+
) -> Tuple[torch.Tensor, ...]:
|
| 95 |
+
out = op(*args, **kwargs)
|
| 96 |
+
new_token = torch.tensor([])
|
| 97 |
+
if isinstance(out, tuple):
|
| 98 |
+
return (new_token, *out)
|
| 99 |
+
return (new_token, out)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@with_effects.py_impl(FakeTensorMode)
|
| 103 |
+
def with_effects_fake(
|
| 104 |
+
mode,
|
| 105 |
+
token: torch.Tensor,
|
| 106 |
+
op: torch._ops.OpOverload,
|
| 107 |
+
*args: Tuple[Any, ...],
|
| 108 |
+
**kwargs: Dict[str, Any],
|
| 109 |
+
) -> Tuple[torch.Tensor, ...]:
|
| 110 |
+
with mode:
|
| 111 |
+
result = with_effects_dense(token, op, *args, **kwargs)
|
| 112 |
+
return result
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@with_effects.py_impl(ProxyTorchDispatchMode)
|
| 116 |
+
def with_effects_proxy(
|
| 117 |
+
mode,
|
| 118 |
+
token: torch.Tensor,
|
| 119 |
+
op: torch._ops.OpOverload,
|
| 120 |
+
*args: Tuple[Any, ...],
|
| 121 |
+
**kwargs: Dict[str, Any],
|
| 122 |
+
) -> Tuple[torch.Tensor, ...]:
|
| 123 |
+
if not mode.enable_tracing:
|
| 124 |
+
return with_effects(token, op, *args, **kwargs)
|
| 125 |
+
|
| 126 |
+
with disable_proxy_modes_tracing():
|
| 127 |
+
out = with_effects(token, op, *args, **kwargs)
|
| 128 |
+
|
| 129 |
+
proxy_token = mode.tracer.unwrap_proxy(token)
|
| 130 |
+
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
|
| 131 |
+
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
|
| 132 |
+
|
| 133 |
+
out_proxy = mode.tracer.create_proxy(
|
| 134 |
+
"call_function",
|
| 135 |
+
with_effects,
|
| 136 |
+
(proxy_token, op, *proxy_args),
|
| 137 |
+
proxy_kwargs,
|
| 138 |
+
)
|
| 139 |
+
result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
| 140 |
+
return result
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
with_effects.fallthrough(DispatchKey.AutogradCPU)
|
| 144 |
+
with_effects.fallthrough(DispatchKey.AutogradCUDA)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def handle_effects(
|
| 148 |
+
allow_token_discovery: bool,
|
| 149 |
+
tokens: Dict[_EffectType, torch.Tensor],
|
| 150 |
+
op: torch._ops.OpOverload,
|
| 151 |
+
args: Tuple[Any, ...],
|
| 152 |
+
kwargs: Dict[str, Any],
|
| 153 |
+
) -> Any:
|
| 154 |
+
"""
|
| 155 |
+
Args:
|
| 156 |
+
allow_token_discovery: Whether or not we are discovering tokens. If this
|
| 157 |
+
is true, we will create a token for every side effect type seen that
|
| 158 |
+
does not have a token assigned yet. If this is false, the tokens
|
| 159 |
+
should've all been created ahead of time, so we will error if there is
|
| 160 |
+
no token mapping to every effect type.
|
| 161 |
+
|
| 162 |
+
tokens: Map of effect type to tokens. This is to chain operators of the
|
| 163 |
+
same effects together so that they do not get reordered in later
|
| 164 |
+
optimization passes.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
|
| 168 |
+
# this will create an empty tensor during proxy mode tracing if the token
|
| 169 |
+
# doesn't exist. But the tokens should always exist during proxy mode tracing.
|
| 170 |
+
key = get_effect_key(op, args, kwargs)
|
| 171 |
+
assert key is not None
|
| 172 |
+
if key not in tokens:
|
| 173 |
+
assert allow_token_discovery, f"Could not find a token for effect {key}"
|
| 174 |
+
tokens[key] = torch.tensor([])
|
| 175 |
+
token = tokens[key]
|
| 176 |
+
|
| 177 |
+
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
|
| 178 |
+
|
| 179 |
+
ctx = PythonFunctionalizeAPI()
|
| 180 |
+
|
| 181 |
+
unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type]
|
| 182 |
+
unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type]
|
| 183 |
+
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
|
| 184 |
+
with ctx.redispatch_to_next():
|
| 185 |
+
(new_token, *unwrapped_outs) = with_effects(
|
| 186 |
+
unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if len(op._schema.returns) == 0:
|
| 190 |
+
assert unwrapped_outs[0] is None
|
| 191 |
+
unwrapped_outs = None # type: ignore[assignment]
|
| 192 |
+
elif len(op._schema.returns) == 1:
|
| 193 |
+
assert len(unwrapped_outs) == 1
|
| 194 |
+
unwrapped_outs = unwrapped_outs[0]
|
| 195 |
+
else:
|
| 196 |
+
assert len(unwrapped_outs) == len(op._schema.returns)
|
| 197 |
+
|
| 198 |
+
# Add the newly created token into the tokens map for a following call to
|
| 199 |
+
# use this token.
|
| 200 |
+
wrapped_token = ctx.wrap_tensors(new_token)
|
| 201 |
+
assert isinstance(wrapped_token, torch.Tensor)
|
| 202 |
+
tokens[key] = wrapped_token
|
| 203 |
+
|
| 204 |
+
return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/map.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.utils._pytree as pytree
|
| 3 |
+
from torch._C import DispatchKey
|
| 4 |
+
from torch._dispatch.python import suspend_functionalization
|
| 5 |
+
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
|
| 6 |
+
|
| 7 |
+
from torch._higher_order_ops.utils import (
|
| 8 |
+
_has_potential_branch_input_alias,
|
| 9 |
+
_has_potential_branch_input_mutation,
|
| 10 |
+
reenter_make_fx,
|
| 11 |
+
UnsupportedAliasMutationException,
|
| 12 |
+
)
|
| 13 |
+
from torch._ops import HigherOrderOperator
|
| 14 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 15 |
+
from torch._subclasses.functional_tensor import (
|
| 16 |
+
disable_functional_mode,
|
| 17 |
+
FunctionalTensor,
|
| 18 |
+
)
|
| 19 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 20 |
+
disable_proxy_modes_tracing,
|
| 21 |
+
make_fx,
|
| 22 |
+
ProxyTorchDispatchMode,
|
| 23 |
+
track_tensor_tree,
|
| 24 |
+
)
|
| 25 |
+
from torch.multiprocessing.reductions import StorageWeakRef
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# TODO: We add this to prevent dymamo from tracing into map_wrapper,
|
| 29 |
+
# remove the wrapper call when it's ready.
|
| 30 |
+
class MapWrapper(HigherOrderOperator):
|
| 31 |
+
def __call__(self, xs, *args):
|
| 32 |
+
return map_wrapper(xs, *args)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
map = MapWrapper("map")
|
| 36 |
+
map_impl = HigherOrderOperator("map_impl")
|
| 37 |
+
|
| 38 |
+
dummy_aot_config = AOTConfig(
|
| 39 |
+
fw_compiler=None, # type: ignore[arg-type]
|
| 40 |
+
bw_compiler=None, # type: ignore[arg-type]
|
| 41 |
+
partition_fn=None, # type: ignore[arg-type]
|
| 42 |
+
decompositions={},
|
| 43 |
+
num_params_buffers=0,
|
| 44 |
+
aot_id=0,
|
| 45 |
+
keep_inference_input_mutations=False,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def create_fw_bw_graph(f, num_mapped_args, *args):
|
| 50 |
+
mapped_xs = args[:num_mapped_args]
|
| 51 |
+
pos_args = args[num_mapped_args:]
|
| 52 |
+
|
| 53 |
+
# Note: We create "clean" environments for make_fx by suspending all dispatch keys
|
| 54 |
+
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
|
| 55 |
+
# added when required. Will encounter two problems if we don't suspend functionalization:
|
| 56 |
+
#
|
| 57 |
+
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
|
| 58 |
+
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
|
| 59 |
+
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
|
| 60 |
+
# fetch the proxy for the inputs and fail to capture any operations on them.
|
| 61 |
+
#
|
| 62 |
+
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
|
| 63 |
+
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
|
| 64 |
+
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
|
| 65 |
+
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
|
| 66 |
+
# Instead, it will create _tensor_constant as output.
|
| 67 |
+
|
| 68 |
+
with suspend_functionalization(), disable_functional_mode():
|
| 69 |
+
with disable_proxy_modes_tracing():
|
| 70 |
+
|
| 71 |
+
def _from_fun(t):
|
| 72 |
+
if isinstance(t, torch.Tensor):
|
| 73 |
+
if t.dtype != torch.bool:
|
| 74 |
+
return torch.empty_strided(
|
| 75 |
+
t.size(),
|
| 76 |
+
t.stride(),
|
| 77 |
+
dtype=t.dtype,
|
| 78 |
+
requires_grad=t.requires_grad,
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
# clone of a functional tensor produces a functional tensor
|
| 82 |
+
# but we want to avoid it so we clone a non-functional version
|
| 83 |
+
maybe_unfunc_t = t
|
| 84 |
+
if isinstance(t, FunctionalTensor):
|
| 85 |
+
torch._sync(t)
|
| 86 |
+
maybe_unfunc_t = from_fun(t)
|
| 87 |
+
elif torch._is_functional_tensor(t):
|
| 88 |
+
# need to handle both types of functionalization here:
|
| 89 |
+
# these are the tensors that came from the user,
|
| 90 |
+
# which could be either FunctionalTensorWrapper or FunctionalTensor
|
| 91 |
+
torch._sync(t)
|
| 92 |
+
maybe_unfunc_t = torch._from_functional_tensor(t)
|
| 93 |
+
return maybe_unfunc_t.clone()
|
| 94 |
+
return t
|
| 95 |
+
|
| 96 |
+
unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs)
|
| 97 |
+
example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]
|
| 98 |
+
|
| 99 |
+
example_pos_args = [
|
| 100 |
+
_from_fun(arg) if isinstance(arg, torch.Tensor) else arg
|
| 101 |
+
for arg in pos_args
|
| 102 |
+
]
|
| 103 |
+
example_flat_out = pytree.tree_map(
|
| 104 |
+
_from_fun, f(*example_xs, *example_pos_args)
|
| 105 |
+
)
|
| 106 |
+
if any(
|
| 107 |
+
not isinstance(out, torch.Tensor)
|
| 108 |
+
for out in example_flat_out
|
| 109 |
+
if out is not None
|
| 110 |
+
):
|
| 111 |
+
raise RuntimeError(
|
| 112 |
+
"Expect outputs of map only contains tensors or None. "
|
| 113 |
+
f"Got types {[type(out) for out in example_flat_out]}."
|
| 114 |
+
)
|
| 115 |
+
example_grad = [_from_fun(out) for out in example_flat_out]
|
| 116 |
+
|
| 117 |
+
fw_graph = make_fx(f)(*example_xs, *example_pos_args)
|
| 118 |
+
|
| 119 |
+
def joint_f(*example_args):
|
| 120 |
+
joint_mapped_args = example_args[:joint_num_mapped]
|
| 121 |
+
args = example_args[joint_num_mapped:]
|
| 122 |
+
|
| 123 |
+
mapped_input = joint_mapped_args[:num_mapped_args]
|
| 124 |
+
mapped_grads = joint_mapped_args[num_mapped_args:]
|
| 125 |
+
|
| 126 |
+
def fw_with_masks(*args):
|
| 127 |
+
fw_out = f(*args)
|
| 128 |
+
return fw_out, [
|
| 129 |
+
True
|
| 130 |
+
if isinstance(ret, torch.Tensor) and ret.requires_grad
|
| 131 |
+
else False
|
| 132 |
+
for ret in fw_out
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
|
| 136 |
+
_, grads = joint(
|
| 137 |
+
list(mapped_input) + list(args),
|
| 138 |
+
[
|
| 139 |
+
grad
|
| 140 |
+
for grad in mapped_grads
|
| 141 |
+
if grad is not None and grad.requires_grad
|
| 142 |
+
],
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# In order to keep map functional for backward graph,
|
| 146 |
+
# we clone outputs that are aliasing inputs
|
| 147 |
+
input_storage = {
|
| 148 |
+
StorageWeakRef(arg._typed_storage())
|
| 149 |
+
for arg in example_args
|
| 150 |
+
if isinstance(arg, torch.Tensor)
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
def maybe_clone(t):
|
| 154 |
+
if (
|
| 155 |
+
isinstance(t, torch.Tensor)
|
| 156 |
+
and StorageWeakRef(t._typed_storage()) in input_storage
|
| 157 |
+
):
|
| 158 |
+
return t.clone()
|
| 159 |
+
return t
|
| 160 |
+
|
| 161 |
+
return pytree.tree_map(maybe_clone, grads)
|
| 162 |
+
|
| 163 |
+
joint_num_mapped = len(example_grad) + len(example_xs)
|
| 164 |
+
joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
|
| 165 |
+
return fw_graph, joint_graph
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def map_wrapper(f, xs, *args):
|
| 169 |
+
flat_xs, xs_spec = pytree.tree_flatten(xs)
|
| 170 |
+
if not all(isinstance(t, torch.Tensor) for t in flat_xs):
|
| 171 |
+
raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
|
| 172 |
+
|
| 173 |
+
num_mapped_args = len(flat_xs)
|
| 174 |
+
shapes = [xs.shape for xs in flat_xs]
|
| 175 |
+
leading_dim_size = shapes[0][0]
|
| 176 |
+
if leading_dim_size == 0:
|
| 177 |
+
raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
|
| 178 |
+
|
| 179 |
+
if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
|
| 180 |
+
raise RuntimeError(
|
| 181 |
+
f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
out_spec = None
|
| 185 |
+
|
| 186 |
+
def flat_fn(*flat_args):
|
| 187 |
+
xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec)
|
| 188 |
+
unflattened_out = f(xs, *flat_args[num_mapped_args:])
|
| 189 |
+
flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)
|
| 190 |
+
|
| 191 |
+
nonlocal out_spec
|
| 192 |
+
out_spec = tmp_out_spec
|
| 193 |
+
return flat_out
|
| 194 |
+
|
| 195 |
+
return pytree.tree_unflatten(
|
| 196 |
+
map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type]
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class MapAutogradOp(torch.autograd.Function):
|
| 201 |
+
@staticmethod
|
| 202 |
+
def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
|
| 203 |
+
ctx.save_for_backward(*flat_args)
|
| 204 |
+
ctx._joint_graph = joint_graph
|
| 205 |
+
ctx._num_mapped_args = num_mapped_args
|
| 206 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 207 |
+
return (
|
| 208 |
+
*map_impl(
|
| 209 |
+
fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
|
| 210 |
+
),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def backward(ctx, *flat_grads):
|
| 215 |
+
fw_args = ctx.saved_tensors
|
| 216 |
+
fw_mapped_args = fw_args[: ctx._num_mapped_args]
|
| 217 |
+
pos_args = fw_args[ctx._num_mapped_args :]
|
| 218 |
+
|
| 219 |
+
grads = map_impl(
|
| 220 |
+
ctx._joint_graph,
|
| 221 |
+
fw_mapped_args + flat_grads,
|
| 222 |
+
pos_args,
|
| 223 |
+
)
|
| 224 |
+
return None, None, None, *grads
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def trace_map(proxy_mode, func_overload, f, xs, pos_args):
|
| 228 |
+
leading_dim_size = xs[0].shape[0]
|
| 229 |
+
|
| 230 |
+
example_input = _unstack_pytree(xs)[0]
|
| 231 |
+
body_graph = f
|
| 232 |
+
|
| 233 |
+
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
|
| 234 |
+
body_graph = reenter_make_fx(body_graph, pre_dispatch)(*example_input, *pos_args)
|
| 235 |
+
|
| 236 |
+
next_name = None
|
| 237 |
+
i = 0
|
| 238 |
+
while not next_name:
|
| 239 |
+
candidate = f"body_graph_{i}"
|
| 240 |
+
if hasattr(proxy_mode.tracer.root, candidate):
|
| 241 |
+
i += 1
|
| 242 |
+
else:
|
| 243 |
+
next_name = candidate
|
| 244 |
+
|
| 245 |
+
proxy_mode.tracer.root.register_module(next_name, body_graph)
|
| 246 |
+
|
| 247 |
+
with disable_proxy_modes_tracing():
|
| 248 |
+
example_outs = body_graph(*example_input, *pos_args)
|
| 249 |
+
|
| 250 |
+
def expand_tensor(t):
|
| 251 |
+
if isinstance(t, torch.Tensor):
|
| 252 |
+
return t.expand(leading_dim_size, *t.shape)
|
| 253 |
+
return t
|
| 254 |
+
|
| 255 |
+
expanded_outs = pytree.tree_map(expand_tensor, example_outs)
|
| 256 |
+
|
| 257 |
+
node_args = (body_graph, list(xs), list(pos_args))
|
| 258 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
| 259 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 260 |
+
"call_function", func_overload, proxy_args, {}, name="map_impl"
|
| 261 |
+
)
|
| 262 |
+
return track_tensor_tree(
|
| 263 |
+
expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _unstack_pytree(xs):
|
| 268 |
+
flat_xs, inspec = pytree.tree_flatten(xs)
|
| 269 |
+
if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
|
| 270 |
+
raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
|
| 271 |
+
|
| 272 |
+
if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
|
| 273 |
+
raise RuntimeError(
|
| 274 |
+
f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
a = zip(*flat_xs)
|
| 278 |
+
|
| 279 |
+
pytrees = []
|
| 280 |
+
for tuple in a:
|
| 281 |
+
pytrees.append(pytree.tree_unflatten(tuple, inspec))
|
| 282 |
+
return pytrees
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _stack_pytree(pytrees):
|
| 286 |
+
flat_out = []
|
| 287 |
+
out_spec = None
|
| 288 |
+
for pt in pytrees:
|
| 289 |
+
flat_pt, out_spec = pytree.tree_flatten(pt)
|
| 290 |
+
flat_out.append(flat_pt)
|
| 291 |
+
assert out_spec is not None
|
| 292 |
+
b = zip(*flat_out)
|
| 293 |
+
stacked_out = []
|
| 294 |
+
for leaves in b:
|
| 295 |
+
if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
|
| 296 |
+
stacked_out.append(torch.stack(leaves))
|
| 297 |
+
elif all(leaf is None for leaf in leaves):
|
| 298 |
+
# Backward graph can return None output when forward inputs doesn't require grad.
|
| 299 |
+
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
|
| 300 |
+
# therefore we need to deal with None output.
|
| 301 |
+
stacked_out.append(None) # type: ignore[arg-type]
|
| 302 |
+
else:
|
| 303 |
+
raise RuntimeError(f"Cannot stack {leaves}.")
|
| 304 |
+
return pytree.tree_unflatten(stacked_out, out_spec)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 308 |
+
def map_dense(f, xs, pos_args):
|
| 309 |
+
pytrees = []
|
| 310 |
+
for inp in _unstack_pytree(xs):
|
| 311 |
+
pytrees.append(f(*inp, *pos_args))
|
| 312 |
+
return _stack_pytree(pytrees)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@map_impl.py_impl(DispatchKey.Autograd)
|
| 316 |
+
def map_autograd(f, xs, pos_args):
|
| 317 |
+
num_mapped_args = len(xs)
|
| 318 |
+
fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
|
| 319 |
+
flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
|
| 320 |
+
return flat_out
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@map_impl.py_impl(ProxyTorchDispatchMode)
|
| 324 |
+
def map_proxy_torch_dispatch_mode(mode, f, xs, args):
|
| 325 |
+
if mode.enable_tracing:
|
| 326 |
+
return trace_map(mode, map_impl, f, xs, args)
|
| 327 |
+
else:
|
| 328 |
+
return map_impl(f, xs, args)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
@map_impl.py_impl(FakeTensorMode)
|
| 332 |
+
def map_fake_tensor_mode(mode, f, xs, args):
|
| 333 |
+
with mode:
|
| 334 |
+
return map_dense(f, xs, args)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
@map_impl.py_functionalize_impl
|
| 338 |
+
def map_functionalize(ctx, f, xs, pos_args):
|
| 339 |
+
unwrapped_xs = ctx.unwrap_tensors(xs)
|
| 340 |
+
unwrapped_args = ctx.unwrap_tensors(pos_args)
|
| 341 |
+
wrapped_fn = ctx.functionalize(f)
|
| 342 |
+
|
| 343 |
+
with ctx.redispatch_to_next():
|
| 344 |
+
with disable_proxy_modes_tracing():
|
| 345 |
+
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
|
| 346 |
+
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
| 347 |
+
if _has_potential_branch_input_mutation(
|
| 348 |
+
f, example_inputs, pre_dispatch=pre_dispatch
|
| 349 |
+
):
|
| 350 |
+
raise UnsupportedAliasMutationException("torch.map is mutating the input!")
|
| 351 |
+
|
| 352 |
+
if _has_potential_branch_input_alias(
|
| 353 |
+
f, example_inputs, pre_dispatch=pre_dispatch
|
| 354 |
+
):
|
| 355 |
+
raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
|
| 356 |
+
|
| 357 |
+
map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
|
| 358 |
+
return ctx.wrap_tensors(map_return)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/strict_mode.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch._subclasses.functional_tensor
|
| 3 |
+
|
| 4 |
+
import torch.utils._pytree as pytree
|
| 5 |
+
|
| 6 |
+
from torch._C import DispatchKey
|
| 7 |
+
from torch._functorch.utils import exposed_in
|
| 8 |
+
|
| 9 |
+
from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented
|
| 10 |
+
from torch._ops import HigherOrderOperator
|
| 11 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 12 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 13 |
+
disable_proxy_modes_tracing,
|
| 14 |
+
make_fx,
|
| 15 |
+
ProxyTorchDispatchMode,
|
| 16 |
+
track_tensor_tree,
|
| 17 |
+
)
|
| 18 |
+
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@exposed_in("torch")
|
| 22 |
+
def strict_mode(callable, operands):
|
| 23 |
+
if torch.compiler.is_dynamo_compiling():
|
| 24 |
+
return strict_mode_op(callable, operands)
|
| 25 |
+
|
| 26 |
+
with _set_compilation_env():
|
| 27 |
+
with torch._dynamo.utils.disable_cache_limit():
|
| 28 |
+
return torch.compile(strict_mode_op, backend="eager", fullgraph=True)(
|
| 29 |
+
callable, operands
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
strict_mode_op = HigherOrderOperator("strict_mode")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 37 |
+
def strict_mode_op_dense(callable, operands):
|
| 38 |
+
mode = _get_current_dispatch_mode()
|
| 39 |
+
assert mode is None, "Mode should never be enabled for CPU/CUDA key"
|
| 40 |
+
return callable(*operands)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
strict_mode_op.py_impl(DispatchKey.Autograd)(
|
| 44 |
+
autograd_not_implemented(strict_mode_op, deferred_error=True)
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@strict_mode_op.py_impl(ProxyTorchDispatchMode)
|
| 49 |
+
def inner(mode, callable, operands):
|
| 50 |
+
if mode.enable_tracing:
|
| 51 |
+
return trace_strict_mode(mode, strict_mode_op, callable, operands)
|
| 52 |
+
else:
|
| 53 |
+
return strict_mode_op(callable, operands)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def trace_strict_mode(mode, strict_mode_op, callable, operands):
|
| 57 |
+
pre_dispatch = getattr(mode, "pre_dispatch", False)
|
| 58 |
+
|
| 59 |
+
with disable_proxy_modes_tracing():
|
| 60 |
+
graph = make_fx(callable, pre_dispatch=pre_dispatch)(*operands)
|
| 61 |
+
|
| 62 |
+
next_name = None
|
| 63 |
+
i = 0
|
| 64 |
+
while not next_name:
|
| 65 |
+
candidate = f"strict_graph_{i}"
|
| 66 |
+
if hasattr(mode.tracer.root, candidate):
|
| 67 |
+
i += 1
|
| 68 |
+
else:
|
| 69 |
+
next_name = candidate
|
| 70 |
+
|
| 71 |
+
graph_name = next_name
|
| 72 |
+
mode.tracer.root.register_module(graph_name, graph)
|
| 73 |
+
|
| 74 |
+
args = (graph, operands)
|
| 75 |
+
|
| 76 |
+
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
|
| 77 |
+
|
| 78 |
+
out_proxy = mode.tracer.create_proxy(
|
| 79 |
+
"call_function", strict_mode_op, proxy_args, {}, name="strict_mode"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
out = graph(*operands)
|
| 83 |
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@strict_mode_op.py_impl(FakeTensorMode)
|
| 87 |
+
def strict_mode_fake_tensor_mode(mode, callable, operands):
|
| 88 |
+
with mode:
|
| 89 |
+
true_outs = callable(*operands)
|
| 90 |
+
return true_outs
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@strict_mode_op.py_functionalize_impl
|
| 94 |
+
def strict_mode_func(ctx, callable, inputs):
|
| 95 |
+
unwrapped_inputs = ctx.unwrap_tensors(inputs)
|
| 96 |
+
with ctx.redispatch_to_next():
|
| 97 |
+
functional_callable = ctx.functionalize(callable)
|
| 98 |
+
|
| 99 |
+
cond_return = strict_mode_op(functional_callable, unwrapped_inputs)
|
| 100 |
+
return ctx.wrap_tensors(cond_return)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/torchbind.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch._C import DispatchKey # @manual
|
| 5 |
+
from torch._functorch._aot_autograd.utils import KNOWN_TYPES
|
| 6 |
+
from torch._higher_order_ops.utils import autograd_not_implemented
|
| 7 |
+
from torch._ops import HigherOrderOperator
|
| 8 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 9 |
+
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
| 10 |
+
from torch.fx.node import has_side_effect
|
| 11 |
+
from torch.utils import _pytree as pytree
|
| 12 |
+
|
| 13 |
+
# The call_torchbind operator represents a method invocation on a torchbind
|
| 14 |
+
# object. The calling convention is:
|
| 15 |
+
# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
|
| 16 |
+
# We do not expect users to write this operator directly. Instead it will be
|
| 17 |
+
# emitted by Dynamo when tracing encounters a torchbind object.
|
| 18 |
+
call_torchbind = HigherOrderOperator("call_torchbind")
|
| 19 |
+
|
| 20 |
+
# Register this operator as side-effectful with FX.
|
| 21 |
+
# TODO: this is not really sufficient. While passes (hopefully) check
|
| 22 |
+
# Node.is_impure() and make good decisions, we also assume we can execute the
|
| 23 |
+
# graph as many times as we want without changing behavior, which is NOT true of
|
| 24 |
+
# ops that mutate torchbind object state.
|
| 25 |
+
has_side_effect(call_torchbind)
|
| 26 |
+
|
| 27 |
+
_orig_scriptmethod_call = torch.ScriptMethod.__call__
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def torchbind_method_redispatch(self, *args, **kwargs):
|
| 31 |
+
if isinstance(self.raw_owner, torch.ScriptObject):
|
| 32 |
+
return call_torchbind(self.raw_owner, self.name, *args, **kwargs)
|
| 33 |
+
return _orig_scriptmethod_call(self, *args, **kwargs)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@contextmanager
|
| 37 |
+
def enable_torchbind_tracing():
|
| 38 |
+
"""Context manager that acts as a feature flag to enable torchbind tracing
|
| 39 |
+
behavior. Once torchbind tracing has been stabilized, we can remove this and
|
| 40 |
+
turn it always on.
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
KNOWN_TYPES.append(torch.ScriptObject)
|
| 44 |
+
torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign]
|
| 45 |
+
yield
|
| 46 |
+
finally:
|
| 47 |
+
assert (
|
| 48 |
+
KNOWN_TYPES.pop() is torch.ScriptObject
|
| 49 |
+
), "Someone else messed with KNOWN_TYPES during tracing, exploding."
|
| 50 |
+
torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 54 |
+
def call_torchbind_impl(obj, method, *args, **kwargs):
|
| 55 |
+
return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@call_torchbind.py_impl(ProxyTorchDispatchMode)
|
| 59 |
+
def inner(mode, *args, **kwargs):
|
| 60 |
+
if mode.enable_tracing:
|
| 61 |
+
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
|
| 62 |
+
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
|
| 63 |
+
|
| 64 |
+
out_proxy = mode.tracer.create_proxy(
|
| 65 |
+
"call_function",
|
| 66 |
+
call_torchbind,
|
| 67 |
+
proxy_args,
|
| 68 |
+
proxy_kwargs,
|
| 69 |
+
)
|
| 70 |
+
out = call_torchbind_impl(*args, **kwargs)
|
| 71 |
+
|
| 72 |
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
| 73 |
+
else:
|
| 74 |
+
return call_torchbind(*args, **kwargs)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# TODO: currently we just run the C++ implementation with fake tensors.
|
| 78 |
+
# But we should make it possible to register a fake torchbind implementation.
|
| 79 |
+
@call_torchbind.py_impl(FakeTensorMode)
|
| 80 |
+
def call_torchbind_fake(mode, *args, **kwargs):
|
| 81 |
+
with mode:
|
| 82 |
+
return call_torchbind_impl(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
call_torchbind.py_impl(DispatchKey.Autograd)(
|
| 86 |
+
autograd_not_implemented(call_torchbind, deferred_error=True)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@call_torchbind.py_functionalize_impl
|
| 91 |
+
def call_torchbind_func(ctx, *args, **kwargs):
|
| 92 |
+
args = ctx.unwrap_tensors(args)
|
| 93 |
+
with ctx.redispatch_to_next():
|
| 94 |
+
return ctx.wrap_tensors(call_torchbind(*args, **kwargs))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py
ADDED
|
@@ -0,0 +1,842 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
import threading
|
| 4 |
+
import warnings
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import Any, Dict, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import torch.utils._pytree as pytree
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch._C import DispatchKey
|
| 11 |
+
from torch._ops import HigherOrderOperator
|
| 12 |
+
from torch._prims_common import clone_preserve_strides
|
| 13 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 14 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 15 |
+
disable_proxy_modes_tracing,
|
| 16 |
+
ProxyTorchDispatchMode,
|
| 17 |
+
track_tensor_tree,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger("torch._dynamo")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
###############################################################################
|
| 24 |
+
# Kernel Side Table
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# We cannot put Triton Kernels into the FX graph as the graph nodes
|
| 28 |
+
# do not support arbitrary functions.
|
| 29 |
+
# Use a side table.
|
| 30 |
+
# We use two dicts so that fetching both the kernel and id are O(1)
|
| 31 |
+
class KernelSideTable:
|
| 32 |
+
id_to_kernel: Dict[int, Any] = dict()
|
| 33 |
+
kernel_to_id: Dict[Any, int] = dict()
|
| 34 |
+
lock = threading.Lock()
|
| 35 |
+
|
| 36 |
+
# Returns index on the table
|
| 37 |
+
def add_kernel(self, kernel) -> int:
|
| 38 |
+
with self.lock:
|
| 39 |
+
if kernel in self.kernel_to_id:
|
| 40 |
+
return self.kernel_to_id[kernel]
|
| 41 |
+
|
| 42 |
+
idx = len(self.id_to_kernel)
|
| 43 |
+
self.id_to_kernel[idx] = kernel
|
| 44 |
+
self.kernel_to_id[kernel] = idx
|
| 45 |
+
return idx
|
| 46 |
+
|
| 47 |
+
# Returns the triton kernel at the given index
|
| 48 |
+
def get_kernel(self, idx: int):
|
| 49 |
+
# No need to lock here as fetching from dict is atomic
|
| 50 |
+
assert idx in self.id_to_kernel
|
| 51 |
+
return self.id_to_kernel[idx]
|
| 52 |
+
|
| 53 |
+
# Resets the table (only meant to be used in unit tests)
|
| 54 |
+
# This is only safe assuming single threaded execution
|
| 55 |
+
def reset_table(self) -> None:
|
| 56 |
+
self.id_to_kernel = dict()
|
| 57 |
+
self.kernel_to_id = dict()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
kernel_side_table = KernelSideTable()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
###############################################################################
|
| 64 |
+
# Mutation Tracker
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclasses.dataclass(frozen=True)
|
| 68 |
+
class Param:
|
| 69 |
+
idx: int
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclasses.dataclass(frozen=True)
|
| 73 |
+
class Intermediate:
|
| 74 |
+
idx: int
|
| 75 |
+
|
| 76 |
+
def fake(self):
|
| 77 |
+
return self.idx < 0
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclasses.dataclass(frozen=True)
|
| 81 |
+
class Op:
|
| 82 |
+
name: str
|
| 83 |
+
fn_call_name: Optional[str]
|
| 84 |
+
args: List[Union[Param, Intermediate]]
|
| 85 |
+
ret: Intermediate = dataclasses.field(repr=False)
|
| 86 |
+
|
| 87 |
+
def __post_init__(self):
|
| 88 |
+
if self.name == "tt.call":
|
| 89 |
+
assert self.fn_call_name is not None
|
| 90 |
+
else:
|
| 91 |
+
assert self.fn_call_name is None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def generate_ttir(kernel, kwargs):
|
| 95 |
+
"""
|
| 96 |
+
Uses Triton's internal code generation to create TTIR
|
| 97 |
+
"""
|
| 98 |
+
from triton.compiler.compiler import ASTSource
|
| 99 |
+
from triton.runtime.autotuner import Autotuner
|
| 100 |
+
from triton.runtime.jit import JITFunction
|
| 101 |
+
|
| 102 |
+
import torch
|
| 103 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 104 |
+
|
| 105 |
+
if isinstance(kernel, Autotuner):
|
| 106 |
+
if len(kernel.configs) > 0:
|
| 107 |
+
# If we are autotuning, then it doesn't matter which version gets
|
| 108 |
+
# picked for tracing purposes, so lets pick the first one
|
| 109 |
+
kwargs = {**kwargs, **kernel.configs[0].kwargs}
|
| 110 |
+
kernel = kernel.fn
|
| 111 |
+
|
| 112 |
+
assert isinstance(kernel, JITFunction)
|
| 113 |
+
|
| 114 |
+
if len(kwargs) != len(kernel.arg_names):
|
| 115 |
+
raise Exception("Incorrect number of arguments passed to kernel")
|
| 116 |
+
|
| 117 |
+
# Replace all SymExprs with a regular value for TTIR generation
|
| 118 |
+
# Replace all FakeTensor with real tensors
|
| 119 |
+
# These replacements are needed for triton's type, key and config functions
|
| 120 |
+
ordered_args: Dict[str, Any] = {}
|
| 121 |
+
for name in kernel.arg_names:
|
| 122 |
+
a = kwargs[name]
|
| 123 |
+
if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
| 124 |
+
ordered_args[name] = 2
|
| 125 |
+
elif isinstance(a, FakeTensor):
|
| 126 |
+
ordered_args[name] = torch.empty(2, dtype=a.dtype)
|
| 127 |
+
else:
|
| 128 |
+
ordered_args[name] = a
|
| 129 |
+
|
| 130 |
+
ordered_tensor_names = [
|
| 131 |
+
name for name, arg in ordered_args.items() if isinstance(arg, Tensor)
|
| 132 |
+
]
|
| 133 |
+
specialization = kernel._get_config(*ordered_args.values())
|
| 134 |
+
constants = {
|
| 135 |
+
i: arg
|
| 136 |
+
for i, arg in enumerate(ordered_args.values())
|
| 137 |
+
if not isinstance(arg, Tensor)
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
# Build kernel signature -- doesn't include constexpr arguments.
|
| 141 |
+
signature = {
|
| 142 |
+
i: kernel._type_of(kernel._key_of(arg))
|
| 143 |
+
for i, arg in enumerate(ordered_args.values())
|
| 144 |
+
if i not in kernel.constexprs
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
def get_backend():
|
| 148 |
+
from triton.compiler.backends.cuda import CUDABackend
|
| 149 |
+
from triton.runtime.driver import driver
|
| 150 |
+
|
| 151 |
+
target = driver.get_current_target()
|
| 152 |
+
return CUDABackend(target)
|
| 153 |
+
|
| 154 |
+
backend = get_backend()
|
| 155 |
+
|
| 156 |
+
options = backend.parse_options(dict())
|
| 157 |
+
# triton._C.libtriton.triton.ir.load_dialects(context)
|
| 158 |
+
# backend.load_dialects(context)
|
| 159 |
+
|
| 160 |
+
src = ASTSource(kernel, signature, constants, specialization)
|
| 161 |
+
ttir_module = src.make_ir(options)
|
| 162 |
+
if not ttir_module.verify():
|
| 163 |
+
raise Exception("Verification for TTIR module has failed")
|
| 164 |
+
|
| 165 |
+
return ttir_module, ordered_tensor_names
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]:
|
| 169 |
+
"""
|
| 170 |
+
Walk the `ttir_module` bottom up to mine the `functions` from
|
| 171 |
+
the structured MLIR entities representing the Triton kernel
|
| 172 |
+
(mlir::Operation, mlir::Block, mlir::Region).
|
| 173 |
+
"""
|
| 174 |
+
functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
|
| 175 |
+
|
| 176 |
+
# block id --> op result (Intermediate) --> one or more ops
|
| 177 |
+
op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict(
|
| 178 |
+
lambda: defaultdict(list)
|
| 179 |
+
)
|
| 180 |
+
region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list)
|
| 181 |
+
block_id_to_block_arg_ids: Dict[int, List[int]] = {}
|
| 182 |
+
replacements: Dict[int, Union[Intermediate, Param]] = {}
|
| 183 |
+
reindex_map: Dict[int, int] = {}
|
| 184 |
+
next_fake_intermediate = 0
|
| 185 |
+
|
| 186 |
+
def reindex(idx):
|
| 187 |
+
if idx not in reindex_map:
|
| 188 |
+
reindex_map[idx] = len(reindex_map)
|
| 189 |
+
return reindex_map[idx]
|
| 190 |
+
|
| 191 |
+
def mlir_to_functions(op) -> None:
|
| 192 |
+
name: str = op.get_name()
|
| 193 |
+
if name == "builtin.module":
|
| 194 |
+
# this wraps all tt.func ops
|
| 195 |
+
return
|
| 196 |
+
|
| 197 |
+
operand_ids: List[int] = [
|
| 198 |
+
reindex(op.get_operand(i).id()) for i in range(op.get_num_operands())
|
| 199 |
+
]
|
| 200 |
+
result_ids: List[int] = [
|
| 201 |
+
reindex(op.get_result(i).id()) for i in range(op.get_num_results())
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
child_block_ids: List[int] = []
|
| 205 |
+
for i in [op.get_region(i).id() for i in range(op.get_num_regions())]:
|
| 206 |
+
# as the walk is bottom-up, the region_id_to_block_ids[i]
|
| 207 |
+
# must be populated by the time we process the enclosing op
|
| 208 |
+
child_block_ids.extend(region_id_to_block_ids[i])
|
| 209 |
+
|
| 210 |
+
parent_block_id = -1
|
| 211 |
+
parent_block = op.get_block()
|
| 212 |
+
if parent_block is not None:
|
| 213 |
+
parent_block_id = parent_block.id()
|
| 214 |
+
if parent_block_id not in block_id_to_block_arg_ids:
|
| 215 |
+
block_id_to_block_arg_ids[parent_block_id] = []
|
| 216 |
+
for i in range(parent_block.get_num_arguments()):
|
| 217 |
+
block_id_to_block_arg_ids[parent_block_id].append(
|
| 218 |
+
reindex(parent_block.get_argument(i).id()),
|
| 219 |
+
)
|
| 220 |
+
# the region info is collected via ops' parent blocks to be
|
| 221 |
+
# used later when the region's encloding op is traversed
|
| 222 |
+
parent_region = parent_block.get_parent()
|
| 223 |
+
if parent_region is not None:
|
| 224 |
+
region_id_to_block_ids[parent_region.id()].append(parent_block_id)
|
| 225 |
+
|
| 226 |
+
nonlocal next_fake_intermediate
|
| 227 |
+
|
| 228 |
+
if name == "tt.func":
|
| 229 |
+
# for function ops: gather and inline
|
| 230 |
+
# the ops from all child blocks
|
| 231 |
+
fn_ops = defaultdict(list)
|
| 232 |
+
for child_block_id in child_block_ids:
|
| 233 |
+
for result, block_fn_ops in op_stack.pop(child_block_id).items():
|
| 234 |
+
for block_fn_op in block_fn_ops:
|
| 235 |
+
fn_ops[result].append(block_fn_op)
|
| 236 |
+
|
| 237 |
+
# replace the corresponding Intermediates in the
|
| 238 |
+
# child op args with the function args (Params)
|
| 239 |
+
for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]):
|
| 240 |
+
replacements[idx] = Param(i)
|
| 241 |
+
|
| 242 |
+
for fn_op_list in fn_ops.values():
|
| 243 |
+
for fn_op in fn_op_list:
|
| 244 |
+
for i in range(len(fn_op.args)):
|
| 245 |
+
arg = fn_op.args[i]
|
| 246 |
+
if isinstance(arg, Intermediate) and arg.idx in replacements:
|
| 247 |
+
fn_op.args[i] = replacements[arg.idx]
|
| 248 |
+
|
| 249 |
+
# next function capture starts
|
| 250 |
+
# with empty replacements
|
| 251 |
+
replacements.clear()
|
| 252 |
+
|
| 253 |
+
fn_name = op.get_str_attr("sym_name")
|
| 254 |
+
functions[fn_name] = fn_ops
|
| 255 |
+
elif child_block_ids:
|
| 256 |
+
if name in ("scf.if", "scf.for", "scf.while"):
|
| 257 |
+
# for blocked control flow ops: inline the enclosed
|
| 258 |
+
# ops into the parent block + rewire the last op in
|
| 259 |
+
# each child block (yield) to return the scf result
|
| 260 |
+
yield_ops = []
|
| 261 |
+
for block_id in child_block_ids:
|
| 262 |
+
# the block args used as operands of the ops in the block
|
| 263 |
+
# (and nested blocks inlined in the current block by now)
|
| 264 |
+
# are replaced by new fake Intermediates to avoid "this
|
| 265 |
+
# operand is not returned by anything other op in the fn"
|
| 266 |
+
# error in the downstream analysis
|
| 267 |
+
for idx in block_id_to_block_arg_ids[block_id]:
|
| 268 |
+
next_fake_intermediate -= 1
|
| 269 |
+
replacements[idx] = Intermediate(next_fake_intermediate)
|
| 270 |
+
|
| 271 |
+
if block_id in op_stack:
|
| 272 |
+
block_ops = op_stack.pop(block_id)
|
| 273 |
+
if not block_ops:
|
| 274 |
+
continue
|
| 275 |
+
last_ret, last_ops = block_ops.popitem()
|
| 276 |
+
if all(op.name == "scf.yield" for op in last_ops):
|
| 277 |
+
# if last_ops are scf.yield, treat them separately
|
| 278 |
+
yield_ops.extend(last_ops)
|
| 279 |
+
else:
|
| 280 |
+
# otherwise, return last_ops to the block
|
| 281 |
+
block_ops[last_ret] = last_ops
|
| 282 |
+
for op_result, child_ops in block_ops.items():
|
| 283 |
+
op_stack[parent_block_id][op_result].extend(child_ops)
|
| 284 |
+
|
| 285 |
+
scf_results = [Intermediate(idx) for idx in result_ids]
|
| 286 |
+
for scf_result in scf_results:
|
| 287 |
+
for yield_op in yield_ops:
|
| 288 |
+
op_stack[parent_block_id][scf_result].append(yield_op)
|
| 289 |
+
else:
|
| 290 |
+
# TODO(oulgen): add support for tt.reduce
|
| 291 |
+
raise Exception(
|
| 292 |
+
f"Unknown blocked function: {name}. Can't capture the TTIR."
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
callee = None
|
| 296 |
+
if name == "tt.call":
|
| 297 |
+
callee = op.get_flat_symbol_ref_attr("callee")
|
| 298 |
+
args: List[Union[Param, Intermediate]] = [
|
| 299 |
+
Intermediate(operand) for operand in operand_ids
|
| 300 |
+
]
|
| 301 |
+
block_ops = op_stack[parent_block_id]
|
| 302 |
+
if result_ids:
|
| 303 |
+
for result_id in result_ids:
|
| 304 |
+
res = Intermediate(result_id)
|
| 305 |
+
block_ops[res].append(Op(name, callee, args, res))
|
| 306 |
+
else:
|
| 307 |
+
next_fake_intermediate -= 1
|
| 308 |
+
fake_res = Intermediate(next_fake_intermediate)
|
| 309 |
+
block_ops[fake_res].append(Op(name, callee, args, fake_res))
|
| 310 |
+
|
| 311 |
+
ttir_module.walk(mlir_to_functions)
|
| 312 |
+
|
| 313 |
+
return functions
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def parse_ttir(ttir, kwargs):
|
| 317 |
+
"""
|
| 318 |
+
Given a Triton emitted TTIR text, this function lexes and parses the
|
| 319 |
+
code using a minimal grammar defined inside. During the lexing/parsing,
|
| 320 |
+
we drop any constant value and type information as they are not
|
| 321 |
+
necessary to us.
|
| 322 |
+
Being able to choose what we need makes this not a general purpose TTIR
|
| 323 |
+
parser which further makes parsing much simpler.
|
| 324 |
+
"""
|
| 325 |
+
# TODO(oulgen):
|
| 326 |
+
# - Support closures (e.g. "tt.reduce")
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
import lark # type: ignore[import-not-found]
|
| 330 |
+
from lark import Lark, Transformer, v_args
|
| 331 |
+
except ModuleNotFoundError:
|
| 332 |
+
warnings.warn(
|
| 333 |
+
"Using slow path for user-defined Triton kernels. `pip install lark` to fix this."
|
| 334 |
+
)
|
| 335 |
+
raise
|
| 336 |
+
|
| 337 |
+
# Ops looks like one of the following forms:
|
| 338 |
+
#
|
| 339 |
+
# %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>
|
| 340 |
+
# tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32>
|
| 341 |
+
# %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950
|
| 342 |
+
grammar = """
|
| 343 |
+
start: (module_block | loc_line)+
|
| 344 |
+
|
| 345 |
+
loc_line: "#loc" /.+/ NEWLINE
|
| 346 |
+
|
| 347 |
+
module_block: "module" "{" func_block+ "}" LOC
|
| 348 |
+
|
| 349 |
+
func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func
|
| 350 |
+
|
| 351 |
+
?stmt: op | if | for | while | condition_stmt | label_stmt | cf_stmt
|
| 352 |
+
|
| 353 |
+
if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if
|
| 354 |
+
for: [assign_lhs "="] "scf.for" args rest stmt* "}" divisibility_annot? LOC -> process_for
|
| 355 |
+
while: [assign_lhs "="] "scf.while" args rest stmt* "}" "do" "{" stmt* "}" LOC -> process_while
|
| 356 |
+
|
| 357 |
+
condition_stmt: "scf.condition" "(" arg ")" args rest
|
| 358 |
+
label_stmt: LABEL ":" "// pred:" LABEL
|
| 359 |
+
| LABEL "(" /.+/ NEWLINE
|
| 360 |
+
cf_stmt: "cf" "." NAME /.+/ NEWLINE
|
| 361 |
+
|
| 362 |
+
op: OP_NAME LOC
|
| 363 |
+
| [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op
|
| 364 |
+
|
| 365 |
+
?rest: (":" | "{" | "\\"" | "->" | "<" | "=") /.+/ NEWLINE
|
| 366 |
+
divisibility_annot: "{" "tt.divisibility_arg1" /[^}]+/ "}"
|
| 367 |
+
|
| 368 |
+
args: | "(" ")" | "("? arg ("," arg)* ")"?
|
| 369 |
+
|
| 370 |
+
?arg: INTERMEDIATE
|
| 371 |
+
| INTERMEDIATE_CONSTANT
|
| 372 |
+
| CONSTANT
|
| 373 |
+
| PARAM
|
| 374 |
+
| "[" args "]"
|
| 375 |
+
| arg_with_index
|
| 376 |
+
|
| 377 |
+
?arg_with_index: arg "#" DIGIT+
|
| 378 |
+
|
| 379 |
+
?assign_lhs: (INTERMEDIATE | INTERMEDIATE_CONSTANT) [":" DIGIT+]
|
| 380 |
+
|
| 381 |
+
PARAM.5: "%arg" DIGIT+
|
| 382 |
+
INTERMEDIATE.4: "%" DIGIT+
|
| 383 |
+
INTERMEDIATE_CONSTANT.3: "%" NAME
|
| 384 |
+
CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")?
|
| 385 |
+
LABEL: "^bb" DIGIT+
|
| 386 |
+
|
| 387 |
+
NAME: (LETTER | DIGIT | "_")+
|
| 388 |
+
NON_CF_NAME: /(?!(cf))/ NAME
|
| 389 |
+
FN_NAME: "@" (NAME | ESCAPED_STRING)
|
| 390 |
+
OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""?
|
| 391 |
+
|
| 392 |
+
LOC.5: "loc(#loc" DIGIT* ")"
|
| 393 |
+
|
| 394 |
+
%import common.LETTER
|
| 395 |
+
%import common.DIGIT
|
| 396 |
+
%import common.WS
|
| 397 |
+
%import common.NEWLINE
|
| 398 |
+
%import common.ESCAPED_STRING
|
| 399 |
+
%import common.FLOAT
|
| 400 |
+
%ignore WS
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
next_fake_intermediate = 0
|
| 404 |
+
|
| 405 |
+
def convert(token):
|
| 406 |
+
if isinstance(token, lark.tree.Tree):
|
| 407 |
+
if token.data == "args":
|
| 408 |
+
res = []
|
| 409 |
+
for a in token.children:
|
| 410 |
+
c = convert(a)
|
| 411 |
+
if isinstance(c, list):
|
| 412 |
+
res.extend(c)
|
| 413 |
+
else:
|
| 414 |
+
res.append(c)
|
| 415 |
+
return res
|
| 416 |
+
elif token.data in {"assign_lhs", "arg_with_index"}:
|
| 417 |
+
# Drop length/index qualifier
|
| 418 |
+
return convert(token.children[0])
|
| 419 |
+
else:
|
| 420 |
+
raise AssertionError(f"Tree node with {token.data}")
|
| 421 |
+
|
| 422 |
+
if token is None or (
|
| 423 |
+
isinstance(token, lark.lexer.Token)
|
| 424 |
+
and token.type in ("CONSTANT", "INTERMEDIATE_CONSTANT")
|
| 425 |
+
):
|
| 426 |
+
nonlocal next_fake_intermediate
|
| 427 |
+
next_fake_intermediate -= 1
|
| 428 |
+
return Intermediate(next_fake_intermediate)
|
| 429 |
+
|
| 430 |
+
assert isinstance(token, lark.lexer.Token)
|
| 431 |
+
|
| 432 |
+
if token.type == "INTERMEDIATE":
|
| 433 |
+
return Intermediate(int(token.value[len("%") :]))
|
| 434 |
+
if token.type == "PARAM":
|
| 435 |
+
return Param(int(token.value[len("%arg") :]))
|
| 436 |
+
|
| 437 |
+
raise AssertionError(f"{type(token.type)} => {token.value} invalid")
|
| 438 |
+
|
| 439 |
+
# In alternative representation, function names are quoted.
|
| 440 |
+
# It should be possible to move this into the grammar alltogether.
|
| 441 |
+
def convert_name(token):
|
| 442 |
+
if token is None:
|
| 443 |
+
return None
|
| 444 |
+
s = token.value
|
| 445 |
+
if len(s) > 2 and s[0] == '"' and s[-1] == '"':
|
| 446 |
+
return s[1:-1]
|
| 447 |
+
return s
|
| 448 |
+
|
| 449 |
+
functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
|
| 450 |
+
|
| 451 |
+
def extend_dict_list(d1, d2):
|
| 452 |
+
for key, values in d2.items():
|
| 453 |
+
d1[key].extend(values)
|
| 454 |
+
|
| 455 |
+
@v_args(inline=True)
|
| 456 |
+
class TransformOps(Transformer):
|
| 457 |
+
def process_op(self, ret, op_name, fn_name, args, *rest):
|
| 458 |
+
return Op(
|
| 459 |
+
convert_name(op_name),
|
| 460 |
+
convert_name(fn_name),
|
| 461 |
+
convert(args),
|
| 462 |
+
convert(ret),
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def process_func(self, name, _args, *stmts):
|
| 466 |
+
ops: Dict[Intermediate, List[Op]] = defaultdict(list)
|
| 467 |
+
for e in stmts:
|
| 468 |
+
if isinstance(e, Op):
|
| 469 |
+
ops[e.ret].append(e)
|
| 470 |
+
elif isinstance(e, dict):
|
| 471 |
+
extend_dict_list(ops, e)
|
| 472 |
+
functions[name.value] = ops
|
| 473 |
+
|
| 474 |
+
def _process_scf(self, ret, stmts):
|
| 475 |
+
ret = convert(ret)
|
| 476 |
+
ops: Dict[Intermediate, List[Op]] = defaultdict(list)
|
| 477 |
+
for e in stmts:
|
| 478 |
+
if isinstance(e, Op):
|
| 479 |
+
if e.name == "scf.yield":
|
| 480 |
+
ops[ret].append(Op(e.name, None, e.args, ret))
|
| 481 |
+
else:
|
| 482 |
+
ops[e.ret].append(e)
|
| 483 |
+
elif isinstance(e, dict):
|
| 484 |
+
extend_dict_list(ops, e)
|
| 485 |
+
return ops
|
| 486 |
+
|
| 487 |
+
def process_if(self, ret, _args, _rest, *stmts):
|
| 488 |
+
return self._process_scf(ret, stmts)
|
| 489 |
+
|
| 490 |
+
def process_for(self, ret, _args, _rest, *stmts):
|
| 491 |
+
return self._process_scf(ret, stmts)
|
| 492 |
+
|
| 493 |
+
def process_while(self, ret, _args, _rest, *stmts):
|
| 494 |
+
return self._process_scf(ret, stmts)
|
| 495 |
+
|
| 496 |
+
parser = Lark(
|
| 497 |
+
grammar, parser="lalr", maybe_placeholders=True, transformer=TransformOps()
|
| 498 |
+
)
|
| 499 |
+
parser.parse(ttir)
|
| 500 |
+
return functions
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class MemoizeWithCycleCheck:
|
| 504 |
+
def __init__(self, fn):
|
| 505 |
+
self.fn = fn
|
| 506 |
+
self.reset()
|
| 507 |
+
|
| 508 |
+
def __call__(self, functions, fn_name, num_args):
|
| 509 |
+
key = (fn_name, num_args)
|
| 510 |
+
if key not in self.cache:
|
| 511 |
+
self.cache[key] = None
|
| 512 |
+
self.cache[key] = self.fn(functions, fn_name, num_args)
|
| 513 |
+
if self.cache[key] is None:
|
| 514 |
+
raise Exception("Recursion is not supported")
|
| 515 |
+
return self.cache[key]
|
| 516 |
+
|
| 517 |
+
def reset(self):
|
| 518 |
+
self.cache = {}
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
@MemoizeWithCycleCheck
|
| 522 |
+
def analyze_kernel_mutations(functions, fn_name, num_args):
|
| 523 |
+
"""
|
| 524 |
+
Analyzes the graph to detect all sinks from a predefined list of sinks
|
| 525 |
+
by using triton's MemWrite trait list. NOTE: What if triton exposed this?
|
| 526 |
+
From each sink, it traverses the CFG backwards to identify all the input
|
| 527 |
+
pointers that are mutated.
|
| 528 |
+
"""
|
| 529 |
+
# Name of mutation op to mutated parameter indices
|
| 530 |
+
# List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
|
| 531 |
+
# All the OPs that have MemWrite trait.
|
| 532 |
+
# What if Triton exposed this?
|
| 533 |
+
MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]}
|
| 534 |
+
# Ops that we want to bail out on
|
| 535 |
+
UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
|
| 536 |
+
|
| 537 |
+
stack: List[Union[Param, Intermediate]] = []
|
| 538 |
+
visited = set()
|
| 539 |
+
ops = functions[fn_name]
|
| 540 |
+
for op_list in ops.values():
|
| 541 |
+
for op in op_list:
|
| 542 |
+
if op.name in UNKNOWN_OPS:
|
| 543 |
+
raise Exception(
|
| 544 |
+
f"ttir analysis hit an op we do not know how to analyze: {op.name}"
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
if op.name == "tt.call":
|
| 548 |
+
assert op.fn_call_name in functions
|
| 549 |
+
mutations = analyze_kernel_mutations(
|
| 550 |
+
functions, op.fn_call_name, len(op.args)
|
| 551 |
+
)
|
| 552 |
+
stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
|
| 553 |
+
else:
|
| 554 |
+
for idx in MUTATION_OPS.get(op.name, []):
|
| 555 |
+
stack.append(op.args[idx])
|
| 556 |
+
|
| 557 |
+
# The following is an iterative DFS algorithm
|
| 558 |
+
mutated = [False] * num_args
|
| 559 |
+
while stack:
|
| 560 |
+
arg = stack.pop()
|
| 561 |
+
if arg in visited:
|
| 562 |
+
continue
|
| 563 |
+
|
| 564 |
+
visited.add(arg)
|
| 565 |
+
|
| 566 |
+
if isinstance(arg, Param):
|
| 567 |
+
if arg.idx >= num_args:
|
| 568 |
+
# This is an argument defined in the kernel, not passed in
|
| 569 |
+
continue
|
| 570 |
+
mutated[arg.idx] = True
|
| 571 |
+
elif isinstance(arg, Intermediate) and not arg.fake():
|
| 572 |
+
for op in ops[arg]:
|
| 573 |
+
# Skip arguments to load
|
| 574 |
+
if op.name != "tt.load":
|
| 575 |
+
stack.extend(op.args)
|
| 576 |
+
return mutated
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def identify_mutated_tensors(kernel, kwargs):
|
| 580 |
+
"""
|
| 581 |
+
Given a triton kernel and the arguments for this kernel, this function
|
| 582 |
+
1) Retrieves the TTIR converted version of the kernel from Triton's API.
|
| 583 |
+
2) Parses the TTIR and creates a control flow graph
|
| 584 |
+
3) Analyzes the graph to detect all input tensor mutations
|
| 585 |
+
"""
|
| 586 |
+
|
| 587 |
+
ttir_module = None
|
| 588 |
+
functions = None
|
| 589 |
+
try:
|
| 590 |
+
from torch._dynamo import config
|
| 591 |
+
|
| 592 |
+
if not config.optimize_user_defined_triton_kernels:
|
| 593 |
+
raise Exception("optimize_user_defined_triton_kernels is False")
|
| 594 |
+
|
| 595 |
+
ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
|
| 596 |
+
|
| 597 |
+
# extract functions from TTIR
|
| 598 |
+
if hasattr(ttir_module, "walk"):
|
| 599 |
+
# use MLIR bindings exposed by Triton code
|
| 600 |
+
functions = ttir_to_functions(ttir_module)
|
| 601 |
+
else:
|
| 602 |
+
# parse string representation of Triton IR
|
| 603 |
+
functions = parse_ttir(str(ttir_module), kwargs)
|
| 604 |
+
|
| 605 |
+
assert functions is not None
|
| 606 |
+
kernel_name = next(iter(functions.keys()))
|
| 607 |
+
# Triton codegen modifies the name
|
| 608 |
+
assert kernel.fn.__name__ in kernel_name
|
| 609 |
+
# Reset the cache between top level invocations
|
| 610 |
+
# The cache for analyze kernel mutations is mainly used for cycle
|
| 611 |
+
# detection, so each top level invocation needs a clean cache
|
| 612 |
+
analyze_kernel_mutations.reset()
|
| 613 |
+
mutations = analyze_kernel_mutations(
|
| 614 |
+
functions, kernel_name, len(ordered_tensor_names)
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
return [
|
| 618 |
+
ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated
|
| 619 |
+
]
|
| 620 |
+
except Exception as e:
|
| 621 |
+
import traceback
|
| 622 |
+
|
| 623 |
+
warnings.warn(
|
| 624 |
+
"Encountered an exception in identify_mutated_tensors, "
|
| 625 |
+
"assuming every input is mutated:\n"
|
| 626 |
+
"".join(
|
| 627 |
+
traceback.TracebackException.from_exception(e).format() # noqa: G001
|
| 628 |
+
)
|
| 629 |
+
)
|
| 630 |
+
if ttir_module is not None:
|
| 631 |
+
log.debug("TTIR:\n%s", str(ttir_module))
|
| 632 |
+
if functions is not None:
|
| 633 |
+
log.debug("functions:")
|
| 634 |
+
for name, fn in functions.items():
|
| 635 |
+
log.debug("===\t%s\t===", name)
|
| 636 |
+
for ret, ops in fn.items():
|
| 637 |
+
log.debug("%s\t=>\t%s", ret, ops)
|
| 638 |
+
return [key for key, value in kwargs.items() if isinstance(value, Tensor)]
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
###############################################################################
|
| 642 |
+
# Triton Kernel Wrappers
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
# Used for wrapping a Triton Kernel
|
| 646 |
+
class TritonKernelWrapperMutation(HigherOrderOperator):
|
| 647 |
+
def __init__(self):
|
| 648 |
+
super().__init__("triton_kernel_wrapper_mutation")
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
# Used for wrapping a Triton Kernel in a functional manner
|
| 655 |
+
class TritonKernelWrapperFunctional(HigherOrderOperator):
|
| 656 |
+
def __init__(self):
|
| 657 |
+
super().__init__("triton_kernel_wrapper_functional")
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 664 |
+
def triton_kernel_wrapper_mutation_dense(*, kernel_idx, grid, kwargs):
|
| 665 |
+
from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
|
| 666 |
+
|
| 667 |
+
kernel = kernel_side_table.get_kernel(kernel_idx)
|
| 668 |
+
|
| 669 |
+
if len(grid) == 1:
|
| 670 |
+
grid_fn = grid[0]
|
| 671 |
+
else:
|
| 672 |
+
fn_name, code = user_defined_kernel_grid_fn_code(
|
| 673 |
+
kernel.fn.__name__, kernel.configs, grid
|
| 674 |
+
)
|
| 675 |
+
namespace: Dict[str, Any] = {}
|
| 676 |
+
exec(code, namespace)
|
| 677 |
+
grid_fn = namespace[fn_name]
|
| 678 |
+
|
| 679 |
+
kernel[grid_fn](**kwargs)
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode)
|
| 683 |
+
def triton_kernel_wrapper_mutation_fake_tensor_mode(mode, *, kernel_idx, grid, kwargs):
|
| 684 |
+
with mode:
|
| 685 |
+
return None
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args):
|
| 689 |
+
with disable_proxy_modes_tracing():
|
| 690 |
+
out = func_overload(**node_args)
|
| 691 |
+
|
| 692 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
| 693 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 694 |
+
"call_function",
|
| 695 |
+
func_overload,
|
| 696 |
+
(),
|
| 697 |
+
proxy_args,
|
| 698 |
+
name=func_overload.__name__ + "_proxy",
|
| 699 |
+
)
|
| 700 |
+
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
|
| 704 |
+
def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
|
| 705 |
+
mode, *, kernel_idx, grid, kwargs
|
| 706 |
+
):
|
| 707 |
+
if mode.enable_tracing:
|
| 708 |
+
trace_triton_kernel_wrapper(
|
| 709 |
+
mode,
|
| 710 |
+
triton_kernel_wrapper_mutation,
|
| 711 |
+
{"kernel_idx": kernel_idx, "grid": grid, "kwargs": kwargs},
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
triton_kernel_wrapper_mutation(kernel_idx=kernel_idx, grid=grid, kwargs=kwargs)
|
| 715 |
+
|
| 716 |
+
return None
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
@triton_kernel_wrapper_mutation.py_functionalize_impl
|
| 720 |
+
def triton_kernel_wrapper_mutation_functionalize(ctx, kernel_idx, grid, kwargs):
|
| 721 |
+
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
| 722 |
+
kernel = kernel_side_table.get_kernel(kernel_idx)
|
| 723 |
+
# TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
|
| 724 |
+
# other, and one gets mutated in kernel, and later another gets mutated,
|
| 725 |
+
# they are no longer equal. Fix this by graph breaking on this condition
|
| 726 |
+
# earlier in dynamo.
|
| 727 |
+
tensors_to_clone = identify_mutated_tensors(kernel, unwrapped_kwargs)
|
| 728 |
+
with ctx.redispatch_to_next():
|
| 729 |
+
unwrapped_outputs = triton_kernel_wrapper_functional(
|
| 730 |
+
kernel_idx=kernel_idx,
|
| 731 |
+
grid=grid,
|
| 732 |
+
kwargs=unwrapped_kwargs,
|
| 733 |
+
tensors_to_clone=tensors_to_clone,
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys()))
|
| 737 |
+
for key, output_arg in unwrapped_outputs.items():
|
| 738 |
+
if not isinstance(output_arg, Tensor):
|
| 739 |
+
continue
|
| 740 |
+
input_arg = kwargs[key]
|
| 741 |
+
assert isinstance(input_arg, Tensor)
|
| 742 |
+
|
| 743 |
+
ctx.replace(input_arg, output_arg)
|
| 744 |
+
# indicate that above replace is hidden from autograd
|
| 745 |
+
ctx.mark_mutation_hidden_from_autograd(input_arg)
|
| 746 |
+
ctx.commit_update(input_arg)
|
| 747 |
+
ctx.sync(input_arg)
|
| 748 |
+
# sync calls replace_ under the hood, so again indicate that
|
| 749 |
+
# this indirect replace is hidden from autograd
|
| 750 |
+
ctx.mark_mutation_hidden_from_autograd(input_arg)
|
| 751 |
+
return None
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 755 |
+
def triton_kernel_wrapper_functional_dense(
|
| 756 |
+
*, kernel_idx, grid, kwargs, tensors_to_clone
|
| 757 |
+
):
|
| 758 |
+
# TODO(oulgen): For performance reasons, we want to ensure that these
|
| 759 |
+
# `clone_preserve_strides` calls are never executed at runtime
|
| 760 |
+
# (inductor should always optimize them away).
|
| 761 |
+
# Requires https://github.com/pytorch/pytorch/issues/109240
|
| 762 |
+
kwargs = {
|
| 763 |
+
key: (clone_preserve_strides(val) if key in tensors_to_clone else val)
|
| 764 |
+
for key, val in kwargs.items()
|
| 765 |
+
}
|
| 766 |
+
triton_kernel_wrapper_mutation(kernel_idx=kernel_idx, grid=grid, kwargs=kwargs)
|
| 767 |
+
return {key: val for key, val in kwargs.items() if key in tensors_to_clone}
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
@triton_kernel_wrapper_functional.py_impl(FakeTensorMode)
|
| 771 |
+
def triton_kernel_wrapper_functional_fake_tensor_mode(
|
| 772 |
+
mode, *, kernel_idx, grid, kwargs, tensors_to_clone
|
| 773 |
+
):
|
| 774 |
+
# TODO(oulgen): For performance reasons, we want to ensure that these
|
| 775 |
+
# `clone_preserve_strides` calls are never executed at runtime
|
| 776 |
+
# (inductor should always optimize them away).
|
| 777 |
+
# Requires https://github.com/pytorch/pytorch/issues/109240
|
| 778 |
+
with mode:
|
| 779 |
+
return {
|
| 780 |
+
key: clone_preserve_strides(val)
|
| 781 |
+
for key, val in kwargs.items()
|
| 782 |
+
if key in tensors_to_clone
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode)
|
| 787 |
+
def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
|
| 788 |
+
mode, *, kernel_idx, grid, kwargs, tensors_to_clone
|
| 789 |
+
):
|
| 790 |
+
if mode.enable_tracing:
|
| 791 |
+
return trace_triton_kernel_wrapper(
|
| 792 |
+
mode,
|
| 793 |
+
triton_kernel_wrapper_functional,
|
| 794 |
+
{
|
| 795 |
+
"kernel_idx": kernel_idx,
|
| 796 |
+
"grid": grid,
|
| 797 |
+
"kwargs": kwargs,
|
| 798 |
+
"tensors_to_clone": tensors_to_clone,
|
| 799 |
+
},
|
| 800 |
+
)
|
| 801 |
+
else:
|
| 802 |
+
return triton_kernel_wrapper_functional(
|
| 803 |
+
kernel_idx=kernel_idx,
|
| 804 |
+
grid=grid,
|
| 805 |
+
kwargs=kwargs,
|
| 806 |
+
tensors_to_clone=tensors_to_clone,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
@triton_kernel_wrapper_functional.py_functionalize_impl
|
| 811 |
+
def triton_kernel_wrapper_functional_functionalize(
|
| 812 |
+
ctx, kernel_idx, grid, kwargs, tensors_to_clone
|
| 813 |
+
):
|
| 814 |
+
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
| 815 |
+
with ctx.redispatch_to_next():
|
| 816 |
+
outputs = triton_kernel_wrapper_functional(
|
| 817 |
+
kernel_idx=kernel_idx,
|
| 818 |
+
grid=grid,
|
| 819 |
+
kwargs=unwrapped_kwargs,
|
| 820 |
+
tensors_to_clone=tensors_to_clone,
|
| 821 |
+
)
|
| 822 |
+
return ctx.wrap_tensors(outputs)
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
|
| 826 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
|
| 827 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView)
|
| 828 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect)
|
| 829 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
|
| 830 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
|
| 831 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA)
|
| 832 |
+
triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU)
|
| 833 |
+
|
| 834 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
|
| 835 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
|
| 836 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView)
|
| 837 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect)
|
| 838 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
|
| 839 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
|
| 840 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
|
| 841 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
|
| 842 |
+
triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/while_loop.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.utils._pytree as pytree
|
| 3 |
+
|
| 4 |
+
from torch._C import DispatchKey
|
| 5 |
+
|
| 6 |
+
from torch._higher_order_ops.utils import (
|
| 7 |
+
_has_potential_branch_input_alias,
|
| 8 |
+
_has_potential_branch_input_mutation,
|
| 9 |
+
_set_compilation_env,
|
| 10 |
+
autograd_not_implemented,
|
| 11 |
+
reenter_make_fx,
|
| 12 |
+
UnsupportedAliasMutationException,
|
| 13 |
+
)
|
| 14 |
+
from torch._ops import HigherOrderOperator
|
| 15 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 16 |
+
from torch.fx.experimental.proxy_tensor import (
|
| 17 |
+
disable_proxy_modes_tracing,
|
| 18 |
+
ProxyTorchDispatchMode,
|
| 19 |
+
track_tensor_tree,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WhileLoopOp(HigherOrderOperator):
|
| 24 |
+
def __call__(self, cond_fn, body_fn, operands):
|
| 25 |
+
if not isinstance(cond_fn, torch.fx.GraphModule) or not isinstance(
|
| 26 |
+
body_fn, torch.fx.GraphModule
|
| 27 |
+
):
|
| 28 |
+
raise RuntimeError(
|
| 29 |
+
"cond_fn and body_fn must be torch.fx.GraphModule, got "
|
| 30 |
+
f"{type(cond_fn)} and {type(body_fn)}"
|
| 31 |
+
)
|
| 32 |
+
if not isinstance(operands, tuple):
|
| 33 |
+
raise RuntimeError("operands must be a tuple, got " f"{type(operands)}")
|
| 34 |
+
if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in operands):
|
| 35 |
+
raise RuntimeError(
|
| 36 |
+
"operands must be a tuple of tensors, ints, floats, or bools, got "
|
| 37 |
+
f"{operands}"
|
| 38 |
+
)
|
| 39 |
+
return super().__call__(cond_fn, body_fn, operands)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
while_loop_op = HigherOrderOperator("while_loop")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def while_loop(cond_fn, body_fn, operands):
|
| 46 |
+
r"""
|
| 47 |
+
Run body_fn(*operands) while cond_fn(*operands) returns a True scalar tensor. Returns the output of body_fn or
|
| 48 |
+
initial operands.
|
| 49 |
+
|
| 50 |
+
.. warning::
|
| 51 |
+
`torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and
|
| 52 |
+
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
|
| 53 |
+
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
|
| 54 |
+
|
| 55 |
+
`while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export.
|
| 56 |
+
|
| 57 |
+
`while_loop` is equivalent to the following:
|
| 58 |
+
|
| 59 |
+
def while_loop(cond_fn, body_fn, operands):
|
| 60 |
+
val = operands
|
| 61 |
+
while cond_fn(*val):
|
| 62 |
+
val = body_fn(*val)
|
| 63 |
+
return val
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
cond_fn (Callable): A callable function that returns a boolean Scalar tensor.
|
| 67 |
+
|
| 68 |
+
body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors
|
| 69 |
+
|
| 70 |
+
operands (Tuple of possibly nested dict/list/tuple of tensors): A tuple of inputs to cond_fn and body_fn. It's also
|
| 71 |
+
the initial value of states that are carried across iterations.
|
| 72 |
+
|
| 73 |
+
Example:
|
| 74 |
+
|
| 75 |
+
def cond_fn(iter, x):
|
| 76 |
+
return iter.sum() < 10
|
| 77 |
+
|
| 78 |
+
def body_fn(iter, x):
|
| 79 |
+
return iter + 1, x.sin()
|
| 80 |
+
|
| 81 |
+
while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4)))
|
| 82 |
+
|
| 83 |
+
Restrictions:
|
| 84 |
+
|
| 85 |
+
- body_fn must return tensors with the same metadata (e.g.shape, dtype) as inputs.
|
| 86 |
+
|
| 87 |
+
- body_fn and cond_fn must not in-place mutate the operands. A clone before the mutation is required.
|
| 88 |
+
|
| 89 |
+
- body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn.
|
| 90 |
+
|
| 91 |
+
- body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required.
|
| 92 |
+
|
| 93 |
+
.. warning::
|
| 94 |
+
Temporal Limitations:
|
| 95 |
+
|
| 96 |
+
- 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
if torch.compiler.is_dynamo_compiling():
|
| 100 |
+
return while_loop_op(cond_fn, body_fn, operands)
|
| 101 |
+
|
| 102 |
+
def _validate_input(cond_fn, body_fn, operands):
|
| 103 |
+
if not callable(cond_fn) or not callable(body_fn):
|
| 104 |
+
raise RuntimeError("Expect cond_fn and body_fn to be callbale.")
|
| 105 |
+
|
| 106 |
+
if not isinstance(operands, (tuple, list)) or pytree.tree_any(
|
| 107 |
+
lambda t: not isinstance(t, torch.Tensor), operands
|
| 108 |
+
):
|
| 109 |
+
raise RuntimeError(
|
| 110 |
+
"Expect operands to be a tuple of possibly nested dict/list/tuple that only"
|
| 111 |
+
f"consists of tensor leaves, but got {operands}."
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
_validate_input(cond_fn, body_fn, operands)
|
| 115 |
+
|
| 116 |
+
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
|
| 117 |
+
return torch.compile(while_loop_op, backend="eager", fullgraph=True)(
|
| 118 |
+
cond_fn, body_fn, operands
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
|
| 123 |
+
def while_loop_dense(cond_fn, body_fn, operands):
|
| 124 |
+
init_val = operands
|
| 125 |
+
|
| 126 |
+
def _is_boolean_scalar_tensor(pred):
|
| 127 |
+
return (
|
| 128 |
+
isinstance(pred, torch.Tensor)
|
| 129 |
+
and pred.size() == torch.Size([])
|
| 130 |
+
and pred.dtype == torch.bool
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if not isinstance(operands, tuple):
|
| 134 |
+
raise RuntimeError(f"operands must be a tuple but got {type(operands)}")
|
| 135 |
+
|
| 136 |
+
while pred := cond_fn(*init_val):
|
| 137 |
+
if not _is_boolean_scalar_tensor(pred):
|
| 138 |
+
raise RuntimeError(
|
| 139 |
+
f"cond_fn must return a boolean scalar tensor but got {pred}"
|
| 140 |
+
)
|
| 141 |
+
out = body_fn(*init_val)
|
| 142 |
+
assert isinstance(
|
| 143 |
+
out, tuple
|
| 144 |
+
), f"body_fn should return a tuple but got {type(out)}"
|
| 145 |
+
assert len(out) == len(
|
| 146 |
+
init_val
|
| 147 |
+
), "body_fn should return the same number of elements as operands"
|
| 148 |
+
init_val = out
|
| 149 |
+
return init_val
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
while_loop_op.py_impl(DispatchKey.Autograd)(
|
| 153 |
+
autograd_not_implemented(while_loop_op, deferred_error=True)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@while_loop_op.py_impl(ProxyTorchDispatchMode)
|
| 158 |
+
def while_loop_tracing(mode, cond_fn, body_fn, operands):
|
| 159 |
+
def _trace_while_loop(proxy_mode, while_loop_op, cond_fn, body_fn, operands):
|
| 160 |
+
pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
|
| 161 |
+
with disable_proxy_modes_tracing():
|
| 162 |
+
cond_graph = reenter_make_fx(cond_fn, pre_dispatch)(*operands)
|
| 163 |
+
body_graph = reenter_make_fx(body_fn, pre_dispatch)(*operands)
|
| 164 |
+
|
| 165 |
+
next_name = None
|
| 166 |
+
i = 0
|
| 167 |
+
while not next_name:
|
| 168 |
+
candidate = f"while_loop_cond_graph_{i}"
|
| 169 |
+
if hasattr(proxy_mode.tracer.root, candidate):
|
| 170 |
+
i += 1
|
| 171 |
+
else:
|
| 172 |
+
next_name = candidate
|
| 173 |
+
cond_graph_name = next_name
|
| 174 |
+
body_graph_name = f"while_loop_body_graph_{i}"
|
| 175 |
+
assert not hasattr(proxy_mode.tracer.root, body_graph_name)
|
| 176 |
+
|
| 177 |
+
proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
|
| 178 |
+
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
|
| 179 |
+
|
| 180 |
+
args = (cond_graph, body_graph, operands)
|
| 181 |
+
|
| 182 |
+
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
| 183 |
+
|
| 184 |
+
out_proxy = proxy_mode.tracer.create_proxy(
|
| 185 |
+
"call_function", while_loop_op, proxy_args, {}, name="while_loop"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# body_fn return output with the same pytree and tensor meta data as operands
|
| 189 |
+
# so we could just return the output after one iteration.
|
| 190 |
+
out = body_fn(*operands)
|
| 191 |
+
return track_tensor_tree(
|
| 192 |
+
out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if mode.enable_tracing:
|
| 196 |
+
return _trace_while_loop(mode, while_loop_op, cond_fn, body_fn, operands)
|
| 197 |
+
else:
|
| 198 |
+
return while_loop_op(cond_fn, body_fn, operands)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@while_loop_op.py_impl(FakeTensorMode)
|
| 202 |
+
def while_loop_fake_tensor_mode(mode, cond_fn, body_fn, operands):
|
| 203 |
+
return body_fn(*operands)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@while_loop_op.py_functionalize_impl
|
| 207 |
+
def while_loop_func(ctx, cond_fn, body_fn, operands):
|
| 208 |
+
unwrapped_operands = ctx.unwrap_tensors(operands)
|
| 209 |
+
with ctx.redispatch_to_next() as m:
|
| 210 |
+
functional_cond_fn = ctx.functionalize(cond_fn)
|
| 211 |
+
functional_body_fn = ctx.functionalize(body_fn)
|
| 212 |
+
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
|
| 213 |
+
for fn, fn_name in [
|
| 214 |
+
(functional_cond_fn, "cond_fn"),
|
| 215 |
+
(functional_body_fn, "body_fn"),
|
| 216 |
+
]:
|
| 217 |
+
if _has_potential_branch_input_mutation(
|
| 218 |
+
fn, unwrapped_operands, pre_dispatch=pre_dispatch
|
| 219 |
+
):
|
| 220 |
+
raise UnsupportedAliasMutationException(
|
| 221 |
+
f"torch.while_loop's {fn_name} might be modifying the input!"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
for fn in [functional_cond_fn, functional_body_fn]:
|
| 225 |
+
if _has_potential_branch_input_alias(
|
| 226 |
+
fn, unwrapped_operands, pre_dispatch=pre_dispatch
|
| 227 |
+
):
|
| 228 |
+
raise UnsupportedAliasMutationException(
|
| 229 |
+
f"torch.while_loop's {fn_name} might be aliasing the input!"
|
| 230 |
+
)
|
| 231 |
+
ret = while_loop_op(functional_cond_fn, functional_body_fn, unwrapped_operands)
|
| 232 |
+
return ctx.wrap_tensors(ret)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mkl/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def is_available():
|
| 5 |
+
r"""Return whether PyTorch is built with MKL support."""
|
| 6 |
+
return torch._C.has_mkl
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
VERBOSE_OFF = 0
|
| 10 |
+
VERBOSE_ON = 1
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class verbose:
|
| 14 |
+
"""
|
| 15 |
+
On-demand oneMKL verbosing functionality.
|
| 16 |
+
|
| 17 |
+
To make it easier to debug performance issues, oneMKL can dump verbose
|
| 18 |
+
messages containing execution information like duration while executing
|
| 19 |
+
the kernel. The verbosing functionality can be invoked via an environment
|
| 20 |
+
variable named `MKL_VERBOSE`. However, this methodology dumps messages in
|
| 21 |
+
all steps. Those are a large amount of verbose messages. Moreover, for
|
| 22 |
+
investigating the performance issues, generally taking verbose messages
|
| 23 |
+
for one single iteration is enough. This on-demand verbosing functionality
|
| 24 |
+
makes it possible to control scope for verbose message dumping. In the
|
| 25 |
+
following example, verbose messages will be dumped out for the second
|
| 26 |
+
inference only.
|
| 27 |
+
|
| 28 |
+
.. highlight:: python
|
| 29 |
+
.. code-block:: python
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
model(data)
|
| 33 |
+
with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON):
|
| 34 |
+
model(data)
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
level: Verbose level
|
| 38 |
+
- ``VERBOSE_OFF``: Disable verbosing
|
| 39 |
+
- ``VERBOSE_ON``: Enable verbosing
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, enable):
|
| 43 |
+
self.enable = enable
|
| 44 |
+
|
| 45 |
+
def __enter__(self):
|
| 46 |
+
if self.enable == VERBOSE_OFF:
|
| 47 |
+
return
|
| 48 |
+
st = torch._C._verbose.mkl_set_verbose(self.enable)
|
| 49 |
+
assert (
|
| 50 |
+
st
|
| 51 |
+
), "Failed to set MKL into verbose mode. Please consider to disable this verbose scope."
|
| 52 |
+
return self
|
| 53 |
+
|
| 54 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 55 |
+
torch._C._verbose.mkl_set_verbose(VERBOSE_OFF)
|
| 56 |
+
return False
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mps/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.84 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.19 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/openmp/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def is_available():
|
| 5 |
+
r"""Return whether PyTorch is built with OpenMP support."""
|
| 6 |
+
return torch._C.has_openmp
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (723 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (291 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/autograd/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def is_available():
|
| 7 |
+
return hasattr(torch._C, "_dist_autograd_init")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if is_available() and not torch._C._dist_autograd_init():
|
| 11 |
+
raise RuntimeError("Failed to initialize torch.distributed.autograd")
|
| 12 |
+
|
| 13 |
+
if is_available():
|
| 14 |
+
from torch._C._distributed_autograd import (
|
| 15 |
+
get_gradients,
|
| 16 |
+
backward,
|
| 17 |
+
_init,
|
| 18 |
+
_new_context,
|
| 19 |
+
_release_context,
|
| 20 |
+
_get_max_id,
|
| 21 |
+
_is_valid_context,
|
| 22 |
+
_retrieve_context,
|
| 23 |
+
_current_context,
|
| 24 |
+
_get_debug_info,
|
| 25 |
+
DistAutogradContext,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class context:
|
| 30 |
+
'''
|
| 31 |
+
Context object to wrap forward and backward passes when using
|
| 32 |
+
distributed autograd. The ``context_id`` generated in the ``with``
|
| 33 |
+
statement is required to uniquely identify a distributed backward pass
|
| 34 |
+
on all workers. Each worker stores metadata associated with this
|
| 35 |
+
``context_id``, which is required to correctly execute a distributed
|
| 36 |
+
autograd pass.
|
| 37 |
+
|
| 38 |
+
Example::
|
| 39 |
+
>>> # xdoctest: +SKIP
|
| 40 |
+
>>> import torch.distributed.autograd as dist_autograd
|
| 41 |
+
>>> with dist_autograd.context() as context_id:
|
| 42 |
+
>>> t1 = torch.rand((3, 3), requires_grad=True)
|
| 43 |
+
>>> t2 = torch.rand((3, 3), requires_grad=True)
|
| 44 |
+
>>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
|
| 45 |
+
>>> dist_autograd.backward(context_id, [loss])
|
| 46 |
+
'''
|
| 47 |
+
def __enter__(self):
|
| 48 |
+
self.autograd_context = _new_context()
|
| 49 |
+
return self.autograd_context._context_id()
|
| 50 |
+
|
| 51 |
+
def __exit__(self, type, value, traceback):
|
| 52 |
+
_release_context(self.autograd_context._context_id())
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/api.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
from dataclasses import asdict, dataclass, field
|
| 11 |
+
from enum import Enum
|
| 12 |
+
from typing import Dict, Union, Optional
|
| 13 |
+
|
| 14 |
+
__all__ = ['EventSource', 'Event', 'NodeState', 'RdzvEvent']
|
| 15 |
+
|
| 16 |
+
EventMetadataValue = Union[str, int, float, bool, None]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class EventSource(str, Enum):
|
| 20 |
+
"""Known identifiers of the event producers."""
|
| 21 |
+
|
| 22 |
+
AGENT = "AGENT"
|
| 23 |
+
WORKER = "WORKER"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class Event:
|
| 28 |
+
"""
|
| 29 |
+
The class represents the generic event that occurs during the torchelastic job execution.
|
| 30 |
+
|
| 31 |
+
The event can be any kind of meaningful action.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
name: event name.
|
| 35 |
+
source: the event producer, e.g. agent or worker
|
| 36 |
+
timestamp: timestamp in milliseconds when event occurred.
|
| 37 |
+
metadata: additional data that is associated with the event.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
name: str
|
| 41 |
+
source: EventSource
|
| 42 |
+
timestamp: int = 0
|
| 43 |
+
metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
|
| 44 |
+
|
| 45 |
+
def __str__(self):
|
| 46 |
+
return self.serialize()
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def deserialize(data: Union[str, "Event"]) -> "Event":
|
| 50 |
+
if isinstance(data, Event):
|
| 51 |
+
return data
|
| 52 |
+
if isinstance(data, str):
|
| 53 |
+
data_dict = json.loads(data)
|
| 54 |
+
data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
|
| 55 |
+
return Event(**data_dict)
|
| 56 |
+
|
| 57 |
+
def serialize(self) -> str:
|
| 58 |
+
return json.dumps(asdict(self))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class NodeState(str, Enum):
|
| 62 |
+
"""The states that a node can be in rendezvous."""
|
| 63 |
+
|
| 64 |
+
INIT = "INIT"
|
| 65 |
+
RUNNING = "RUNNING"
|
| 66 |
+
SUCCEEDED = "SUCCEEDED"
|
| 67 |
+
FAILED = "FAILED"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class RdzvEvent:
|
| 72 |
+
"""
|
| 73 |
+
Dataclass to represent any rendezvous event.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
name: Event name. (E.g. Current action being performed)
|
| 77 |
+
run_id: The run id of the rendezvous
|
| 78 |
+
message: The message describing the event
|
| 79 |
+
hostname: Hostname of the node
|
| 80 |
+
pid: The process id of the node
|
| 81 |
+
node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED)
|
| 82 |
+
master_endpoint: The master endpoint for the rendezvous store, if known
|
| 83 |
+
rank: The rank of the node, if known
|
| 84 |
+
local_id: The local_id of the node, if defined in dynamic_rendezvous.py
|
| 85 |
+
error_trace: Error stack trace, if this is an error event.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
name: str
|
| 89 |
+
run_id: str
|
| 90 |
+
message: str
|
| 91 |
+
hostname: str
|
| 92 |
+
pid: int
|
| 93 |
+
node_state: NodeState
|
| 94 |
+
master_endpoint: str = ""
|
| 95 |
+
rank: Optional[int] = None
|
| 96 |
+
local_id: Optional[int] = None
|
| 97 |
+
error_trace: str = ""
|
| 98 |
+
|
| 99 |
+
def __str__(self):
|
| 100 |
+
return self.serialize()
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent":
|
| 104 |
+
if isinstance(data, RdzvEvent):
|
| 105 |
+
return data
|
| 106 |
+
if isinstance(data, str):
|
| 107 |
+
data_dict = json.loads(data)
|
| 108 |
+
data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
|
| 109 |
+
return RdzvEvent(**data_dict)
|
| 110 |
+
|
| 111 |
+
def serialize(self) -> str:
|
| 112 |
+
return json.dumps(asdict(self))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/handlers.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_log_handlers: Dict[str, logging.Handler] = {
|
| 14 |
+
"console": logging.StreamHandler(),
|
| 15 |
+
"dynamic_rendezvous": logging.NullHandler(),
|
| 16 |
+
"null": logging.NullHandler(),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_logging_handler(destination: str = "null") -> logging.Handler:
|
| 21 |
+
global _log_handlers
|
| 22 |
+
return _log_handlers[destination]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
import abc
|
| 10 |
+
import time
|
| 11 |
+
import warnings
|
| 12 |
+
from collections import namedtuple
|
| 13 |
+
from functools import wraps
|
| 14 |
+
from typing import Dict, Optional
|
| 15 |
+
|
| 16 |
+
__all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream',
|
| 17 |
+
'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms',
|
| 18 |
+
'MetricData']
|
| 19 |
+
|
| 20 |
+
MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MetricsConfig:
|
| 24 |
+
__slots__ = ["params"]
|
| 25 |
+
|
| 26 |
+
def __init__(self, params: Optional[Dict[str, str]] = None):
|
| 27 |
+
self.params = params
|
| 28 |
+
if self.params is None:
|
| 29 |
+
self.params = {}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MetricHandler(abc.ABC):
|
| 33 |
+
@abc.abstractmethod
|
| 34 |
+
def emit(self, metric_data: MetricData):
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ConsoleMetricHandler(MetricHandler):
|
| 39 |
+
def emit(self, metric_data: MetricData):
|
| 40 |
+
print(
|
| 41 |
+
f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class NullMetricHandler(MetricHandler):
|
| 46 |
+
def emit(self, metric_data: MetricData):
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class MetricStream:
|
| 51 |
+
def __init__(self, group_name: str, handler: MetricHandler):
|
| 52 |
+
self.group_name = group_name
|
| 53 |
+
self.handler = handler
|
| 54 |
+
|
| 55 |
+
def add_value(self, metric_name: str, metric_value: int):
|
| 56 |
+
self.handler.emit(
|
| 57 |
+
MetricData(time.time(), self.group_name, metric_name, metric_value)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
_metrics_map: Dict[str, MetricHandler] = {}
|
| 62 |
+
_default_metrics_handler: MetricHandler = NullMetricHandler()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# pyre-fixme[9]: group has type `str`; used as `None`.
|
| 66 |
+
def configure(handler: MetricHandler, group: Optional[str] = None):
|
| 67 |
+
if group is None:
|
| 68 |
+
global _default_metrics_handler
|
| 69 |
+
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
|
| 70 |
+
# as `MetricHandler`.
|
| 71 |
+
_default_metrics_handler = handler
|
| 72 |
+
else:
|
| 73 |
+
_metrics_map[group] = handler
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def getStream(group: str):
|
| 77 |
+
if group in _metrics_map:
|
| 78 |
+
handler = _metrics_map[group]
|
| 79 |
+
else:
|
| 80 |
+
handler = _default_metrics_handler
|
| 81 |
+
return MetricStream(group, handler)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _get_metric_name(fn):
|
| 85 |
+
qualname = fn.__qualname__
|
| 86 |
+
split = qualname.split(".")
|
| 87 |
+
if len(split) == 1:
|
| 88 |
+
module = fn.__module__
|
| 89 |
+
if module:
|
| 90 |
+
return module.split(".")[-1] + "." + split[0]
|
| 91 |
+
else:
|
| 92 |
+
return split[0]
|
| 93 |
+
else:
|
| 94 |
+
return qualname
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def prof(fn=None, group: str = "torchelastic"):
|
| 98 |
+
r"""
|
| 99 |
+
@profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates.
|
| 100 |
+
|
| 101 |
+
The metric name defaults to the qualified name (``class_name.def_name``) of the function.
|
| 102 |
+
If the function does not belong to a class, it uses the leaf module name instead.
|
| 103 |
+
|
| 104 |
+
Usage
|
| 105 |
+
|
| 106 |
+
::
|
| 107 |
+
|
| 108 |
+
@metrics.prof
|
| 109 |
+
def x():
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
@metrics.prof(group="agent")
|
| 113 |
+
def y():
|
| 114 |
+
pass
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def wrap(f):
|
| 118 |
+
@wraps(f)
|
| 119 |
+
def wrapper(*args, **kwargs):
|
| 120 |
+
key = _get_metric_name(f)
|
| 121 |
+
try:
|
| 122 |
+
start = time.time()
|
| 123 |
+
result = f(*args, **kwargs)
|
| 124 |
+
put_metric(f"{key}.success", 1, group)
|
| 125 |
+
except Exception:
|
| 126 |
+
put_metric(f"{key}.failure", 1, group)
|
| 127 |
+
raise
|
| 128 |
+
finally:
|
| 129 |
+
put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined]
|
| 130 |
+
return result
|
| 131 |
+
|
| 132 |
+
return wrapper
|
| 133 |
+
|
| 134 |
+
if fn:
|
| 135 |
+
return wrap(fn)
|
| 136 |
+
else:
|
| 137 |
+
return wrap
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def profile(group=None):
|
| 141 |
+
"""
|
| 142 |
+
@profile decorator adds latency and success/failure metrics to any given function.
|
| 143 |
+
|
| 144 |
+
Usage
|
| 145 |
+
|
| 146 |
+
::
|
| 147 |
+
|
| 148 |
+
@metrics.profile("my_metric_group")
|
| 149 |
+
def some_function(<arguments>):
|
| 150 |
+
"""
|
| 151 |
+
warnings.warn("Deprecated, use @prof instead", DeprecationWarning)
|
| 152 |
+
|
| 153 |
+
def wrap(func):
|
| 154 |
+
@wraps(func)
|
| 155 |
+
def wrapper(*args, **kwargs):
|
| 156 |
+
try:
|
| 157 |
+
start_time = time.time()
|
| 158 |
+
result = func(*args, **kwargs)
|
| 159 |
+
publish_metric(group, f"{func.__name__}.success", 1)
|
| 160 |
+
except Exception:
|
| 161 |
+
publish_metric(group, f"{func.__name__}.failure", 1)
|
| 162 |
+
raise
|
| 163 |
+
finally:
|
| 164 |
+
publish_metric(
|
| 165 |
+
group,
|
| 166 |
+
f"{func.__name__}.duration.ms",
|
| 167 |
+
get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]
|
| 168 |
+
)
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
return wrapper
|
| 172 |
+
|
| 173 |
+
return wrap
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
|
| 177 |
+
"""
|
| 178 |
+
Publish a metric data point.
|
| 179 |
+
|
| 180 |
+
Usage
|
| 181 |
+
|
| 182 |
+
::
|
| 183 |
+
|
| 184 |
+
put_metric("metric_name", 1)
|
| 185 |
+
put_metric("metric_name", 1, "metric_group_name")
|
| 186 |
+
"""
|
| 187 |
+
getStream(metric_group).add_value(metric_name, metric_value)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def publish_metric(metric_group: str, metric_name: str, metric_value: int):
|
| 191 |
+
warnings.warn(
|
| 192 |
+
"Deprecated, use put_metric(metric_group)(metric_name, metric_value) instead"
|
| 193 |
+
)
|
| 194 |
+
metric_stream = getStream(metric_group)
|
| 195 |
+
metric_stream.add_value(metric_name, metric_value)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def get_elapsed_time_ms(start_time_in_seconds: float):
|
| 199 |
+
"""Return the elapsed time in millis from the given start time."""
|
| 200 |
+
end_time = time.time()
|
| 201 |
+
return int((end_time - start_time_in_seconds) * 1000)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-311.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
Each host in a distributed PyTorch job runs with a single TorchElastic agent,
|
| 11 |
+
and multiple workers (as children processes of the TorchElastic agent).
|
| 12 |
+
Since the workers are user-provided (your PyTorch script/job), TorchElastic
|
| 13 |
+
has a way to propagate errors on the trainers through the agent and up to the
|
| 14 |
+
scheduler, which ultimately informs the end-user about the state of the job
|
| 15 |
+
and applies any retry policies.
|
| 16 |
+
|
| 17 |
+
TorchElastic categorizes errors into 3 categories:
|
| 18 |
+
|
| 19 |
+
+----------------+----------------+--------------------------------------------------------------+
|
| 20 |
+
| Category | Sub-Category | Description |
|
| 21 |
+
+================+================+==============================================================+
|
| 22 |
+
| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) |
|
| 23 |
+
| +----------------+--------------------------------------------------------------+
|
| 24 |
+
| | Worker Failure | any failures on the worker child process |
|
| 25 |
+
+----------------+----------------+--------------------------------------------------------------+
|
| 26 |
+
| Platform Error | n/a | failures caused by the agent |
|
| 27 |
+
+----------------+----------------+--------------------------------------------------------------+
|
| 28 |
+
| Infra Error | n/a | failures outside the domain of the agent and workers |
|
| 29 |
+
| | | (e.g. host failures) |
|
| 30 |
+
+----------------+----------------+--------------------------------------------------------------+
|
| 31 |
+
|
| 32 |
+
All errors other than "Worker Failure" are either raised canonically from the
|
| 33 |
+
agent process or implicitly or explicitly crash the agent process. So the
|
| 34 |
+
standard language (python) provided exception handling strategies apply.
|
| 35 |
+
|
| 36 |
+
Worker Failures are special because the exception/failure originates on a different
|
| 37 |
+
process from the agent so the error needs to be propagated inter-process
|
| 38 |
+
(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process).
|
| 39 |
+
|
| 40 |
+
TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes`
|
| 41 |
+
to launch the workers which has a simple file based inter-process error propagation
|
| 42 |
+
built-in.
|
| 43 |
+
|
| 44 |
+
Any function or binary entrypoint decorated with :func:`record`
|
| 45 |
+
will write uncaught exceptions (with the trace information) to a file specified by the
|
| 46 |
+
environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent)
|
| 47 |
+
sets this env var on each child it launches, then aggregates the error files for all
|
| 48 |
+
children, and propagates the one with the **smallest** timestamp (e.g. the **first** error).
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
import json
|
| 52 |
+
import os
|
| 53 |
+
import signal
|
| 54 |
+
import socket
|
| 55 |
+
import time
|
| 56 |
+
import warnings
|
| 57 |
+
from dataclasses import dataclass, field
|
| 58 |
+
from datetime import datetime
|
| 59 |
+
from functools import wraps
|
| 60 |
+
from string import Template
|
| 61 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
| 62 |
+
|
| 63 |
+
from torch.distributed.elastic.utils.logging import get_logger
|
| 64 |
+
|
| 65 |
+
from .error_handler import ErrorHandler # noqa: F401
|
| 66 |
+
from .handlers import get_error_handler # noqa: F401
|
| 67 |
+
|
| 68 |
+
__all__ = ["ProcessFailure", "ChildFailedError", "record", "ErrorHandler", "get_error_handler"]
|
| 69 |
+
|
| 70 |
+
log = get_logger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
JSON = Dict
|
| 74 |
+
|
| 75 |
+
_EMPTY_ERROR_DATA = {"message": "<NONE>"}
|
| 76 |
+
_NOT_AVAILABLE = "<N/A>"
|
| 77 |
+
|
| 78 |
+
T = TypeVar("T")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@dataclass
|
| 82 |
+
class ProcessFailure:
|
| 83 |
+
"""
|
| 84 |
+
Represent the failed process result. When the worker process fails, it may record failure root cause into the file.
|
| 85 |
+
|
| 86 |
+
Tries to read the failure timestamp from the provided ``error_file``,
|
| 87 |
+
if the ``error_file`` does not exist, the timestamp is the current
|
| 88 |
+
timestamp (seconds since epoch).
|
| 89 |
+
|
| 90 |
+
The ``message`` field is a concise explanation of the failure. If
|
| 91 |
+
the error file exists then the message is obtained from the error file.
|
| 92 |
+
Otherwise one is generated based on the failure signature.
|
| 93 |
+
|
| 94 |
+
.. note:: It is assumed that the ``error_file`` is written by
|
| 95 |
+
``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``.
|
| 96 |
+
Otherwise the behavior is undefined.
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
local_rank: int
|
| 101 |
+
pid: int
|
| 102 |
+
exitcode: int
|
| 103 |
+
error_file: str
|
| 104 |
+
error_file_data: JSON = field(init=False)
|
| 105 |
+
message: str = field(init=False)
|
| 106 |
+
timestamp: int = field(init=False)
|
| 107 |
+
|
| 108 |
+
def __post_init__(self):
|
| 109 |
+
self.error_file_data = _EMPTY_ERROR_DATA
|
| 110 |
+
if os.path.isfile(self.error_file):
|
| 111 |
+
try:
|
| 112 |
+
with open(self.error_file) as fp:
|
| 113 |
+
self.error_file_data = json.load(fp)
|
| 114 |
+
log.debug(
|
| 115 |
+
"User process failed with error data: %s", json.dumps(self.error_file_data, indent=2)
|
| 116 |
+
)
|
| 117 |
+
self.message, self.timestamp = self._get_error_data(
|
| 118 |
+
self.error_file_data
|
| 119 |
+
)
|
| 120 |
+
except Exception:
|
| 121 |
+
log.exception("Failed to parse reply file: %s", self.error_file)
|
| 122 |
+
raise
|
| 123 |
+
else:
|
| 124 |
+
self._set_no_reply_file()
|
| 125 |
+
|
| 126 |
+
# make up an informative message if not already present
|
| 127 |
+
if not self.message:
|
| 128 |
+
# signals typically do not generate an error file message
|
| 129 |
+
if self.exitcode < 0:
|
| 130 |
+
self.message = (
|
| 131 |
+
f"Signal {-self.exitcode} ({self.signal_name()})"
|
| 132 |
+
f" received by PID {self.pid}"
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
|
| 136 |
+
|
| 137 |
+
def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
|
| 138 |
+
message = error_file_data["message"]
|
| 139 |
+
if isinstance(message, str):
|
| 140 |
+
timestamp = int(error_file_data.get("timestamp", 0))
|
| 141 |
+
else:
|
| 142 |
+
timestamp = int(message["extraInfo"]["timestamp"])
|
| 143 |
+
return (message, timestamp)
|
| 144 |
+
|
| 145 |
+
def _set_no_reply_file(self):
|
| 146 |
+
self.error_file = _NOT_AVAILABLE
|
| 147 |
+
self.error_file_data = _EMPTY_ERROR_DATA
|
| 148 |
+
self.message = ""
|
| 149 |
+
self.timestamp = int(time.time())
|
| 150 |
+
|
| 151 |
+
def signal_name(self) -> str:
|
| 152 |
+
if self.exitcode < 0:
|
| 153 |
+
# We don't want to kill the parent process trying to find the signal name.
|
| 154 |
+
# if the signal doesn't map to a known name, use not available.
|
| 155 |
+
try:
|
| 156 |
+
return signal.Signals(-self.exitcode).name
|
| 157 |
+
except Exception:
|
| 158 |
+
return _NOT_AVAILABLE
|
| 159 |
+
else:
|
| 160 |
+
return _NOT_AVAILABLE
|
| 161 |
+
|
| 162 |
+
def timestamp_isoformat(self):
|
| 163 |
+
"""Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS)."""
|
| 164 |
+
return datetime.fromtimestamp(self.timestamp).isoformat(sep="_")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
GlobalRank = int
|
| 168 |
+
|
| 169 |
+
_FAILURE_FORMAT_TEMPLATE = """[${idx}]:
|
| 170 |
+
time : ${time}
|
| 171 |
+
host : ${hostname}
|
| 172 |
+
rank : ${rank} (local_rank: ${local_rank})
|
| 173 |
+
exitcode : ${exitcode} (pid: ${pid})
|
| 174 |
+
error_file: ${error_file}
|
| 175 |
+
traceback : ${message}"""
|
| 176 |
+
|
| 177 |
+
# extra new lines before and after are intentional
|
| 178 |
+
_MSG_FORMAT_TEMPLATE = """
|
| 179 |
+
${boarder}
|
| 180 |
+
${title}
|
| 181 |
+
${section}
|
| 182 |
+
Failures:
|
| 183 |
+
${other_failures}
|
| 184 |
+
${section}
|
| 185 |
+
Root Cause (first observed failure):
|
| 186 |
+
${root_failure}
|
| 187 |
+
${boarder}"""
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class ChildFailedError(Exception):
|
| 191 |
+
"""
|
| 192 |
+
Special exception type that can be raised from a function annotated with the
|
| 193 |
+
``@record`` decorator to have the child process' (root exception) propagate
|
| 194 |
+
up the stack as-is (e.g. without being wrapped in the parent's traceback).
|
| 195 |
+
|
| 196 |
+
Useful in cases where the parent is a simple nanny process
|
| 197 |
+
and the child (worker) processes are actually doing meaningful compute.
|
| 198 |
+
In this case, errors typically occur on the child process as the parent
|
| 199 |
+
is not doing anything non-trivial, and child errors should be propagated
|
| 200 |
+
to the scheduler for accurate root cause diagnostics.
|
| 201 |
+
|
| 202 |
+
.. note:: The propagation relies on error files rather than exception handling to
|
| 203 |
+
support both function and binary launches.
|
| 204 |
+
|
| 205 |
+
Example:
|
| 206 |
+
::
|
| 207 |
+
|
| 208 |
+
# process tree on a host (container)
|
| 209 |
+
0: scheduler-init-process:
|
| 210 |
+
|- 1: torchelastic_agent:
|
| 211 |
+
|- 2: trainer_0 (ok)
|
| 212 |
+
|- 3: trainer_1 (fail) -> error.json
|
| 213 |
+
|- ...
|
| 214 |
+
|- n+2: trainer_n (ok)
|
| 215 |
+
|- n+3: other processes
|
| 216 |
+
|- ...
|
| 217 |
+
|
| 218 |
+
In the example above, trainer 1's failure (written into error.json) is
|
| 219 |
+
the root cause and should be reported to the scheduler's init process.
|
| 220 |
+
The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})``
|
| 221 |
+
upon detecting trainer 1's failure which would propagate the contents
|
| 222 |
+
of trainer 1's error file to the scheduler's init process.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]):
|
| 226 |
+
self.name = name
|
| 227 |
+
self.failures = failures
|
| 228 |
+
assert (
|
| 229 |
+
self.failures
|
| 230 |
+
) # does not make sense to create a ChildFaileError with no failures
|
| 231 |
+
super().__init__(self.format_msg())
|
| 232 |
+
|
| 233 |
+
def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]:
|
| 234 |
+
rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp)
|
| 235 |
+
return rank, self.failures[rank]
|
| 236 |
+
|
| 237 |
+
def format_msg(self, boarder_delim="=", section_delim="-"):
|
| 238 |
+
title = f"{self.name} FAILED"
|
| 239 |
+
root_rank, root_failure = self.get_first_failure()
|
| 240 |
+
|
| 241 |
+
root_failure_fmt: str = ""
|
| 242 |
+
other_failures_fmt: List[str] = []
|
| 243 |
+
width = len(title)
|
| 244 |
+
for idx, (rank, failure) in enumerate(self.failures.items()):
|
| 245 |
+
fmt, w = self._format_failure(idx, rank, failure)
|
| 246 |
+
width = max(width, w)
|
| 247 |
+
if rank == root_rank:
|
| 248 |
+
root_failure_fmt = fmt
|
| 249 |
+
else:
|
| 250 |
+
other_failures_fmt.append(fmt)
|
| 251 |
+
|
| 252 |
+
# upper boundary on width
|
| 253 |
+
width = min(width, 60)
|
| 254 |
+
|
| 255 |
+
return Template(_MSG_FORMAT_TEMPLATE).substitute(
|
| 256 |
+
boarder=boarder_delim * width,
|
| 257 |
+
title=title,
|
| 258 |
+
section=section_delim * width,
|
| 259 |
+
root_failure=root_failure_fmt,
|
| 260 |
+
other_failures="\n".join(other_failures_fmt or [" <NO_OTHER_FAILURES>"]),
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def _format_failure(
|
| 264 |
+
self, idx: int, rank: int, failure: ProcessFailure
|
| 265 |
+
) -> Tuple[str, int]:
|
| 266 |
+
|
| 267 |
+
# failure.message is either a str (when the failure does not generate a traceback - e.g. signals)
|
| 268 |
+
# or a dict (json) of the form
|
| 269 |
+
# {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}}
|
| 270 |
+
# so the display logic is:
|
| 271 |
+
# 1. if failure.message is not a dict (it is a str) just show it as is
|
| 272 |
+
# 2. else try to get the traceback (py_callstack)
|
| 273 |
+
# 3. if the traceback is not there, use the message
|
| 274 |
+
# 4. if the message is not there show <N/A>
|
| 275 |
+
msg = failure.message
|
| 276 |
+
if isinstance(failure.message, dict):
|
| 277 |
+
msg = (
|
| 278 |
+
failure.message.get("extraInfo", {})
|
| 279 |
+
.get("py_callstack", failure.message.get("message", "<N/A>"))
|
| 280 |
+
.replace("\n", "\n ") # to properly indent the traceback
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
|
| 284 |
+
idx=idx,
|
| 285 |
+
time=failure.timestamp_isoformat(),
|
| 286 |
+
hostname=socket.getfqdn(),
|
| 287 |
+
rank=rank,
|
| 288 |
+
local_rank=failure.local_rank,
|
| 289 |
+
exitcode=failure.exitcode,
|
| 290 |
+
pid=failure.pid,
|
| 291 |
+
error_file=failure.error_file,
|
| 292 |
+
message=msg,
|
| 293 |
+
)
|
| 294 |
+
width = 0
|
| 295 |
+
for line in fmt.split("\n"):
|
| 296 |
+
width = max(width, len(line))
|
| 297 |
+
return fmt, width
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def record(
|
| 301 |
+
fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
|
| 302 |
+
) -> Callable[..., T]:
|
| 303 |
+
"""
|
| 304 |
+
Syntactic sugar to record errors/exceptions that happened in the decorated
|
| 305 |
+
function using the provided ``error_handler``.
|
| 306 |
+
|
| 307 |
+
Using this decorator is equivalent to:
|
| 308 |
+
|
| 309 |
+
::
|
| 310 |
+
|
| 311 |
+
error_handler = get_error_handler()
|
| 312 |
+
error_handler.initialize()
|
| 313 |
+
try:
|
| 314 |
+
foobar()
|
| 315 |
+
except ChildFailedError as e:
|
| 316 |
+
_, failure = e.get_first_failure()
|
| 317 |
+
error_handler.dump_error_file(failure.error_file, failure.exitcode)
|
| 318 |
+
raise
|
| 319 |
+
except Exception as e:
|
| 320 |
+
error_handler.record(e)
|
| 321 |
+
raise
|
| 322 |
+
|
| 323 |
+
.. important:: use this decorator once per process at the top level method,
|
| 324 |
+
typically this is the main method.
|
| 325 |
+
|
| 326 |
+
Example
|
| 327 |
+
|
| 328 |
+
::
|
| 329 |
+
|
| 330 |
+
@record
|
| 331 |
+
def main():
|
| 332 |
+
pass
|
| 333 |
+
|
| 334 |
+
if __name__=="__main__":
|
| 335 |
+
main()
|
| 336 |
+
|
| 337 |
+
"""
|
| 338 |
+
if not error_handler:
|
| 339 |
+
error_handler = get_error_handler()
|
| 340 |
+
|
| 341 |
+
def wrap(f):
|
| 342 |
+
@wraps(f)
|
| 343 |
+
def wrapper(*args, **kwargs):
|
| 344 |
+
assert error_handler is not None # assertion for mypy type checker
|
| 345 |
+
error_handler.initialize()
|
| 346 |
+
try:
|
| 347 |
+
return f(*args, **kwargs)
|
| 348 |
+
except SystemExit as se:
|
| 349 |
+
# For run_path based entrypoints, SystemExit with code = 0 will never exit.
|
| 350 |
+
# Handling it here by returning a value:
|
| 351 |
+
if se.code == 0:
|
| 352 |
+
return None
|
| 353 |
+
else:
|
| 354 |
+
raise
|
| 355 |
+
except ChildFailedError as e:
|
| 356 |
+
rank, failure = e.get_first_failure()
|
| 357 |
+
if failure.error_file != _NOT_AVAILABLE:
|
| 358 |
+
error_handler.dump_error_file(failure.error_file, failure.exitcode)
|
| 359 |
+
else:
|
| 360 |
+
log.info(
|
| 361 |
+
(
|
| 362 |
+
"local_rank %s FAILED with no error file."
|
| 363 |
+
" Decorate your entrypoint fn with @record for traceback info."
|
| 364 |
+
" See: https://pytorch.org/docs/stable/elastic/errors.html",
|
| 365 |
+
rank
|
| 366 |
+
)
|
| 367 |
+
)
|
| 368 |
+
raise
|
| 369 |
+
except Exception as e:
|
| 370 |
+
error_handler.record_exception(e)
|
| 371 |
+
raise
|
| 372 |
+
|
| 373 |
+
return wrapper
|
| 374 |
+
|
| 375 |
+
return wrap(fn)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
# Multiprocessing error-reporting module
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
|
| 12 |
+
|
| 13 |
+
__all__ = ['get_error_handler']
|
| 14 |
+
|
| 15 |
+
def get_error_handler():
|
| 16 |
+
return ErrorHandler()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-311.pyc
ADDED
|
Binary file (937 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-311.pyc
ADDED
|
Binary file (3.72 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This source code is licensed under the BSD-style license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
from typing import Dict, Tuple
|
| 9 |
+
|
| 10 |
+
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
|
| 11 |
+
SubprocessHandler,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = ["get_subprocess_handler"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_subprocess_handler(
|
| 18 |
+
entrypoint: str,
|
| 19 |
+
args: Tuple,
|
| 20 |
+
env: Dict[str, str],
|
| 21 |
+
stdout: str,
|
| 22 |
+
stderr: str,
|
| 23 |
+
local_rank_id: int,
|
| 24 |
+
):
|
| 25 |
+
return SubprocessHandler(
|
| 26 |
+
entrypoint=entrypoint,
|
| 27 |
+
args=args,
|
| 28 |
+
env=env,
|
| 29 |
+
stdout=stdout,
|
| 30 |
+
stderr=stderr,
|
| 31 |
+
local_rank_id=local_rank_id,
|
| 32 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Expiration timers are set up on the same process as the agent and
|
| 9 |
+
used from your script to deal with stuck workers. When you go into
|
| 10 |
+
a code-block that has the potential to get stuck you can acquire
|
| 11 |
+
an expiration timer, which instructs the timer server to kill the
|
| 12 |
+
process if it does not release the timer by the self-imposed expiration
|
| 13 |
+
deadline.
|
| 14 |
+
|
| 15 |
+
Usage::
|
| 16 |
+
|
| 17 |
+
import torchelastic.timer as timer
|
| 18 |
+
import torchelastic.agent.server as agent
|
| 19 |
+
|
| 20 |
+
def main():
|
| 21 |
+
start_method = "spawn"
|
| 22 |
+
message_queue = mp.get_context(start_method).Queue()
|
| 23 |
+
server = timer.LocalTimerServer(message, max_interval=0.01)
|
| 24 |
+
server.start() # non-blocking
|
| 25 |
+
|
| 26 |
+
spec = WorkerSpec(
|
| 27 |
+
fn=trainer_func,
|
| 28 |
+
args=(message_queue,),
|
| 29 |
+
...<OTHER_PARAMS...>)
|
| 30 |
+
agent = agent.LocalElasticAgent(spec, start_method)
|
| 31 |
+
agent.run()
|
| 32 |
+
|
| 33 |
+
def trainer_func(message_queue):
|
| 34 |
+
timer.configure(timer.LocalTimerClient(message_queue))
|
| 35 |
+
with timer.expires(after=60): # 60 second expiry
|
| 36 |
+
# do some work
|
| 37 |
+
|
| 38 |
+
In the example above if ``trainer_func`` takes more than 60 seconds to
|
| 39 |
+
complete, then the worker process is killed and the agent retries the worker group.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
from .api import TimerClient, TimerRequest, TimerServer, configure, expires # noqa: F401
|
| 43 |
+
from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401
|
| 44 |
+
from .file_based_local_timer import FileTimerClient, FileTimerServer, FileTimerRequest # noqa: F401
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-311.pyc
ADDED
|
Binary file (7.63 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/local_timer.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import logging
|
| 7 |
+
import multiprocessing as mp
|
| 8 |
+
import os
|
| 9 |
+
import signal
|
| 10 |
+
import time
|
| 11 |
+
from queue import Empty
|
| 12 |
+
from typing import Any, Dict, List, Set, Tuple
|
| 13 |
+
|
| 14 |
+
from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
|
| 15 |
+
|
| 16 |
+
__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer']
|
| 17 |
+
|
| 18 |
+
log = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
class LocalTimerClient(TimerClient):
|
| 21 |
+
"""
|
| 22 |
+
Client side of ``LocalTimerServer``. This client is meant to be used
|
| 23 |
+
on the same host that the ``LocalTimerServer`` is running on and uses
|
| 24 |
+
pid to uniquely identify a worker. This is particularly useful in situations
|
| 25 |
+
where one spawns a subprocess (trainer) per GPU on a host with multiple
|
| 26 |
+
GPU devices.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, mp_queue):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self._mp_queue = mp_queue
|
| 32 |
+
|
| 33 |
+
def acquire(self, scope_id, expiration_time):
|
| 34 |
+
pid = os.getpid()
|
| 35 |
+
acquire_request = TimerRequest(pid, scope_id, expiration_time)
|
| 36 |
+
self._mp_queue.put(acquire_request)
|
| 37 |
+
|
| 38 |
+
def release(self, scope_id):
|
| 39 |
+
pid = os.getpid()
|
| 40 |
+
release_request = TimerRequest(pid, scope_id, -1)
|
| 41 |
+
self._mp_queue.put(release_request)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MultiprocessingRequestQueue(RequestQueue):
|
| 45 |
+
"""
|
| 46 |
+
A ``RequestQueue`` backed by python ``multiprocessing.Queue``
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, mp_queue: mp.Queue):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self._mp_queue = mp_queue
|
| 52 |
+
|
| 53 |
+
def size(self) -> int:
|
| 54 |
+
return self._mp_queue.qsize()
|
| 55 |
+
|
| 56 |
+
def get(self, size, timeout: float) -> List[TimerRequest]:
|
| 57 |
+
requests = []
|
| 58 |
+
wait = timeout
|
| 59 |
+
for _ in range(0, size):
|
| 60 |
+
start = time.time()
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
r = self._mp_queue.get(block=True, timeout=wait)
|
| 64 |
+
except Empty:
|
| 65 |
+
break
|
| 66 |
+
|
| 67 |
+
requests.append(r)
|
| 68 |
+
wait = wait - (time.time() - start)
|
| 69 |
+
if wait <= 0:
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
return requests
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class LocalTimerServer(TimerServer):
|
| 76 |
+
"""
|
| 77 |
+
Server that works with ``LocalTimerClient``. Clients are expected to be
|
| 78 |
+
subprocesses to the parent process that is running this server. Each host
|
| 79 |
+
in the job is expected to start its own timer server locally and each
|
| 80 |
+
server instance manages timers for local workers (running on processes
|
| 81 |
+
on the same host).
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
|
| 86 |
+
):
|
| 87 |
+
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
|
| 88 |
+
self._timers: Dict[Tuple[Any, str], TimerRequest] = {}
|
| 89 |
+
|
| 90 |
+
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
|
| 91 |
+
for request in timer_requests:
|
| 92 |
+
pid = request.worker_id
|
| 93 |
+
scope_id = request.scope_id
|
| 94 |
+
expiration_time = request.expiration_time
|
| 95 |
+
|
| 96 |
+
# negative expiration is a proxy for a release call
|
| 97 |
+
if expiration_time < 0:
|
| 98 |
+
self._timers.pop((pid, scope_id), None)
|
| 99 |
+
else:
|
| 100 |
+
self._timers[(pid, scope_id)] = request
|
| 101 |
+
|
| 102 |
+
def clear_timers(self, worker_ids: Set[int]) -> None:
|
| 103 |
+
for (pid, scope_id) in list(self._timers.keys()):
|
| 104 |
+
if pid in worker_ids:
|
| 105 |
+
self._timers.pop((pid, scope_id))
|
| 106 |
+
|
| 107 |
+
def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]:
|
| 108 |
+
# pid -> [timer_requests...]
|
| 109 |
+
expired_timers: Dict[Any, List[TimerRequest]] = {}
|
| 110 |
+
for request in self._timers.values():
|
| 111 |
+
if request.expiration_time <= deadline:
|
| 112 |
+
expired_scopes = expired_timers.setdefault(request.worker_id, [])
|
| 113 |
+
expired_scopes.append(request)
|
| 114 |
+
return expired_timers
|
| 115 |
+
|
| 116 |
+
def _reap_worker(self, worker_id: int) -> bool:
|
| 117 |
+
try:
|
| 118 |
+
os.kill(worker_id, signal.SIGKILL)
|
| 119 |
+
return True
|
| 120 |
+
except ProcessLookupError:
|
| 121 |
+
log.info("Process with pid=%s does not exist. Skipping", worker_id)
|
| 122 |
+
return True
|
| 123 |
+
except Exception:
|
| 124 |
+
log.exception("Error terminating pid=%s", worker_id)
|
| 125 |
+
return False
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
if torch.distributed.rpc.is_available():
|
| 3 |
+
from .api.remote_module import RemoteModule
|
| 4 |
+
from .functional import * # noqa: F403
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/api/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (225 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/templates/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-311.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-311.pyc
ADDED
|
Binary file (5.95 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-311.pyc
ADDED
|
Binary file (7.46 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (225 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-311.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-311.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-311.pyc
ADDED
|
Binary file (28 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/api.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
from typing import Dict, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed._tensor.random as random
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.distributed._tensor import (
|
| 8 |
+
DeviceMesh,
|
| 9 |
+
)
|
| 10 |
+
from torch.distributed._tensor.random import (
|
| 11 |
+
is_rng_supported_mesh,
|
| 12 |
+
TensorParallelRNGTracker,
|
| 13 |
+
)
|
| 14 |
+
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
|
| 15 |
+
from torch.distributed.tensor.parallel.style import (
|
| 16 |
+
ParallelStyle,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"parallelize_module",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parallelize_module( # type: ignore[return]
|
| 26 |
+
module: nn.Module,
|
| 27 |
+
device_mesh: DeviceMesh,
|
| 28 |
+
parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
|
| 29 |
+
) -> nn.Module:
|
| 30 |
+
"""
|
| 31 |
+
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
|
| 32 |
+
|
| 33 |
+
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
|
| 34 |
+
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
|
| 35 |
+
to be parallelized.
|
| 36 |
+
|
| 37 |
+
User can also specify different parallel style per module fully qualified name (FQN).
|
| 38 |
+
|
| 39 |
+
Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
|
| 40 |
+
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
module (:class:`nn.Module`):
|
| 44 |
+
Module to be parallelized.
|
| 45 |
+
device_mesh (:class:`DeviceMesh`):
|
| 46 |
+
Object which describes the mesh topology
|
| 47 |
+
of devices for the DTensor.
|
| 48 |
+
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
|
| 49 |
+
The plan used to parallelize the module. It can be either a
|
| 50 |
+
:class:`ParallelStyle` object which contains how
|
| 51 |
+
we prepare input/output for Tensor Parallelism or it can be a
|
| 52 |
+
dict of module FQN and its corresponding :class:`ParallelStyle` object.
|
| 53 |
+
Return:
|
| 54 |
+
A :class:`nn.Module` object parallelized.
|
| 55 |
+
|
| 56 |
+
Example::
|
| 57 |
+
>>> # xdoctest: +SKIP("distributed")
|
| 58 |
+
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
|
| 59 |
+
>>> from torch.distributed.device_mesh import init_device_mesh
|
| 60 |
+
>>>
|
| 61 |
+
>>> # Define the module.
|
| 62 |
+
>>> m = Model(...)
|
| 63 |
+
>>> tp_mesh = init_device_mesh("cuda", (8,))
|
| 64 |
+
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
|
| 65 |
+
>>>
|
| 66 |
+
|
| 67 |
+
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
|
| 68 |
+
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
|
| 69 |
+
as a parallelize_plan, to achieves the desired sharding computation.
|
| 70 |
+
"""
|
| 71 |
+
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
|
| 72 |
+
|
| 73 |
+
_validate_tp_mesh_dim(device_mesh)
|
| 74 |
+
|
| 75 |
+
# instantiate a TP RNG state tracker if it's not there
|
| 76 |
+
if is_rng_supported_mesh(device_mesh) and not isinstance(
|
| 77 |
+
random._rng_tracker, TensorParallelRNGTracker
|
| 78 |
+
):
|
| 79 |
+
random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
|
| 80 |
+
# TODO: we should allow user to pass in the default seed from a config
|
| 81 |
+
random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
|
| 82 |
+
# By default we execute random ops in non-tensor-parallel region. If users want
|
| 83 |
+
# to execute in tensor-parallel region, they can manually set this field to True
|
| 84 |
+
# after parallelizing the model.
|
| 85 |
+
random._rng_tracker.distribute_region_enabled = False
|
| 86 |
+
|
| 87 |
+
if isinstance(parallelize_plan, ParallelStyle):
|
| 88 |
+
return parallelize_plan._apply(module, device_mesh)
|
| 89 |
+
elif isinstance(parallelize_plan, dict):
|
| 90 |
+
for module_path, parallelize_style in parallelize_plan.items():
|
| 91 |
+
sub_module = module.get_submodule(module_path)
|
| 92 |
+
parent_module = module
|
| 93 |
+
if "." in module_path:
|
| 94 |
+
parent_module_path = ".".join(module_path.split(".")[:-1])
|
| 95 |
+
parent_module = module.get_submodule(parent_module_path)
|
| 96 |
+
module_path = module_path.split(".")[-1]
|
| 97 |
+
parent_module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
|
| 98 |
+
module_path,
|
| 99 |
+
parallelize_module( # type: ignore[arg-type]
|
| 100 |
+
sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
|
| 101 |
+
),
|
| 102 |
+
)
|
| 103 |
+
return module
|
| 104 |
+
else:
|
| 105 |
+
raise RuntimeError( # pyre-ignore[7]
|
| 106 |
+
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
|
| 107 |
+
f" parallelize_plan, {type(parallelize_plan)} found!"
|
| 108 |
+
)
|