koichi12 commited on
Commit
de8bd69
·
verified ·
1 Parent(s): 71147a9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__init__.py +1 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/auto_functionalize.py +261 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/cond.py +349 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/effects.py +204 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/map.py +358 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/strict_mode.py +100 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/torchbind.py +94 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py +842 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/while_loop.py +232 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mkl/__init__.py +56 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mps/__pycache__/__init__.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/openmp/__init__.py +6 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/autograd/__init__.py +52 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/api.py +112 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/handlers.py +22 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py +201 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-311.pyc +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py +375 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py +16 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-311.pyc +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-311.pyc +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +32 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__init__.py +44 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-311.pyc +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/local_timer.py +125 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/__init__.py +4 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/api/__init__.py +0 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-311.pyc +0 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/templates/__init__.py +0 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-311.pyc +0 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-311.pyc +0 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-311.pyc +0 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-311.pyc +0 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/api.py +108 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .cond import cond
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (263 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/cond.cpython-311.pyc ADDED
Binary file (18.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/effects.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/map.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/strict_mode.cpython-311.pyc ADDED
Binary file (5.79 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-311.pyc ADDED
Binary file (42 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/__pycache__/wrap.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/auto_functionalize.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.utils._pytree as pytree
5
+ from torch import Tensor
6
+ from torch._C import DispatchKey
7
+ from torch._ops import HigherOrderOperator
8
+ from torch._prims_common import clone_preserve_strides
9
+ from torch._subclasses.fake_tensor import FakeTensorMode
10
+ from torch.fx.experimental.proxy_tensor import (
11
+ disable_proxy_modes_tracing,
12
+ ProxyTorchDispatchMode,
13
+ track_tensor_tree,
14
+ )
15
+
16
+
17
+ # NOTE: [auto-functionalizing custom ops]
18
+ # Users may wish to torch.compile custom ops that mutate their inputs.
19
+ # torch.compile will automatically support this op without anyone needing
20
+ # to provide a functionalization kernel for it. Here's how.
21
+ #
22
+ # Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
23
+ # op. First, when FakeTensor sees this op:
24
+ # - If the schema says it returns nothing, we can generate a trivial
25
+ # FakeTensor rule for it (that returns nothing).
26
+ # - Otherwise, the user needs to provide a FakeTensor rule (abstract impl)
27
+ #
28
+ # Next, when Python FunctionalTensor sees the op, it will functionalize
29
+ # it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
30
+ # HOP and replacing the mutated inputs with corresponding outputs of this HOP.
31
+ # This HOP effectively runs the functional version of the op when
32
+ # called: it clones inputs that will be mutated, runs the op, and
33
+ # then returns (output, Tensors with the new values)
34
+
35
+
36
+ class AutoFunctionalized(HigherOrderOperator):
37
+ """auto_functionalized(_mutable_op, **kwargs)
38
+
39
+ This HOP runs a "functional" version of _mutable_op.
40
+
41
+ Concretely, it looks at all the arguments that are mutable through
42
+ _mutable_op's operator schema, clones those kwargs, runs
43
+ `out = _mutable_op(**kwargs)` with the cloned values, and then returns the
44
+ operator output concatenated with the cloned values that were mutated.
45
+
46
+ We have some restrictions on `_mutable_op`.
47
+ See `can_auto_functionalize` for the restrictions. We can likely lift
48
+ many of these if users request it.
49
+
50
+ The reason why _mutable_op is prefixed with an
51
+ underscore is to prevent collisions with kwarg names in **kwargs.
52
+ """
53
+
54
+ def __init__(self):
55
+ super().__init__("auto_functionalized")
56
+
57
+ def __call__(
58
+ self,
59
+ _mutable_op: torch._ops.OpOverload,
60
+ **kwargs: Dict[str, Any],
61
+ ) -> Tuple[Any, Tuple[Tensor, ...]]:
62
+ assert can_auto_functionalize(_mutable_op)
63
+ assert isinstance(kwargs, dict)
64
+ return super().__call__(_mutable_op, **kwargs)
65
+
66
+
67
+ auto_functionalized = AutoFunctionalized()
68
+
69
+
70
+ def can_auto_functionalize(op: torch._ops.OperatorBase) -> bool:
71
+ if not isinstance(op, torch._ops.OpOverload):
72
+ return False
73
+
74
+ if torch._library.utils.is_builtin(op):
75
+ # We control the built-ins. These may (in rare cases)
76
+ # do input metadata mutation (which we have banned on custom ops)
77
+ return False
78
+ schema = op._schema
79
+ if not schema.is_mutable:
80
+ return False
81
+ schema = op._schema
82
+
83
+ for arg in schema.arguments:
84
+ if arg.alias_info is None:
85
+ continue
86
+ if not arg.alias_info.is_write:
87
+ continue
88
+ if type(arg.type) is torch.TensorType:
89
+ continue
90
+ if (
91
+ type(arg.type) is torch.OptionalType
92
+ and type(arg.type.getElementType()) is torch.TensorType
93
+ ):
94
+ continue
95
+ # Not yet supported: other Tensor types. This includes things like
96
+ # Tensor[], Tensor?[], Tensor[]?.
97
+ return False
98
+
99
+ # The returns must not alias anything
100
+ for ret in schema.returns:
101
+ if ret.alias_info is None and type(ret.type) is torch.TensorType:
102
+ continue
103
+ # Not yet supported: List[Tensor] return.
104
+ return False
105
+ return True
106
+
107
+
108
+ @auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd)
109
+ def auto_functionalized_dense(
110
+ _mutable_op: torch._ops.OpOverload,
111
+ _only_clone_these_tensors: Optional[Tuple[str, ...]] = None,
112
+ **kwargs: Dict[str, Any],
113
+ ) -> Tuple[Any, Tuple[Tensor, ...]]:
114
+ new_kwargs = dict(**kwargs)
115
+ result = []
116
+
117
+ _mutable_args_names = get_mutable_arg_names(_mutable_op)
118
+ for name in _mutable_args_names:
119
+ if (
120
+ _only_clone_these_tensors is not None
121
+ and name not in _only_clone_these_tensors
122
+ ):
123
+ new_kwargs[name] = kwargs[name]
124
+ else:
125
+ new_kwargs[name] = (
126
+ clone_preserve_strides(kwargs[name])
127
+ if kwargs[name] is not None
128
+ else None
129
+ )
130
+ result.append(new_kwargs[name])
131
+ out = _mutable_op(**new_kwargs)
132
+
133
+ if isinstance(out, tuple):
134
+ return (*out, *result) # type: ignore[return-value]
135
+ else:
136
+ return (out, *result) # type: ignore[return-value]
137
+
138
+
139
+ @auto_functionalized.py_impl(FakeTensorMode)
140
+ def auto_functionalized_fake(
141
+ mode,
142
+ _mutable_op: torch._ops.OpOverload,
143
+ **kwargs: Dict[str, Any],
144
+ ) -> Tuple[Any, Tuple[Tensor, ...]]:
145
+ with mode:
146
+ result = auto_functionalized_dense(_mutable_op, **kwargs)
147
+ return result
148
+
149
+
150
+ @auto_functionalized.py_impl(ProxyTorchDispatchMode)
151
+ def auto_functionalized_proxy(
152
+ mode,
153
+ _mutable_op: torch._ops.OpOverload,
154
+ **kwargs: Dict[str, Any],
155
+ ) -> Tuple[Any, Tuple[Tensor, ...]]:
156
+ if not mode.enable_tracing:
157
+ return auto_functionalized(_mutable_op, **kwargs)
158
+
159
+ with disable_proxy_modes_tracing():
160
+ out = auto_functionalized(_mutable_op, **kwargs)
161
+
162
+ proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
163
+ out_proxy = mode.tracer.create_proxy(
164
+ "call_function",
165
+ auto_functionalized,
166
+ (_mutable_op,),
167
+ proxy_kwargs,
168
+ )
169
+ result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
170
+ return result
171
+
172
+
173
+ auto_functionalized.fallthrough(DispatchKey.AutogradCPU)
174
+ auto_functionalized.fallthrough(DispatchKey.AutogradCUDA)
175
+
176
+
177
+ def get_mutable_arg_names(op: torch._ops.OpOverload) -> List[str]:
178
+ """
179
+ Returns the list of argument names that get mutated according to the
180
+ schema.
181
+ """
182
+ mutable_args_names = [
183
+ arg.name
184
+ for arg in op._schema.arguments
185
+ if arg.alias_info is not None and arg.alias_info.is_write
186
+ ]
187
+ return mutable_args_names
188
+
189
+
190
+ def do_auto_functionalize(
191
+ op: torch._ops.OpOverload, args: Tuple[Any, ...], kwargs: Dict[str, Any]
192
+ ) -> Any:
193
+ """Functionalizes a call to op(*args, **kwargs) by emitting a call to
194
+ `outs = auto_functionalized(op, normalized_kwargs)`
195
+ and replacing the mutated (args, kwargs) with the corresponding outputs.
196
+
197
+ The normalized_kwargs are just the (args, kwargs), but all in kwarg form.
198
+ This makes handling easier for the auto_functionalized HOP.
199
+ """
200
+ from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
201
+
202
+ ctx = PythonFunctionalizeAPI()
203
+
204
+ # All of the (args, kwargs), but all as kwargs. The names for the
205
+ # args come from the schema. This makes it easier for us to work with them.
206
+ normalized_kwargs = {}
207
+ schema = op._schema
208
+ for idx, arg in enumerate(schema.arguments):
209
+ # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema
210
+ if arg.name in kwargs:
211
+ normalized_kwargs[arg.name] = kwargs[arg.name]
212
+ elif idx < len(args):
213
+ # if its out of bounds we don't need to do anything
214
+ # as it means the the optional arg was passed with its default
215
+ # value
216
+ normalized_kwargs[arg.name] = args[idx]
217
+ else:
218
+ normalized_kwargs[arg.name] = arg.default_value
219
+
220
+ unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
221
+ with ctx.redispatch_to_next():
222
+ unwrapped_outs = auto_functionalized(
223
+ op, **unwrapped_kwargs # type: ignore[arg-type]
224
+ )
225
+
226
+ # List of the name of args that get mutated (according to the schema)
227
+ mutable_args_names = get_mutable_arg_names(op)
228
+
229
+ unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[
230
+ : -len(mutable_args_names)
231
+ ]
232
+ unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :]
233
+
234
+ if len(op._schema.returns) == 0:
235
+ assert unwrapped_actual_out[0] is None
236
+ unwrapped_actual_out = None
237
+ elif len(op._schema.returns) == 1:
238
+ assert len(unwrapped_actual_out) == 1
239
+ unwrapped_actual_out = unwrapped_actual_out[0]
240
+ else:
241
+ assert len(unwrapped_actual_out) == len(op._schema.returns)
242
+
243
+ for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out):
244
+ # Can be None if input was `Tensor(a!)?`
245
+ if unwrapped_out is None:
246
+ continue
247
+ assert isinstance(unwrapped_out, torch.Tensor)
248
+ orig_arg = normalized_kwargs[name]
249
+ ctx.replace(orig_arg, unwrapped_out)
250
+ ctx.commit_update(orig_arg)
251
+ ctx.sync(orig_arg)
252
+
253
+ return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]
254
+
255
+
256
+ @auto_functionalized.py_functionalize_impl
257
+ def auto_functionalized_func(ctx, _mutable_op, **kwargs):
258
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
259
+ with ctx.redispatch_to_next():
260
+ result = auto_functionalized(_mutable_op, **unwrapped_kwargs)
261
+ return ctx.wrap_tensors(result)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/cond.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch._subclasses.functional_tensor
3
+
4
+ import torch.utils._pytree as pytree
5
+
6
+ from torch._C import DispatchKey
7
+ from torch._C._functorch import (
8
+ _add_batch_dim,
9
+ get_unwrapped,
10
+ is_batchedtensor,
11
+ maybe_get_bdim,
12
+ )
13
+ from torch._functorch.utils import exposed_in
14
+
15
+ from torch._higher_order_ops.utils import (
16
+ _has_potential_branch_input_alias,
17
+ _has_potential_branch_input_mutation,
18
+ _set_compilation_env,
19
+ autograd_not_implemented,
20
+ reenter_make_fx,
21
+ UnsupportedAliasMutationException,
22
+ )
23
+
24
+ from torch._ops import HigherOrderOperator
25
+ from torch._subclasses.fake_tensor import FakeTensorMode
26
+ from torch.fx.experimental.proxy_tensor import (
27
+ disable_proxy_modes_tracing,
28
+ ProxyTorchDispatchMode,
29
+ track_tensor_tree,
30
+ )
31
+ from torch.fx.passes.shape_prop import _extract_tensor_metadata
32
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
33
+
34
+
35
+ @exposed_in("torch")
36
+ def cond(pred, true_fn, false_fn, operands):
37
+ r"""
38
+ Conditionally applies `true_fn` or `false_fn`.
39
+
40
+ .. warning::
41
+ `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
42
+ doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
43
+ Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
44
+
45
+ `cond` is structured control flow operator. That is, it is like a Python if-statement,
46
+ but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be
47
+ capturable using torch.compile and torch.export.
48
+
49
+ Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following::
50
+
51
+ def cond(pred, true_branch, false_branch, operands):
52
+ if pred:
53
+ return true_branch(*operands)
54
+ else:
55
+ return false_branch(*operands)
56
+
57
+ Args:
58
+ pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element,
59
+ indicating which branch function to apply.
60
+
61
+ true_fn (Callable): A callable function (a -> b) that is within the
62
+ scope that is being traced.
63
+
64
+ false_fn (Callable): A callable function (a -> b) that is within the
65
+ scope that is being traced. The true branch and false branch must
66
+ have consistent input and outputs, meaning the inputs have to be
67
+ the same, and the outputs have to be the same type and shape.
68
+
69
+ operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions.
70
+
71
+ Example::
72
+
73
+ def true_fn(x: torch.Tensor):
74
+ return x.cos()
75
+ def false_fn(x: torch.Tensor):
76
+ return x.sin()
77
+ return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
78
+
79
+ Restrictions:
80
+ - The conditional statement (aka `pred`) must meet one of the following constraints:
81
+
82
+ - It's a `torch.Tensor` with only one element, and torch.bool dtype
83
+
84
+ - It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10`
85
+
86
+ - The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints:
87
+
88
+ - The function signature must match with operands.
89
+
90
+ - The function must return a tensor with the same metadata, e.g. shape,
91
+ dtype, etc.
92
+
93
+ - The function cannot have in-place mutations on inputs or global variables.
94
+ (Note: in-place tensor operations such as `add_` for intermediate results
95
+ are allowed in a branch)
96
+
97
+ .. warning::
98
+ Temporal Limitations:
99
+
100
+ - `cond` only supports **inference** right now. Autograd will be supported in the future.
101
+
102
+ - The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
103
+
104
+ """
105
+
106
+ if torch.compiler.is_dynamo_compiling():
107
+ return cond_op(pred, true_fn, false_fn, operands)
108
+
109
+ def _validate_input(pred, true_fn, false_fn, operands):
110
+ if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
111
+ raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
112
+
113
+ if isinstance(pred, torch.Tensor) and pred.numel() != 1:
114
+ raise RuntimeError(
115
+ f"Expected pred to be bool or single-element tensor, but got {pred}."
116
+ )
117
+
118
+ if not callable(true_fn) or not callable(false_fn):
119
+ raise RuntimeError("Expect both branches to be callbale.")
120
+
121
+ if not isinstance(operands, (tuple, list)) or pytree.tree_any(
122
+ lambda t: not isinstance(t, torch.Tensor), operands
123
+ ):
124
+ raise RuntimeError(
125
+ "Expect operands to be a tuple of possibly nested dict/list/tuple that only"
126
+ f"consists of tensor leaves, but got {operands}."
127
+ )
128
+
129
+ _validate_input(pred, true_fn, false_fn, operands)
130
+
131
+ if not torch._dynamo.is_dynamo_supported():
132
+ raise RuntimeError("torch.cond requires dynamo support.")
133
+
134
+ with _set_compilation_env():
135
+ with torch._dynamo.utils.disable_cache_limit():
136
+ return torch.compile(cond_op, backend="eager", fullgraph=True)(
137
+ pred, true_fn, false_fn, operands
138
+ )
139
+
140
+
141
+ """
142
+ We're going to define a `cond_op` operation.
143
+ In order to do this, we need implementations for each of the dispatch keys.
144
+ """
145
+ cond_op = HigherOrderOperator("cond")
146
+
147
+
148
+ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
149
+ assert isinstance(
150
+ operands, (list, tuple)
151
+ ), "Cond operands must be a list or tuple of tensors"
152
+ assert all(
153
+ isinstance(o, torch.Tensor) for o in operands
154
+ ), "Cond operands must be a list of tensors"
155
+
156
+ pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
157
+
158
+ with disable_proxy_modes_tracing():
159
+ true_graph = reenter_make_fx(true_fn, pre_dispatch)(*operands)
160
+ false_graph = reenter_make_fx(false_fn, pre_dispatch)(*operands)
161
+
162
+ true_outs = []
163
+ false_outs = []
164
+ for node in true_graph.graph.nodes:
165
+ if node.op == "output":
166
+ true_outs.extend(node.args)
167
+
168
+ for node in false_graph.graph.nodes:
169
+ if node.op == "output":
170
+ false_outs.extend(node.args)
171
+
172
+ flat_true_outs = pytree.arg_tree_leaves(*true_outs)
173
+ flat_false_outs = pytree.arg_tree_leaves(*false_outs)
174
+ if len(flat_true_outs) != len(flat_false_outs):
175
+ raise torch._dynamo.exc.CondOpArgsMismatchError(
176
+ f"Expected to return same number of outputs but got:"
177
+ f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
178
+ f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
179
+ )
180
+
181
+ for i in range(0, len(flat_true_outs)):
182
+ true_out = flat_true_outs[i]
183
+ false_out = flat_false_outs[i]
184
+ if true_out.meta["tensor_meta"] != false_out.meta["tensor_meta"]:
185
+ raise torch._dynamo.exc.CondOpArgsMismatchError(
186
+ f"Expected each tensor to have same metadata but got:"
187
+ f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
188
+ f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
189
+ )
190
+
191
+ # There are probably better ways - I know that create_arg has some self incrementing name
192
+ # magic to it, but since we explicitly have to get the name for register_module,
193
+ # I was not sure how to do that. This kinda simulates it.
194
+ next_name = None
195
+ i = 0
196
+ while not next_name:
197
+ candidate = f"true_graph_{i}"
198
+ if hasattr(proxy_mode.tracer.root, candidate):
199
+ i += 1
200
+ else:
201
+ next_name = candidate
202
+
203
+ true_name = next_name
204
+ false_name = f"false_graph_{i}"
205
+ assert not hasattr(proxy_mode.tracer.root, false_name)
206
+
207
+ proxy_mode.tracer.root.register_module(true_name, true_graph)
208
+ proxy_mode.tracer.root.register_module(false_name, false_graph)
209
+
210
+ args = (pred, true_graph, false_graph, operands)
211
+
212
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
213
+
214
+ out_proxy = proxy_mode.tracer.create_proxy(
215
+ "call_function", func_overload, proxy_args, {}, name="conditional"
216
+ )
217
+
218
+ # At this point, we're *guaranteed* that whether an output came from the
219
+ # true or false branch is indistinguishable. So, as this is just for tracing
220
+ # purposes, choose the true branch.
221
+
222
+ # TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
223
+ # a FakeTensorMode error :
224
+ # `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
225
+ # TODO Sometimes the operands are not completely FakeTensor, something seems went wrong in
226
+ # dynamo? Because of that it runs real computation sometimes and re-triggering downstream dispatch keys.
227
+ out = false_fn(*operands)
228
+
229
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
230
+
231
+
232
+ @cond_op.py_impl(DispatchKey.CompositeExplicitAutograd)
233
+ def cond_op_dense(pred, true_fn, false_fn, operands):
234
+ mode = _get_current_dispatch_mode()
235
+ assert mode is None, "Mode should never be enabled for CPU/CUDA key"
236
+ if pred:
237
+ return true_fn(*operands)
238
+ else:
239
+ return false_fn(*operands)
240
+
241
+
242
+ cond_op.py_impl(DispatchKey.Autograd)(
243
+ autograd_not_implemented(cond_op, deferred_error=True)
244
+ )
245
+
246
+
247
+ @cond_op.py_impl(ProxyTorchDispatchMode)
248
+ def inner(mode, pred, true_fn, false_fn, operands):
249
+ if mode.enable_tracing:
250
+ return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
251
+ else:
252
+ return cond_op(pred, true_fn, false_fn, operands)
253
+
254
+
255
+ @cond_op.py_impl(FakeTensorMode)
256
+ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
257
+ with mode:
258
+ true_outs = true_fn(*operands)
259
+ flat_true_outs = pytree.tree_leaves(true_outs)
260
+ flat_false_outs = pytree.tree_leaves(false_fn(*operands))
261
+ if len(flat_true_outs) != len(flat_false_outs):
262
+ raise RuntimeError("Unmatched number of outputs from cond() branches.")
263
+
264
+ for true_out, false_out in zip(flat_true_outs, flat_false_outs):
265
+ true_meta = _extract_tensor_metadata(true_out)
266
+ false_meta = _extract_tensor_metadata(false_out)
267
+ if true_meta != false_meta:
268
+ raise torch._dynamo.exc.CondOpArgsMismatchError(
269
+ f"Expected each tensor to have same metadata but got:"
270
+ f"\n {true_fn.__name__} returns {true_meta}"
271
+ f"\n {false_fn.__name__} returns {false_meta}"
272
+ )
273
+ return true_outs
274
+
275
+
276
+ @cond_op.py_functionalize_impl
277
+ def cond_func(ctx, pred, true_fn, false_fn, inputs):
278
+ unwrapped_inputs = ctx.unwrap_tensors(inputs)
279
+ unwrapped_pred = ctx.unwrap_tensors(pred)
280
+ with ctx.redispatch_to_next() as m:
281
+ functional_true = ctx.functionalize(true_fn)
282
+ functional_false = ctx.functionalize(false_fn)
283
+ pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
284
+ for branch in [functional_true, functional_false]:
285
+ if _has_potential_branch_input_mutation(
286
+ branch, unwrapped_inputs, pre_dispatch=pre_dispatch
287
+ ):
288
+ raise UnsupportedAliasMutationException(
289
+ "One of torch.cond branch might be modifying the input!"
290
+ )
291
+ for branch in [true_fn, false_fn]:
292
+ if _has_potential_branch_input_alias(
293
+ branch, unwrapped_inputs, pre_dispatch=pre_dispatch
294
+ ):
295
+ raise UnsupportedAliasMutationException(
296
+ "One of torch.cond branch might be aliasing the input!"
297
+ )
298
+
299
+ cond_return = cond_op(
300
+ unwrapped_pred, functional_true, functional_false, unwrapped_inputs
301
+ )
302
+ return ctx.wrap_tensors(cond_return)
303
+
304
+
305
+ @cond_op.py_impl(torch._C._functorch.TransformType.Vmap)
306
+ def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs):
307
+ assert isinstance(
308
+ inputs, (list, tuple)
309
+ ), "Cond inputs must be a list or tuple of tensors"
310
+ assert all(
311
+ isinstance(i, torch.Tensor) for i in inputs
312
+ ), "Cond inputs must be a list of tensors"
313
+
314
+ pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred
315
+
316
+ # unbatched tensors are not vmapped
317
+ tensors, in_dims = zip(
318
+ *[
319
+ (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None)
320
+ for t in inputs
321
+ ]
322
+ )
323
+
324
+ if is_batchedtensor(pred):
325
+ # prepend "pred" and vmap everything
326
+ tensors = (pred_,) + tensors
327
+ in_dims = (0,) + in_dims
328
+
329
+ def fn(p, *args):
330
+ t = true_fn(*args)
331
+ f = false_fn(*args)
332
+ return torch.where(p, t[0], f[0])
333
+
334
+ with interpreter.lower():
335
+ result = torch.vmap(fn, in_dims=in_dims)(*tensors)
336
+
337
+ else:
338
+ # predicate is known at this stage and it is a boolean expression or a
339
+ # tensor with one element.
340
+ true_fn = torch.vmap(true_fn, in_dims=in_dims)
341
+ false_fn = torch.vmap(false_fn, in_dims=in_dims)
342
+
343
+ with interpreter.lower():
344
+ result = cond_op(pred, true_fn, false_fn, tensors)
345
+
346
+ if not isinstance(result, tuple):
347
+ result = (result,)
348
+ lvl = interpreter.level()
349
+ return tuple([_add_batch_dim(r, 0, lvl) for r in result])
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/effects.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Any, Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.utils._pytree as pytree
6
+ from torch._C import DispatchKey
7
+ from torch._ops import HigherOrderOperator
8
+ from torch._subclasses.fake_tensor import FakeTensorMode
9
+ from torch.fx.experimental.proxy_tensor import (
10
+ disable_proxy_modes_tracing,
11
+ ProxyTorchDispatchMode,
12
+ track_tensor_tree,
13
+ )
14
+
15
+
16
+ class _EffectType(Enum):
17
+ ORDERED = "Ordered"
18
+
19
+
20
+ SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = {
21
+ torch.ops.aten._print.default: _EffectType.ORDERED,
22
+ }
23
+
24
+
25
+ class WithEffects(HigherOrderOperator):
26
+ """
27
+ with_effects(token, op, args, kwargs) -> (new_token, op_results)
28
+
29
+ This HOP helps ensure ordering between side effectful ops like prints or ops
30
+ using torchbind objects. This is needed to ensure a traced graph from
31
+ AOTAutograd is functional so that future optimization passes do not reorder
32
+ these operators. This is done through threading "effect tokens" through the
33
+ graph to enforce data dependence between side effectful ops.
34
+
35
+ The tokens are basically dummy values (torch.tensor([])). We create a token
36
+ per "effect type", which are enumerated in the _EffectType enum.
37
+ """
38
+
39
+ def __init__(self):
40
+ super().__init__("with_effects")
41
+
42
+ def __call__(
43
+ self,
44
+ token,
45
+ op: torch._ops.OpOverload,
46
+ *args: Tuple[Any, ...],
47
+ **kwargs: Dict[str, Any],
48
+ ) -> Tuple[Any, ...]:
49
+ assert isinstance(op, torch._ops.OpOverload)
50
+ assert not has_aliasing(op), "Ops with aliasing is not supported"
51
+ assert has_effects(op, args, kwargs)
52
+ assert isinstance(kwargs, dict)
53
+ return super().__call__(token, op, *args, **kwargs)
54
+
55
+
56
+ with_effects = WithEffects()
57
+
58
+
59
+ def has_aliasing(op: torch._ops.OpOverload):
60
+ for arg in op._schema.arguments:
61
+ if arg.alias_info is not None:
62
+ return True
63
+ for arg in op._schema.returns:
64
+ if arg.alias_info is not None:
65
+ return True
66
+ return False
67
+
68
+
69
+ def has_effects(op, args, kwargs) -> bool:
70
+ return (
71
+ isinstance(op, torch._ops.OpOverload)
72
+ and not has_aliasing(op)
73
+ and get_effect_key(op, args, kwargs) is not None
74
+ )
75
+
76
+
77
+ def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
78
+ if op in SIDE_EFFECTS:
79
+ return SIDE_EFFECTS[op]
80
+
81
+ for arg in args:
82
+ if isinstance(arg, torch.ScriptObject):
83
+ return _EffectType.ORDERED
84
+
85
+ return None
86
+
87
+
88
+ @with_effects.py_impl(DispatchKey.CompositeExplicitAutograd)
89
+ def with_effects_dense(
90
+ token: torch.Tensor,
91
+ op: torch._ops.OpOverload,
92
+ *args: Tuple[Any, ...],
93
+ **kwargs: Dict[str, Any],
94
+ ) -> Tuple[torch.Tensor, ...]:
95
+ out = op(*args, **kwargs)
96
+ new_token = torch.tensor([])
97
+ if isinstance(out, tuple):
98
+ return (new_token, *out)
99
+ return (new_token, out)
100
+
101
+
102
+ @with_effects.py_impl(FakeTensorMode)
103
+ def with_effects_fake(
104
+ mode,
105
+ token: torch.Tensor,
106
+ op: torch._ops.OpOverload,
107
+ *args: Tuple[Any, ...],
108
+ **kwargs: Dict[str, Any],
109
+ ) -> Tuple[torch.Tensor, ...]:
110
+ with mode:
111
+ result = with_effects_dense(token, op, *args, **kwargs)
112
+ return result
113
+
114
+
115
+ @with_effects.py_impl(ProxyTorchDispatchMode)
116
+ def with_effects_proxy(
117
+ mode,
118
+ token: torch.Tensor,
119
+ op: torch._ops.OpOverload,
120
+ *args: Tuple[Any, ...],
121
+ **kwargs: Dict[str, Any],
122
+ ) -> Tuple[torch.Tensor, ...]:
123
+ if not mode.enable_tracing:
124
+ return with_effects(token, op, *args, **kwargs)
125
+
126
+ with disable_proxy_modes_tracing():
127
+ out = with_effects(token, op, *args, **kwargs)
128
+
129
+ proxy_token = mode.tracer.unwrap_proxy(token)
130
+ proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
131
+ proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
132
+
133
+ out_proxy = mode.tracer.create_proxy(
134
+ "call_function",
135
+ with_effects,
136
+ (proxy_token, op, *proxy_args),
137
+ proxy_kwargs,
138
+ )
139
+ result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
140
+ return result
141
+
142
+
143
+ with_effects.fallthrough(DispatchKey.AutogradCPU)
144
+ with_effects.fallthrough(DispatchKey.AutogradCUDA)
145
+
146
+
147
+ def handle_effects(
148
+ allow_token_discovery: bool,
149
+ tokens: Dict[_EffectType, torch.Tensor],
150
+ op: torch._ops.OpOverload,
151
+ args: Tuple[Any, ...],
152
+ kwargs: Dict[str, Any],
153
+ ) -> Any:
154
+ """
155
+ Args:
156
+ allow_token_discovery: Whether or not we are discovering tokens. If this
157
+ is true, we will create a token for every side effect type seen that
158
+ does not have a token assigned yet. If this is false, the tokens
159
+ should've all been created ahead of time, so we will error if there is
160
+ no token mapping to every effect type.
161
+
162
+ tokens: Map of effect type to tokens. This is to chain operators of the
163
+ same effects together so that they do not get reordered in later
164
+ optimization passes.
165
+ """
166
+
167
+ # Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
168
+ # this will create an empty tensor during proxy mode tracing if the token
169
+ # doesn't exist. But the tokens should always exist during proxy mode tracing.
170
+ key = get_effect_key(op, args, kwargs)
171
+ assert key is not None
172
+ if key not in tokens:
173
+ assert allow_token_discovery, f"Could not find a token for effect {key}"
174
+ tokens[key] = torch.tensor([])
175
+ token = tokens[key]
176
+
177
+ from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
178
+
179
+ ctx = PythonFunctionalizeAPI()
180
+
181
+ unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type]
182
+ unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type]
183
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
184
+ with ctx.redispatch_to_next():
185
+ (new_token, *unwrapped_outs) = with_effects(
186
+ unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type]
187
+ )
188
+
189
+ if len(op._schema.returns) == 0:
190
+ assert unwrapped_outs[0] is None
191
+ unwrapped_outs = None # type: ignore[assignment]
192
+ elif len(op._schema.returns) == 1:
193
+ assert len(unwrapped_outs) == 1
194
+ unwrapped_outs = unwrapped_outs[0]
195
+ else:
196
+ assert len(unwrapped_outs) == len(op._schema.returns)
197
+
198
+ # Add the newly created token into the tokens map for a following call to
199
+ # use this token.
200
+ wrapped_token = ctx.wrap_tensors(new_token)
201
+ assert isinstance(wrapped_token, torch.Tensor)
202
+ tokens[key] = wrapped_token
203
+
204
+ return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/map.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils._pytree as pytree
3
+ from torch._C import DispatchKey
4
+ from torch._dispatch.python import suspend_functionalization
5
+ from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
6
+
7
+ from torch._higher_order_ops.utils import (
8
+ _has_potential_branch_input_alias,
9
+ _has_potential_branch_input_mutation,
10
+ reenter_make_fx,
11
+ UnsupportedAliasMutationException,
12
+ )
13
+ from torch._ops import HigherOrderOperator
14
+ from torch._subclasses.fake_tensor import FakeTensorMode
15
+ from torch._subclasses.functional_tensor import (
16
+ disable_functional_mode,
17
+ FunctionalTensor,
18
+ )
19
+ from torch.fx.experimental.proxy_tensor import (
20
+ disable_proxy_modes_tracing,
21
+ make_fx,
22
+ ProxyTorchDispatchMode,
23
+ track_tensor_tree,
24
+ )
25
+ from torch.multiprocessing.reductions import StorageWeakRef
26
+
27
+
28
+ # TODO: We add this to prevent dymamo from tracing into map_wrapper,
29
+ # remove the wrapper call when it's ready.
30
+ class MapWrapper(HigherOrderOperator):
31
+ def __call__(self, xs, *args):
32
+ return map_wrapper(xs, *args)
33
+
34
+
35
+ map = MapWrapper("map")
36
+ map_impl = HigherOrderOperator("map_impl")
37
+
38
+ dummy_aot_config = AOTConfig(
39
+ fw_compiler=None, # type: ignore[arg-type]
40
+ bw_compiler=None, # type: ignore[arg-type]
41
+ partition_fn=None, # type: ignore[arg-type]
42
+ decompositions={},
43
+ num_params_buffers=0,
44
+ aot_id=0,
45
+ keep_inference_input_mutations=False,
46
+ )
47
+
48
+
49
+ def create_fw_bw_graph(f, num_mapped_args, *args):
50
+ mapped_xs = args[:num_mapped_args]
51
+ pos_args = args[num_mapped_args:]
52
+
53
+ # Note: We create "clean" environments for make_fx by suspending all dispatch keys
54
+ # between Autograd and Python key. Currently, we only suspend functionalization but more can be
55
+ # added when required. Will encounter two problems if we don't suspend functionalization:
56
+ #
57
+ # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
58
+ # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
59
+ # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
60
+ # fetch the proxy for the inputs and fail to capture any operations on them.
61
+ #
62
+ # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
63
+ # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
64
+ # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
65
+ # when creating the output node, it fails to associate the wrapped tensor with its proxy.
66
+ # Instead, it will create _tensor_constant as output.
67
+
68
+ with suspend_functionalization(), disable_functional_mode():
69
+ with disable_proxy_modes_tracing():
70
+
71
+ def _from_fun(t):
72
+ if isinstance(t, torch.Tensor):
73
+ if t.dtype != torch.bool:
74
+ return torch.empty_strided(
75
+ t.size(),
76
+ t.stride(),
77
+ dtype=t.dtype,
78
+ requires_grad=t.requires_grad,
79
+ )
80
+ else:
81
+ # clone of a functional tensor produces a functional tensor
82
+ # but we want to avoid it so we clone a non-functional version
83
+ maybe_unfunc_t = t
84
+ if isinstance(t, FunctionalTensor):
85
+ torch._sync(t)
86
+ maybe_unfunc_t = from_fun(t)
87
+ elif torch._is_functional_tensor(t):
88
+ # need to handle both types of functionalization here:
89
+ # these are the tensors that came from the user,
90
+ # which could be either FunctionalTensorWrapper or FunctionalTensor
91
+ torch._sync(t)
92
+ maybe_unfunc_t = torch._from_functional_tensor(t)
93
+ return maybe_unfunc_t.clone()
94
+ return t
95
+
96
+ unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs)
97
+ example_xs = _unstack_pytree(unwrapped_mapped_xs)[0]
98
+
99
+ example_pos_args = [
100
+ _from_fun(arg) if isinstance(arg, torch.Tensor) else arg
101
+ for arg in pos_args
102
+ ]
103
+ example_flat_out = pytree.tree_map(
104
+ _from_fun, f(*example_xs, *example_pos_args)
105
+ )
106
+ if any(
107
+ not isinstance(out, torch.Tensor)
108
+ for out in example_flat_out
109
+ if out is not None
110
+ ):
111
+ raise RuntimeError(
112
+ "Expect outputs of map only contains tensors or None. "
113
+ f"Got types {[type(out) for out in example_flat_out]}."
114
+ )
115
+ example_grad = [_from_fun(out) for out in example_flat_out]
116
+
117
+ fw_graph = make_fx(f)(*example_xs, *example_pos_args)
118
+
119
+ def joint_f(*example_args):
120
+ joint_mapped_args = example_args[:joint_num_mapped]
121
+ args = example_args[joint_num_mapped:]
122
+
123
+ mapped_input = joint_mapped_args[:num_mapped_args]
124
+ mapped_grads = joint_mapped_args[num_mapped_args:]
125
+
126
+ def fw_with_masks(*args):
127
+ fw_out = f(*args)
128
+ return fw_out, [
129
+ True
130
+ if isinstance(ret, torch.Tensor) and ret.requires_grad
131
+ else False
132
+ for ret in fw_out
133
+ ]
134
+
135
+ joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
136
+ _, grads = joint(
137
+ list(mapped_input) + list(args),
138
+ [
139
+ grad
140
+ for grad in mapped_grads
141
+ if grad is not None and grad.requires_grad
142
+ ],
143
+ )
144
+
145
+ # In order to keep map functional for backward graph,
146
+ # we clone outputs that are aliasing inputs
147
+ input_storage = {
148
+ StorageWeakRef(arg._typed_storage())
149
+ for arg in example_args
150
+ if isinstance(arg, torch.Tensor)
151
+ }
152
+
153
+ def maybe_clone(t):
154
+ if (
155
+ isinstance(t, torch.Tensor)
156
+ and StorageWeakRef(t._typed_storage()) in input_storage
157
+ ):
158
+ return t.clone()
159
+ return t
160
+
161
+ return pytree.tree_map(maybe_clone, grads)
162
+
163
+ joint_num_mapped = len(example_grad) + len(example_xs)
164
+ joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args)
165
+ return fw_graph, joint_graph
166
+
167
+
168
+ def map_wrapper(f, xs, *args):
169
+ flat_xs, xs_spec = pytree.tree_flatten(xs)
170
+ if not all(isinstance(t, torch.Tensor) for t in flat_xs):
171
+ raise RuntimeError(f"Mapped xs can only consist of tensors. Got xs {flat_xs}.")
172
+
173
+ num_mapped_args = len(flat_xs)
174
+ shapes = [xs.shape for xs in flat_xs]
175
+ leading_dim_size = shapes[0][0]
176
+ if leading_dim_size == 0:
177
+ raise RuntimeError("Leading dimensions of mapped xs cannot be 0.")
178
+
179
+ if any(cur_shape[0] != leading_dim_size for cur_shape in shapes):
180
+ raise RuntimeError(
181
+ f"Leading dimensions of mapped xs must be consistent. Got shapes {shapes}."
182
+ )
183
+
184
+ out_spec = None
185
+
186
+ def flat_fn(*flat_args):
187
+ xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec)
188
+ unflattened_out = f(xs, *flat_args[num_mapped_args:])
189
+ flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)
190
+
191
+ nonlocal out_spec
192
+ out_spec = tmp_out_spec
193
+ return flat_out
194
+
195
+ return pytree.tree_unflatten(
196
+ map_impl(flat_fn, flat_xs, args), out_spec # type: ignore[arg-type]
197
+ )
198
+
199
+
200
+ class MapAutogradOp(torch.autograd.Function):
201
+ @staticmethod
202
+ def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args):
203
+ ctx.save_for_backward(*flat_args)
204
+ ctx._joint_graph = joint_graph
205
+ ctx._num_mapped_args = num_mapped_args
206
+ with torch._C._AutoDispatchBelowAutograd():
207
+ return (
208
+ *map_impl(
209
+ fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:]
210
+ ),
211
+ )
212
+
213
+ @staticmethod
214
+ def backward(ctx, *flat_grads):
215
+ fw_args = ctx.saved_tensors
216
+ fw_mapped_args = fw_args[: ctx._num_mapped_args]
217
+ pos_args = fw_args[ctx._num_mapped_args :]
218
+
219
+ grads = map_impl(
220
+ ctx._joint_graph,
221
+ fw_mapped_args + flat_grads,
222
+ pos_args,
223
+ )
224
+ return None, None, None, *grads
225
+
226
+
227
+ def trace_map(proxy_mode, func_overload, f, xs, pos_args):
228
+ leading_dim_size = xs[0].shape[0]
229
+
230
+ example_input = _unstack_pytree(xs)[0]
231
+ body_graph = f
232
+
233
+ pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
234
+ body_graph = reenter_make_fx(body_graph, pre_dispatch)(*example_input, *pos_args)
235
+
236
+ next_name = None
237
+ i = 0
238
+ while not next_name:
239
+ candidate = f"body_graph_{i}"
240
+ if hasattr(proxy_mode.tracer.root, candidate):
241
+ i += 1
242
+ else:
243
+ next_name = candidate
244
+
245
+ proxy_mode.tracer.root.register_module(next_name, body_graph)
246
+
247
+ with disable_proxy_modes_tracing():
248
+ example_outs = body_graph(*example_input, *pos_args)
249
+
250
+ def expand_tensor(t):
251
+ if isinstance(t, torch.Tensor):
252
+ return t.expand(leading_dim_size, *t.shape)
253
+ return t
254
+
255
+ expanded_outs = pytree.tree_map(expand_tensor, example_outs)
256
+
257
+ node_args = (body_graph, list(xs), list(pos_args))
258
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
259
+ out_proxy = proxy_mode.tracer.create_proxy(
260
+ "call_function", func_overload, proxy_args, {}, name="map_impl"
261
+ )
262
+ return track_tensor_tree(
263
+ expanded_outs, out_proxy, constant=None, tracer=proxy_mode.tracer
264
+ )
265
+
266
+
267
+ def _unstack_pytree(xs):
268
+ flat_xs, inspec = pytree.tree_flatten(xs)
269
+ if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
270
+ raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
271
+
272
+ if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
273
+ raise RuntimeError(
274
+ f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
275
+ )
276
+
277
+ a = zip(*flat_xs)
278
+
279
+ pytrees = []
280
+ for tuple in a:
281
+ pytrees.append(pytree.tree_unflatten(tuple, inspec))
282
+ return pytrees
283
+
284
+
285
+ def _stack_pytree(pytrees):
286
+ flat_out = []
287
+ out_spec = None
288
+ for pt in pytrees:
289
+ flat_pt, out_spec = pytree.tree_flatten(pt)
290
+ flat_out.append(flat_pt)
291
+ assert out_spec is not None
292
+ b = zip(*flat_out)
293
+ stacked_out = []
294
+ for leaves in b:
295
+ if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
296
+ stacked_out.append(torch.stack(leaves))
297
+ elif all(leaf is None for leaf in leaves):
298
+ # Backward graph can return None output when forward inputs doesn't require grad.
299
+ # When we eagerly execute backward graph, we need to call _stack_pytree on its output,
300
+ # therefore we need to deal with None output.
301
+ stacked_out.append(None) # type: ignore[arg-type]
302
+ else:
303
+ raise RuntimeError(f"Cannot stack {leaves}.")
304
+ return pytree.tree_unflatten(stacked_out, out_spec)
305
+
306
+
307
+ @map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
308
+ def map_dense(f, xs, pos_args):
309
+ pytrees = []
310
+ for inp in _unstack_pytree(xs):
311
+ pytrees.append(f(*inp, *pos_args))
312
+ return _stack_pytree(pytrees)
313
+
314
+
315
+ @map_impl.py_impl(DispatchKey.Autograd)
316
+ def map_autograd(f, xs, pos_args):
317
+ num_mapped_args = len(xs)
318
+ fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args)
319
+ flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
320
+ return flat_out
321
+
322
+
323
+ @map_impl.py_impl(ProxyTorchDispatchMode)
324
+ def map_proxy_torch_dispatch_mode(mode, f, xs, args):
325
+ if mode.enable_tracing:
326
+ return trace_map(mode, map_impl, f, xs, args)
327
+ else:
328
+ return map_impl(f, xs, args)
329
+
330
+
331
+ @map_impl.py_impl(FakeTensorMode)
332
+ def map_fake_tensor_mode(mode, f, xs, args):
333
+ with mode:
334
+ return map_dense(f, xs, args)
335
+
336
+
337
+ @map_impl.py_functionalize_impl
338
+ def map_functionalize(ctx, f, xs, pos_args):
339
+ unwrapped_xs = ctx.unwrap_tensors(xs)
340
+ unwrapped_args = ctx.unwrap_tensors(pos_args)
341
+ wrapped_fn = ctx.functionalize(f)
342
+
343
+ with ctx.redispatch_to_next():
344
+ with disable_proxy_modes_tracing():
345
+ example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
346
+ pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
347
+ if _has_potential_branch_input_mutation(
348
+ f, example_inputs, pre_dispatch=pre_dispatch
349
+ ):
350
+ raise UnsupportedAliasMutationException("torch.map is mutating the input!")
351
+
352
+ if _has_potential_branch_input_alias(
353
+ f, example_inputs, pre_dispatch=pre_dispatch
354
+ ):
355
+ raise UnsupportedAliasMutationException("torch.map is aliasing the input!")
356
+
357
+ map_return = map_impl(wrapped_fn, unwrapped_xs, unwrapped_args)
358
+ return ctx.wrap_tensors(map_return)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/strict_mode.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch._subclasses.functional_tensor
3
+
4
+ import torch.utils._pytree as pytree
5
+
6
+ from torch._C import DispatchKey
7
+ from torch._functorch.utils import exposed_in
8
+
9
+ from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented
10
+ from torch._ops import HigherOrderOperator
11
+ from torch._subclasses.fake_tensor import FakeTensorMode
12
+ from torch.fx.experimental.proxy_tensor import (
13
+ disable_proxy_modes_tracing,
14
+ make_fx,
15
+ ProxyTorchDispatchMode,
16
+ track_tensor_tree,
17
+ )
18
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
19
+
20
+
21
+ @exposed_in("torch")
22
+ def strict_mode(callable, operands):
23
+ if torch.compiler.is_dynamo_compiling():
24
+ return strict_mode_op(callable, operands)
25
+
26
+ with _set_compilation_env():
27
+ with torch._dynamo.utils.disable_cache_limit():
28
+ return torch.compile(strict_mode_op, backend="eager", fullgraph=True)(
29
+ callable, operands
30
+ )
31
+
32
+
33
+ strict_mode_op = HigherOrderOperator("strict_mode")
34
+
35
+
36
+ @strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd)
37
+ def strict_mode_op_dense(callable, operands):
38
+ mode = _get_current_dispatch_mode()
39
+ assert mode is None, "Mode should never be enabled for CPU/CUDA key"
40
+ return callable(*operands)
41
+
42
+
43
+ strict_mode_op.py_impl(DispatchKey.Autograd)(
44
+ autograd_not_implemented(strict_mode_op, deferred_error=True)
45
+ )
46
+
47
+
48
+ @strict_mode_op.py_impl(ProxyTorchDispatchMode)
49
+ def inner(mode, callable, operands):
50
+ if mode.enable_tracing:
51
+ return trace_strict_mode(mode, strict_mode_op, callable, operands)
52
+ else:
53
+ return strict_mode_op(callable, operands)
54
+
55
+
56
+ def trace_strict_mode(mode, strict_mode_op, callable, operands):
57
+ pre_dispatch = getattr(mode, "pre_dispatch", False)
58
+
59
+ with disable_proxy_modes_tracing():
60
+ graph = make_fx(callable, pre_dispatch=pre_dispatch)(*operands)
61
+
62
+ next_name = None
63
+ i = 0
64
+ while not next_name:
65
+ candidate = f"strict_graph_{i}"
66
+ if hasattr(mode.tracer.root, candidate):
67
+ i += 1
68
+ else:
69
+ next_name = candidate
70
+
71
+ graph_name = next_name
72
+ mode.tracer.root.register_module(graph_name, graph)
73
+
74
+ args = (graph, operands)
75
+
76
+ proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
77
+
78
+ out_proxy = mode.tracer.create_proxy(
79
+ "call_function", strict_mode_op, proxy_args, {}, name="strict_mode"
80
+ )
81
+
82
+ out = graph(*operands)
83
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
84
+
85
+
86
+ @strict_mode_op.py_impl(FakeTensorMode)
87
+ def strict_mode_fake_tensor_mode(mode, callable, operands):
88
+ with mode:
89
+ true_outs = callable(*operands)
90
+ return true_outs
91
+
92
+
93
+ @strict_mode_op.py_functionalize_impl
94
+ def strict_mode_func(ctx, callable, inputs):
95
+ unwrapped_inputs = ctx.unwrap_tensors(inputs)
96
+ with ctx.redispatch_to_next():
97
+ functional_callable = ctx.functionalize(callable)
98
+
99
+ cond_return = strict_mode_op(functional_callable, unwrapped_inputs)
100
+ return ctx.wrap_tensors(cond_return)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/torchbind.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+
3
+ import torch
4
+ from torch._C import DispatchKey # @manual
5
+ from torch._functorch._aot_autograd.utils import KNOWN_TYPES
6
+ from torch._higher_order_ops.utils import autograd_not_implemented
7
+ from torch._ops import HigherOrderOperator
8
+ from torch._subclasses.fake_tensor import FakeTensorMode
9
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
10
+ from torch.fx.node import has_side_effect
11
+ from torch.utils import _pytree as pytree
12
+
13
+ # The call_torchbind operator represents a method invocation on a torchbind
14
+ # object. The calling convention is:
15
+ # call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
16
+ # We do not expect users to write this operator directly. Instead it will be
17
+ # emitted by Dynamo when tracing encounters a torchbind object.
18
+ call_torchbind = HigherOrderOperator("call_torchbind")
19
+
20
+ # Register this operator as side-effectful with FX.
21
+ # TODO: this is not really sufficient. While passes (hopefully) check
22
+ # Node.is_impure() and make good decisions, we also assume we can execute the
23
+ # graph as many times as we want without changing behavior, which is NOT true of
24
+ # ops that mutate torchbind object state.
25
+ has_side_effect(call_torchbind)
26
+
27
+ _orig_scriptmethod_call = torch.ScriptMethod.__call__
28
+
29
+
30
+ def torchbind_method_redispatch(self, *args, **kwargs):
31
+ if isinstance(self.raw_owner, torch.ScriptObject):
32
+ return call_torchbind(self.raw_owner, self.name, *args, **kwargs)
33
+ return _orig_scriptmethod_call(self, *args, **kwargs)
34
+
35
+
36
+ @contextmanager
37
+ def enable_torchbind_tracing():
38
+ """Context manager that acts as a feature flag to enable torchbind tracing
39
+ behavior. Once torchbind tracing has been stabilized, we can remove this and
40
+ turn it always on.
41
+ """
42
+ try:
43
+ KNOWN_TYPES.append(torch.ScriptObject)
44
+ torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign]
45
+ yield
46
+ finally:
47
+ assert (
48
+ KNOWN_TYPES.pop() is torch.ScriptObject
49
+ ), "Someone else messed with KNOWN_TYPES during tracing, exploding."
50
+ torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign]
51
+
52
+
53
+ @call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd)
54
+ def call_torchbind_impl(obj, method, *args, **kwargs):
55
+ return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)
56
+
57
+
58
+ @call_torchbind.py_impl(ProxyTorchDispatchMode)
59
+ def inner(mode, *args, **kwargs):
60
+ if mode.enable_tracing:
61
+ proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
62
+ proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
63
+
64
+ out_proxy = mode.tracer.create_proxy(
65
+ "call_function",
66
+ call_torchbind,
67
+ proxy_args,
68
+ proxy_kwargs,
69
+ )
70
+ out = call_torchbind_impl(*args, **kwargs)
71
+
72
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
73
+ else:
74
+ return call_torchbind(*args, **kwargs)
75
+
76
+
77
+ # TODO: currently we just run the C++ implementation with fake tensors.
78
+ # But we should make it possible to register a fake torchbind implementation.
79
+ @call_torchbind.py_impl(FakeTensorMode)
80
+ def call_torchbind_fake(mode, *args, **kwargs):
81
+ with mode:
82
+ return call_torchbind_impl(*args, **kwargs)
83
+
84
+
85
+ call_torchbind.py_impl(DispatchKey.Autograd)(
86
+ autograd_not_implemented(call_torchbind, deferred_error=True)
87
+ )
88
+
89
+
90
+ @call_torchbind.py_functionalize_impl
91
+ def call_torchbind_func(ctx, *args, **kwargs):
92
+ args = ctx.unwrap_tensors(args)
93
+ with ctx.redispatch_to_next():
94
+ return ctx.wrap_tensors(call_torchbind(*args, **kwargs))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import threading
4
+ import warnings
5
+ from collections import defaultdict
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ import torch.utils._pytree as pytree
9
+ from torch import Tensor
10
+ from torch._C import DispatchKey
11
+ from torch._ops import HigherOrderOperator
12
+ from torch._prims_common import clone_preserve_strides
13
+ from torch._subclasses.fake_tensor import FakeTensorMode
14
+ from torch.fx.experimental.proxy_tensor import (
15
+ disable_proxy_modes_tracing,
16
+ ProxyTorchDispatchMode,
17
+ track_tensor_tree,
18
+ )
19
+
20
+ log = logging.getLogger("torch._dynamo")
21
+
22
+
23
+ ###############################################################################
24
+ # Kernel Side Table
25
+
26
+
27
+ # We cannot put Triton Kernels into the FX graph as the graph nodes
28
+ # do not support arbitrary functions.
29
+ # Use a side table.
30
+ # We use two dicts so that fetching both the kernel and id are O(1)
31
+ class KernelSideTable:
32
+ id_to_kernel: Dict[int, Any] = dict()
33
+ kernel_to_id: Dict[Any, int] = dict()
34
+ lock = threading.Lock()
35
+
36
+ # Returns index on the table
37
+ def add_kernel(self, kernel) -> int:
38
+ with self.lock:
39
+ if kernel in self.kernel_to_id:
40
+ return self.kernel_to_id[kernel]
41
+
42
+ idx = len(self.id_to_kernel)
43
+ self.id_to_kernel[idx] = kernel
44
+ self.kernel_to_id[kernel] = idx
45
+ return idx
46
+
47
+ # Returns the triton kernel at the given index
48
+ def get_kernel(self, idx: int):
49
+ # No need to lock here as fetching from dict is atomic
50
+ assert idx in self.id_to_kernel
51
+ return self.id_to_kernel[idx]
52
+
53
+ # Resets the table (only meant to be used in unit tests)
54
+ # This is only safe assuming single threaded execution
55
+ def reset_table(self) -> None:
56
+ self.id_to_kernel = dict()
57
+ self.kernel_to_id = dict()
58
+
59
+
60
+ kernel_side_table = KernelSideTable()
61
+
62
+
63
+ ###############################################################################
64
+ # Mutation Tracker
65
+
66
+
67
+ @dataclasses.dataclass(frozen=True)
68
+ class Param:
69
+ idx: int
70
+
71
+
72
+ @dataclasses.dataclass(frozen=True)
73
+ class Intermediate:
74
+ idx: int
75
+
76
+ def fake(self):
77
+ return self.idx < 0
78
+
79
+
80
+ @dataclasses.dataclass(frozen=True)
81
+ class Op:
82
+ name: str
83
+ fn_call_name: Optional[str]
84
+ args: List[Union[Param, Intermediate]]
85
+ ret: Intermediate = dataclasses.field(repr=False)
86
+
87
+ def __post_init__(self):
88
+ if self.name == "tt.call":
89
+ assert self.fn_call_name is not None
90
+ else:
91
+ assert self.fn_call_name is None
92
+
93
+
94
+ def generate_ttir(kernel, kwargs):
95
+ """
96
+ Uses Triton's internal code generation to create TTIR
97
+ """
98
+ from triton.compiler.compiler import ASTSource
99
+ from triton.runtime.autotuner import Autotuner
100
+ from triton.runtime.jit import JITFunction
101
+
102
+ import torch
103
+ from torch._subclasses.fake_tensor import FakeTensor
104
+
105
+ if isinstance(kernel, Autotuner):
106
+ if len(kernel.configs) > 0:
107
+ # If we are autotuning, then it doesn't matter which version gets
108
+ # picked for tracing purposes, so lets pick the first one
109
+ kwargs = {**kwargs, **kernel.configs[0].kwargs}
110
+ kernel = kernel.fn
111
+
112
+ assert isinstance(kernel, JITFunction)
113
+
114
+ if len(kwargs) != len(kernel.arg_names):
115
+ raise Exception("Incorrect number of arguments passed to kernel")
116
+
117
+ # Replace all SymExprs with a regular value for TTIR generation
118
+ # Replace all FakeTensor with real tensors
119
+ # These replacements are needed for triton's type, key and config functions
120
+ ordered_args: Dict[str, Any] = {}
121
+ for name in kernel.arg_names:
122
+ a = kwargs[name]
123
+ if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool)):
124
+ ordered_args[name] = 2
125
+ elif isinstance(a, FakeTensor):
126
+ ordered_args[name] = torch.empty(2, dtype=a.dtype)
127
+ else:
128
+ ordered_args[name] = a
129
+
130
+ ordered_tensor_names = [
131
+ name for name, arg in ordered_args.items() if isinstance(arg, Tensor)
132
+ ]
133
+ specialization = kernel._get_config(*ordered_args.values())
134
+ constants = {
135
+ i: arg
136
+ for i, arg in enumerate(ordered_args.values())
137
+ if not isinstance(arg, Tensor)
138
+ }
139
+
140
+ # Build kernel signature -- doesn't include constexpr arguments.
141
+ signature = {
142
+ i: kernel._type_of(kernel._key_of(arg))
143
+ for i, arg in enumerate(ordered_args.values())
144
+ if i not in kernel.constexprs
145
+ }
146
+
147
+ def get_backend():
148
+ from triton.compiler.backends.cuda import CUDABackend
149
+ from triton.runtime.driver import driver
150
+
151
+ target = driver.get_current_target()
152
+ return CUDABackend(target)
153
+
154
+ backend = get_backend()
155
+
156
+ options = backend.parse_options(dict())
157
+ # triton._C.libtriton.triton.ir.load_dialects(context)
158
+ # backend.load_dialects(context)
159
+
160
+ src = ASTSource(kernel, signature, constants, specialization)
161
+ ttir_module = src.make_ir(options)
162
+ if not ttir_module.verify():
163
+ raise Exception("Verification for TTIR module has failed")
164
+
165
+ return ttir_module, ordered_tensor_names
166
+
167
+
168
+ def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]:
169
+ """
170
+ Walk the `ttir_module` bottom up to mine the `functions` from
171
+ the structured MLIR entities representing the Triton kernel
172
+ (mlir::Operation, mlir::Block, mlir::Region).
173
+ """
174
+ functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
175
+
176
+ # block id --> op result (Intermediate) --> one or more ops
177
+ op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict(
178
+ lambda: defaultdict(list)
179
+ )
180
+ region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list)
181
+ block_id_to_block_arg_ids: Dict[int, List[int]] = {}
182
+ replacements: Dict[int, Union[Intermediate, Param]] = {}
183
+ reindex_map: Dict[int, int] = {}
184
+ next_fake_intermediate = 0
185
+
186
+ def reindex(idx):
187
+ if idx not in reindex_map:
188
+ reindex_map[idx] = len(reindex_map)
189
+ return reindex_map[idx]
190
+
191
+ def mlir_to_functions(op) -> None:
192
+ name: str = op.get_name()
193
+ if name == "builtin.module":
194
+ # this wraps all tt.func ops
195
+ return
196
+
197
+ operand_ids: List[int] = [
198
+ reindex(op.get_operand(i).id()) for i in range(op.get_num_operands())
199
+ ]
200
+ result_ids: List[int] = [
201
+ reindex(op.get_result(i).id()) for i in range(op.get_num_results())
202
+ ]
203
+
204
+ child_block_ids: List[int] = []
205
+ for i in [op.get_region(i).id() for i in range(op.get_num_regions())]:
206
+ # as the walk is bottom-up, the region_id_to_block_ids[i]
207
+ # must be populated by the time we process the enclosing op
208
+ child_block_ids.extend(region_id_to_block_ids[i])
209
+
210
+ parent_block_id = -1
211
+ parent_block = op.get_block()
212
+ if parent_block is not None:
213
+ parent_block_id = parent_block.id()
214
+ if parent_block_id not in block_id_to_block_arg_ids:
215
+ block_id_to_block_arg_ids[parent_block_id] = []
216
+ for i in range(parent_block.get_num_arguments()):
217
+ block_id_to_block_arg_ids[parent_block_id].append(
218
+ reindex(parent_block.get_argument(i).id()),
219
+ )
220
+ # the region info is collected via ops' parent blocks to be
221
+ # used later when the region's encloding op is traversed
222
+ parent_region = parent_block.get_parent()
223
+ if parent_region is not None:
224
+ region_id_to_block_ids[parent_region.id()].append(parent_block_id)
225
+
226
+ nonlocal next_fake_intermediate
227
+
228
+ if name == "tt.func":
229
+ # for function ops: gather and inline
230
+ # the ops from all child blocks
231
+ fn_ops = defaultdict(list)
232
+ for child_block_id in child_block_ids:
233
+ for result, block_fn_ops in op_stack.pop(child_block_id).items():
234
+ for block_fn_op in block_fn_ops:
235
+ fn_ops[result].append(block_fn_op)
236
+
237
+ # replace the corresponding Intermediates in the
238
+ # child op args with the function args (Params)
239
+ for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]):
240
+ replacements[idx] = Param(i)
241
+
242
+ for fn_op_list in fn_ops.values():
243
+ for fn_op in fn_op_list:
244
+ for i in range(len(fn_op.args)):
245
+ arg = fn_op.args[i]
246
+ if isinstance(arg, Intermediate) and arg.idx in replacements:
247
+ fn_op.args[i] = replacements[arg.idx]
248
+
249
+ # next function capture starts
250
+ # with empty replacements
251
+ replacements.clear()
252
+
253
+ fn_name = op.get_str_attr("sym_name")
254
+ functions[fn_name] = fn_ops
255
+ elif child_block_ids:
256
+ if name in ("scf.if", "scf.for", "scf.while"):
257
+ # for blocked control flow ops: inline the enclosed
258
+ # ops into the parent block + rewire the last op in
259
+ # each child block (yield) to return the scf result
260
+ yield_ops = []
261
+ for block_id in child_block_ids:
262
+ # the block args used as operands of the ops in the block
263
+ # (and nested blocks inlined in the current block by now)
264
+ # are replaced by new fake Intermediates to avoid "this
265
+ # operand is not returned by anything other op in the fn"
266
+ # error in the downstream analysis
267
+ for idx in block_id_to_block_arg_ids[block_id]:
268
+ next_fake_intermediate -= 1
269
+ replacements[idx] = Intermediate(next_fake_intermediate)
270
+
271
+ if block_id in op_stack:
272
+ block_ops = op_stack.pop(block_id)
273
+ if not block_ops:
274
+ continue
275
+ last_ret, last_ops = block_ops.popitem()
276
+ if all(op.name == "scf.yield" for op in last_ops):
277
+ # if last_ops are scf.yield, treat them separately
278
+ yield_ops.extend(last_ops)
279
+ else:
280
+ # otherwise, return last_ops to the block
281
+ block_ops[last_ret] = last_ops
282
+ for op_result, child_ops in block_ops.items():
283
+ op_stack[parent_block_id][op_result].extend(child_ops)
284
+
285
+ scf_results = [Intermediate(idx) for idx in result_ids]
286
+ for scf_result in scf_results:
287
+ for yield_op in yield_ops:
288
+ op_stack[parent_block_id][scf_result].append(yield_op)
289
+ else:
290
+ # TODO(oulgen): add support for tt.reduce
291
+ raise Exception(
292
+ f"Unknown blocked function: {name}. Can't capture the TTIR."
293
+ )
294
+ else:
295
+ callee = None
296
+ if name == "tt.call":
297
+ callee = op.get_flat_symbol_ref_attr("callee")
298
+ args: List[Union[Param, Intermediate]] = [
299
+ Intermediate(operand) for operand in operand_ids
300
+ ]
301
+ block_ops = op_stack[parent_block_id]
302
+ if result_ids:
303
+ for result_id in result_ids:
304
+ res = Intermediate(result_id)
305
+ block_ops[res].append(Op(name, callee, args, res))
306
+ else:
307
+ next_fake_intermediate -= 1
308
+ fake_res = Intermediate(next_fake_intermediate)
309
+ block_ops[fake_res].append(Op(name, callee, args, fake_res))
310
+
311
+ ttir_module.walk(mlir_to_functions)
312
+
313
+ return functions
314
+
315
+
316
+ def parse_ttir(ttir, kwargs):
317
+ """
318
+ Given a Triton emitted TTIR text, this function lexes and parses the
319
+ code using a minimal grammar defined inside. During the lexing/parsing,
320
+ we drop any constant value and type information as they are not
321
+ necessary to us.
322
+ Being able to choose what we need makes this not a general purpose TTIR
323
+ parser which further makes parsing much simpler.
324
+ """
325
+ # TODO(oulgen):
326
+ # - Support closures (e.g. "tt.reduce")
327
+
328
+ try:
329
+ import lark # type: ignore[import-not-found]
330
+ from lark import Lark, Transformer, v_args
331
+ except ModuleNotFoundError:
332
+ warnings.warn(
333
+ "Using slow path for user-defined Triton kernels. `pip install lark` to fix this."
334
+ )
335
+ raise
336
+
337
+ # Ops looks like one of the following forms:
338
+ #
339
+ # %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>
340
+ # tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32>
341
+ # %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950
342
+ grammar = """
343
+ start: (module_block | loc_line)+
344
+
345
+ loc_line: "#loc" /.+/ NEWLINE
346
+
347
+ module_block: "module" "{" func_block+ "}" LOC
348
+
349
+ func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func
350
+
351
+ ?stmt: op | if | for | while | condition_stmt | label_stmt | cf_stmt
352
+
353
+ if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if
354
+ for: [assign_lhs "="] "scf.for" args rest stmt* "}" divisibility_annot? LOC -> process_for
355
+ while: [assign_lhs "="] "scf.while" args rest stmt* "}" "do" "{" stmt* "}" LOC -> process_while
356
+
357
+ condition_stmt: "scf.condition" "(" arg ")" args rest
358
+ label_stmt: LABEL ":" "// pred:" LABEL
359
+ | LABEL "(" /.+/ NEWLINE
360
+ cf_stmt: "cf" "." NAME /.+/ NEWLINE
361
+
362
+ op: OP_NAME LOC
363
+ | [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op
364
+
365
+ ?rest: (":" | "{" | "\\"" | "->" | "<" | "=") /.+/ NEWLINE
366
+ divisibility_annot: "{" "tt.divisibility_arg1" /[^}]+/ "}"
367
+
368
+ args: | "(" ")" | "("? arg ("," arg)* ")"?
369
+
370
+ ?arg: INTERMEDIATE
371
+ | INTERMEDIATE_CONSTANT
372
+ | CONSTANT
373
+ | PARAM
374
+ | "[" args "]"
375
+ | arg_with_index
376
+
377
+ ?arg_with_index: arg "#" DIGIT+
378
+
379
+ ?assign_lhs: (INTERMEDIATE | INTERMEDIATE_CONSTANT) [":" DIGIT+]
380
+
381
+ PARAM.5: "%arg" DIGIT+
382
+ INTERMEDIATE.4: "%" DIGIT+
383
+ INTERMEDIATE_CONSTANT.3: "%" NAME
384
+ CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")?
385
+ LABEL: "^bb" DIGIT+
386
+
387
+ NAME: (LETTER | DIGIT | "_")+
388
+ NON_CF_NAME: /(?!(cf))/ NAME
389
+ FN_NAME: "@" (NAME | ESCAPED_STRING)
390
+ OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""?
391
+
392
+ LOC.5: "loc(#loc" DIGIT* ")"
393
+
394
+ %import common.LETTER
395
+ %import common.DIGIT
396
+ %import common.WS
397
+ %import common.NEWLINE
398
+ %import common.ESCAPED_STRING
399
+ %import common.FLOAT
400
+ %ignore WS
401
+ """
402
+
403
+ next_fake_intermediate = 0
404
+
405
+ def convert(token):
406
+ if isinstance(token, lark.tree.Tree):
407
+ if token.data == "args":
408
+ res = []
409
+ for a in token.children:
410
+ c = convert(a)
411
+ if isinstance(c, list):
412
+ res.extend(c)
413
+ else:
414
+ res.append(c)
415
+ return res
416
+ elif token.data in {"assign_lhs", "arg_with_index"}:
417
+ # Drop length/index qualifier
418
+ return convert(token.children[0])
419
+ else:
420
+ raise AssertionError(f"Tree node with {token.data}")
421
+
422
+ if token is None or (
423
+ isinstance(token, lark.lexer.Token)
424
+ and token.type in ("CONSTANT", "INTERMEDIATE_CONSTANT")
425
+ ):
426
+ nonlocal next_fake_intermediate
427
+ next_fake_intermediate -= 1
428
+ return Intermediate(next_fake_intermediate)
429
+
430
+ assert isinstance(token, lark.lexer.Token)
431
+
432
+ if token.type == "INTERMEDIATE":
433
+ return Intermediate(int(token.value[len("%") :]))
434
+ if token.type == "PARAM":
435
+ return Param(int(token.value[len("%arg") :]))
436
+
437
+ raise AssertionError(f"{type(token.type)} => {token.value} invalid")
438
+
439
+ # In alternative representation, function names are quoted.
440
+ # It should be possible to move this into the grammar alltogether.
441
+ def convert_name(token):
442
+ if token is None:
443
+ return None
444
+ s = token.value
445
+ if len(s) > 2 and s[0] == '"' and s[-1] == '"':
446
+ return s[1:-1]
447
+ return s
448
+
449
+ functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
450
+
451
+ def extend_dict_list(d1, d2):
452
+ for key, values in d2.items():
453
+ d1[key].extend(values)
454
+
455
+ @v_args(inline=True)
456
+ class TransformOps(Transformer):
457
+ def process_op(self, ret, op_name, fn_name, args, *rest):
458
+ return Op(
459
+ convert_name(op_name),
460
+ convert_name(fn_name),
461
+ convert(args),
462
+ convert(ret),
463
+ )
464
+
465
+ def process_func(self, name, _args, *stmts):
466
+ ops: Dict[Intermediate, List[Op]] = defaultdict(list)
467
+ for e in stmts:
468
+ if isinstance(e, Op):
469
+ ops[e.ret].append(e)
470
+ elif isinstance(e, dict):
471
+ extend_dict_list(ops, e)
472
+ functions[name.value] = ops
473
+
474
+ def _process_scf(self, ret, stmts):
475
+ ret = convert(ret)
476
+ ops: Dict[Intermediate, List[Op]] = defaultdict(list)
477
+ for e in stmts:
478
+ if isinstance(e, Op):
479
+ if e.name == "scf.yield":
480
+ ops[ret].append(Op(e.name, None, e.args, ret))
481
+ else:
482
+ ops[e.ret].append(e)
483
+ elif isinstance(e, dict):
484
+ extend_dict_list(ops, e)
485
+ return ops
486
+
487
+ def process_if(self, ret, _args, _rest, *stmts):
488
+ return self._process_scf(ret, stmts)
489
+
490
+ def process_for(self, ret, _args, _rest, *stmts):
491
+ return self._process_scf(ret, stmts)
492
+
493
+ def process_while(self, ret, _args, _rest, *stmts):
494
+ return self._process_scf(ret, stmts)
495
+
496
+ parser = Lark(
497
+ grammar, parser="lalr", maybe_placeholders=True, transformer=TransformOps()
498
+ )
499
+ parser.parse(ttir)
500
+ return functions
501
+
502
+
503
+ class MemoizeWithCycleCheck:
504
+ def __init__(self, fn):
505
+ self.fn = fn
506
+ self.reset()
507
+
508
+ def __call__(self, functions, fn_name, num_args):
509
+ key = (fn_name, num_args)
510
+ if key not in self.cache:
511
+ self.cache[key] = None
512
+ self.cache[key] = self.fn(functions, fn_name, num_args)
513
+ if self.cache[key] is None:
514
+ raise Exception("Recursion is not supported")
515
+ return self.cache[key]
516
+
517
+ def reset(self):
518
+ self.cache = {}
519
+
520
+
521
+ @MemoizeWithCycleCheck
522
+ def analyze_kernel_mutations(functions, fn_name, num_args):
523
+ """
524
+ Analyzes the graph to detect all sinks from a predefined list of sinks
525
+ by using triton's MemWrite trait list. NOTE: What if triton exposed this?
526
+ From each sink, it traverses the CFG backwards to identify all the input
527
+ pointers that are mutated.
528
+ """
529
+ # Name of mutation op to mutated parameter indices
530
+ # List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
531
+ # All the OPs that have MemWrite trait.
532
+ # What if Triton exposed this?
533
+ MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]}
534
+ # Ops that we want to bail out on
535
+ UNKNOWN_OPS = {"tt.elementwise_inline_asm"}
536
+
537
+ stack: List[Union[Param, Intermediate]] = []
538
+ visited = set()
539
+ ops = functions[fn_name]
540
+ for op_list in ops.values():
541
+ for op in op_list:
542
+ if op.name in UNKNOWN_OPS:
543
+ raise Exception(
544
+ f"ttir analysis hit an op we do not know how to analyze: {op.name}"
545
+ )
546
+
547
+ if op.name == "tt.call":
548
+ assert op.fn_call_name in functions
549
+ mutations = analyze_kernel_mutations(
550
+ functions, op.fn_call_name, len(op.args)
551
+ )
552
+ stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
553
+ else:
554
+ for idx in MUTATION_OPS.get(op.name, []):
555
+ stack.append(op.args[idx])
556
+
557
+ # The following is an iterative DFS algorithm
558
+ mutated = [False] * num_args
559
+ while stack:
560
+ arg = stack.pop()
561
+ if arg in visited:
562
+ continue
563
+
564
+ visited.add(arg)
565
+
566
+ if isinstance(arg, Param):
567
+ if arg.idx >= num_args:
568
+ # This is an argument defined in the kernel, not passed in
569
+ continue
570
+ mutated[arg.idx] = True
571
+ elif isinstance(arg, Intermediate) and not arg.fake():
572
+ for op in ops[arg]:
573
+ # Skip arguments to load
574
+ if op.name != "tt.load":
575
+ stack.extend(op.args)
576
+ return mutated
577
+
578
+
579
+ def identify_mutated_tensors(kernel, kwargs):
580
+ """
581
+ Given a triton kernel and the arguments for this kernel, this function
582
+ 1) Retrieves the TTIR converted version of the kernel from Triton's API.
583
+ 2) Parses the TTIR and creates a control flow graph
584
+ 3) Analyzes the graph to detect all input tensor mutations
585
+ """
586
+
587
+ ttir_module = None
588
+ functions = None
589
+ try:
590
+ from torch._dynamo import config
591
+
592
+ if not config.optimize_user_defined_triton_kernels:
593
+ raise Exception("optimize_user_defined_triton_kernels is False")
594
+
595
+ ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
596
+
597
+ # extract functions from TTIR
598
+ if hasattr(ttir_module, "walk"):
599
+ # use MLIR bindings exposed by Triton code
600
+ functions = ttir_to_functions(ttir_module)
601
+ else:
602
+ # parse string representation of Triton IR
603
+ functions = parse_ttir(str(ttir_module), kwargs)
604
+
605
+ assert functions is not None
606
+ kernel_name = next(iter(functions.keys()))
607
+ # Triton codegen modifies the name
608
+ assert kernel.fn.__name__ in kernel_name
609
+ # Reset the cache between top level invocations
610
+ # The cache for analyze kernel mutations is mainly used for cycle
611
+ # detection, so each top level invocation needs a clean cache
612
+ analyze_kernel_mutations.reset()
613
+ mutations = analyze_kernel_mutations(
614
+ functions, kernel_name, len(ordered_tensor_names)
615
+ )
616
+
617
+ return [
618
+ ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated
619
+ ]
620
+ except Exception as e:
621
+ import traceback
622
+
623
+ warnings.warn(
624
+ "Encountered an exception in identify_mutated_tensors, "
625
+ "assuming every input is mutated:\n"
626
+ "".join(
627
+ traceback.TracebackException.from_exception(e).format() # noqa: G001
628
+ )
629
+ )
630
+ if ttir_module is not None:
631
+ log.debug("TTIR:\n%s", str(ttir_module))
632
+ if functions is not None:
633
+ log.debug("functions:")
634
+ for name, fn in functions.items():
635
+ log.debug("===\t%s\t===", name)
636
+ for ret, ops in fn.items():
637
+ log.debug("%s\t=>\t%s", ret, ops)
638
+ return [key for key, value in kwargs.items() if isinstance(value, Tensor)]
639
+
640
+
641
+ ###############################################################################
642
+ # Triton Kernel Wrappers
643
+
644
+
645
+ # Used for wrapping a Triton Kernel
646
+ class TritonKernelWrapperMutation(HigherOrderOperator):
647
+ def __init__(self):
648
+ super().__init__("triton_kernel_wrapper_mutation")
649
+
650
+
651
+ triton_kernel_wrapper_mutation = TritonKernelWrapperMutation()
652
+
653
+
654
+ # Used for wrapping a Triton Kernel in a functional manner
655
+ class TritonKernelWrapperFunctional(HigherOrderOperator):
656
+ def __init__(self):
657
+ super().__init__("triton_kernel_wrapper_functional")
658
+
659
+
660
+ triton_kernel_wrapper_functional = TritonKernelWrapperFunctional()
661
+
662
+
663
+ @triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd)
664
+ def triton_kernel_wrapper_mutation_dense(*, kernel_idx, grid, kwargs):
665
+ from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code
666
+
667
+ kernel = kernel_side_table.get_kernel(kernel_idx)
668
+
669
+ if len(grid) == 1:
670
+ grid_fn = grid[0]
671
+ else:
672
+ fn_name, code = user_defined_kernel_grid_fn_code(
673
+ kernel.fn.__name__, kernel.configs, grid
674
+ )
675
+ namespace: Dict[str, Any] = {}
676
+ exec(code, namespace)
677
+ grid_fn = namespace[fn_name]
678
+
679
+ kernel[grid_fn](**kwargs)
680
+
681
+
682
+ @triton_kernel_wrapper_mutation.py_impl(FakeTensorMode)
683
+ def triton_kernel_wrapper_mutation_fake_tensor_mode(mode, *, kernel_idx, grid, kwargs):
684
+ with mode:
685
+ return None
686
+
687
+
688
+ def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args):
689
+ with disable_proxy_modes_tracing():
690
+ out = func_overload(**node_args)
691
+
692
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
693
+ out_proxy = proxy_mode.tracer.create_proxy(
694
+ "call_function",
695
+ func_overload,
696
+ (),
697
+ proxy_args,
698
+ name=func_overload.__name__ + "_proxy",
699
+ )
700
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
701
+
702
+
703
+ @triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode)
704
+ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
705
+ mode, *, kernel_idx, grid, kwargs
706
+ ):
707
+ if mode.enable_tracing:
708
+ trace_triton_kernel_wrapper(
709
+ mode,
710
+ triton_kernel_wrapper_mutation,
711
+ {"kernel_idx": kernel_idx, "grid": grid, "kwargs": kwargs},
712
+ )
713
+ else:
714
+ triton_kernel_wrapper_mutation(kernel_idx=kernel_idx, grid=grid, kwargs=kwargs)
715
+
716
+ return None
717
+
718
+
719
+ @triton_kernel_wrapper_mutation.py_functionalize_impl
720
+ def triton_kernel_wrapper_mutation_functionalize(ctx, kernel_idx, grid, kwargs):
721
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
722
+ kernel = kernel_side_table.get_kernel(kernel_idx)
723
+ # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each
724
+ # other, and one gets mutated in kernel, and later another gets mutated,
725
+ # they are no longer equal. Fix this by graph breaking on this condition
726
+ # earlier in dynamo.
727
+ tensors_to_clone = identify_mutated_tensors(kernel, unwrapped_kwargs)
728
+ with ctx.redispatch_to_next():
729
+ unwrapped_outputs = triton_kernel_wrapper_functional(
730
+ kernel_idx=kernel_idx,
731
+ grid=grid,
732
+ kwargs=unwrapped_kwargs,
733
+ tensors_to_clone=tensors_to_clone,
734
+ )
735
+
736
+ assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys()))
737
+ for key, output_arg in unwrapped_outputs.items():
738
+ if not isinstance(output_arg, Tensor):
739
+ continue
740
+ input_arg = kwargs[key]
741
+ assert isinstance(input_arg, Tensor)
742
+
743
+ ctx.replace(input_arg, output_arg)
744
+ # indicate that above replace is hidden from autograd
745
+ ctx.mark_mutation_hidden_from_autograd(input_arg)
746
+ ctx.commit_update(input_arg)
747
+ ctx.sync(input_arg)
748
+ # sync calls replace_ under the hood, so again indicate that
749
+ # this indirect replace is hidden from autograd
750
+ ctx.mark_mutation_hidden_from_autograd(input_arg)
751
+ return None
752
+
753
+
754
+ @triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd)
755
+ def triton_kernel_wrapper_functional_dense(
756
+ *, kernel_idx, grid, kwargs, tensors_to_clone
757
+ ):
758
+ # TODO(oulgen): For performance reasons, we want to ensure that these
759
+ # `clone_preserve_strides` calls are never executed at runtime
760
+ # (inductor should always optimize them away).
761
+ # Requires https://github.com/pytorch/pytorch/issues/109240
762
+ kwargs = {
763
+ key: (clone_preserve_strides(val) if key in tensors_to_clone else val)
764
+ for key, val in kwargs.items()
765
+ }
766
+ triton_kernel_wrapper_mutation(kernel_idx=kernel_idx, grid=grid, kwargs=kwargs)
767
+ return {key: val for key, val in kwargs.items() if key in tensors_to_clone}
768
+
769
+
770
+ @triton_kernel_wrapper_functional.py_impl(FakeTensorMode)
771
+ def triton_kernel_wrapper_functional_fake_tensor_mode(
772
+ mode, *, kernel_idx, grid, kwargs, tensors_to_clone
773
+ ):
774
+ # TODO(oulgen): For performance reasons, we want to ensure that these
775
+ # `clone_preserve_strides` calls are never executed at runtime
776
+ # (inductor should always optimize them away).
777
+ # Requires https://github.com/pytorch/pytorch/issues/109240
778
+ with mode:
779
+ return {
780
+ key: clone_preserve_strides(val)
781
+ for key, val in kwargs.items()
782
+ if key in tensors_to_clone
783
+ }
784
+
785
+
786
+ @triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode)
787
+ def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
788
+ mode, *, kernel_idx, grid, kwargs, tensors_to_clone
789
+ ):
790
+ if mode.enable_tracing:
791
+ return trace_triton_kernel_wrapper(
792
+ mode,
793
+ triton_kernel_wrapper_functional,
794
+ {
795
+ "kernel_idx": kernel_idx,
796
+ "grid": grid,
797
+ "kwargs": kwargs,
798
+ "tensors_to_clone": tensors_to_clone,
799
+ },
800
+ )
801
+ else:
802
+ return triton_kernel_wrapper_functional(
803
+ kernel_idx=kernel_idx,
804
+ grid=grid,
805
+ kwargs=kwargs,
806
+ tensors_to_clone=tensors_to_clone,
807
+ )
808
+
809
+
810
+ @triton_kernel_wrapper_functional.py_functionalize_impl
811
+ def triton_kernel_wrapper_functional_functionalize(
812
+ ctx, kernel_idx, grid, kwargs, tensors_to_clone
813
+ ):
814
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
815
+ with ctx.redispatch_to_next():
816
+ outputs = triton_kernel_wrapper_functional(
817
+ kernel_idx=kernel_idx,
818
+ grid=grid,
819
+ kwargs=unwrapped_kwargs,
820
+ tensors_to_clone=tensors_to_clone,
821
+ )
822
+ return ctx.wrap_tensors(outputs)
823
+
824
+
825
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
826
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
827
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView)
828
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect)
829
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
830
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
831
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA)
832
+ triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU)
833
+
834
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
835
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
836
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView)
837
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect)
838
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
839
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined]
840
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
841
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA)
842
+ triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/while_loop.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils._pytree as pytree
3
+
4
+ from torch._C import DispatchKey
5
+
6
+ from torch._higher_order_ops.utils import (
7
+ _has_potential_branch_input_alias,
8
+ _has_potential_branch_input_mutation,
9
+ _set_compilation_env,
10
+ autograd_not_implemented,
11
+ reenter_make_fx,
12
+ UnsupportedAliasMutationException,
13
+ )
14
+ from torch._ops import HigherOrderOperator
15
+ from torch._subclasses.fake_tensor import FakeTensorMode
16
+ from torch.fx.experimental.proxy_tensor import (
17
+ disable_proxy_modes_tracing,
18
+ ProxyTorchDispatchMode,
19
+ track_tensor_tree,
20
+ )
21
+
22
+
23
+ class WhileLoopOp(HigherOrderOperator):
24
+ def __call__(self, cond_fn, body_fn, operands):
25
+ if not isinstance(cond_fn, torch.fx.GraphModule) or not isinstance(
26
+ body_fn, torch.fx.GraphModule
27
+ ):
28
+ raise RuntimeError(
29
+ "cond_fn and body_fn must be torch.fx.GraphModule, got "
30
+ f"{type(cond_fn)} and {type(body_fn)}"
31
+ )
32
+ if not isinstance(operands, tuple):
33
+ raise RuntimeError("operands must be a tuple, got " f"{type(operands)}")
34
+ if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in operands):
35
+ raise RuntimeError(
36
+ "operands must be a tuple of tensors, ints, floats, or bools, got "
37
+ f"{operands}"
38
+ )
39
+ return super().__call__(cond_fn, body_fn, operands)
40
+
41
+
42
+ while_loop_op = HigherOrderOperator("while_loop")
43
+
44
+
45
+ def while_loop(cond_fn, body_fn, operands):
46
+ r"""
47
+ Run body_fn(*operands) while cond_fn(*operands) returns a True scalar tensor. Returns the output of body_fn or
48
+ initial operands.
49
+
50
+ .. warning::
51
+ `torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and
52
+ doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
53
+ Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
54
+
55
+ `while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export.
56
+
57
+ `while_loop` is equivalent to the following:
58
+
59
+ def while_loop(cond_fn, body_fn, operands):
60
+ val = operands
61
+ while cond_fn(*val):
62
+ val = body_fn(*val)
63
+ return val
64
+
65
+ Args:
66
+ cond_fn (Callable): A callable function that returns a boolean Scalar tensor.
67
+
68
+ body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors
69
+
70
+ operands (Tuple of possibly nested dict/list/tuple of tensors): A tuple of inputs to cond_fn and body_fn. It's also
71
+ the initial value of states that are carried across iterations.
72
+
73
+ Example:
74
+
75
+ def cond_fn(iter, x):
76
+ return iter.sum() < 10
77
+
78
+ def body_fn(iter, x):
79
+ return iter + 1, x.sin()
80
+
81
+ while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4)))
82
+
83
+ Restrictions:
84
+
85
+ - body_fn must return tensors with the same metadata (e.g.shape, dtype) as inputs.
86
+
87
+ - body_fn and cond_fn must not in-place mutate the operands. A clone before the mutation is required.
88
+
89
+ - body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn.
90
+
91
+ - body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required.
92
+
93
+ .. warning::
94
+ Temporal Limitations:
95
+
96
+ - 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
97
+
98
+ """
99
+ if torch.compiler.is_dynamo_compiling():
100
+ return while_loop_op(cond_fn, body_fn, operands)
101
+
102
+ def _validate_input(cond_fn, body_fn, operands):
103
+ if not callable(cond_fn) or not callable(body_fn):
104
+ raise RuntimeError("Expect cond_fn and body_fn to be callbale.")
105
+
106
+ if not isinstance(operands, (tuple, list)) or pytree.tree_any(
107
+ lambda t: not isinstance(t, torch.Tensor), operands
108
+ ):
109
+ raise RuntimeError(
110
+ "Expect operands to be a tuple of possibly nested dict/list/tuple that only"
111
+ f"consists of tensor leaves, but got {operands}."
112
+ )
113
+
114
+ _validate_input(cond_fn, body_fn, operands)
115
+
116
+ with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
117
+ return torch.compile(while_loop_op, backend="eager", fullgraph=True)(
118
+ cond_fn, body_fn, operands
119
+ )
120
+
121
+
122
+ @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
123
+ def while_loop_dense(cond_fn, body_fn, operands):
124
+ init_val = operands
125
+
126
+ def _is_boolean_scalar_tensor(pred):
127
+ return (
128
+ isinstance(pred, torch.Tensor)
129
+ and pred.size() == torch.Size([])
130
+ and pred.dtype == torch.bool
131
+ )
132
+
133
+ if not isinstance(operands, tuple):
134
+ raise RuntimeError(f"operands must be a tuple but got {type(operands)}")
135
+
136
+ while pred := cond_fn(*init_val):
137
+ if not _is_boolean_scalar_tensor(pred):
138
+ raise RuntimeError(
139
+ f"cond_fn must return a boolean scalar tensor but got {pred}"
140
+ )
141
+ out = body_fn(*init_val)
142
+ assert isinstance(
143
+ out, tuple
144
+ ), f"body_fn should return a tuple but got {type(out)}"
145
+ assert len(out) == len(
146
+ init_val
147
+ ), "body_fn should return the same number of elements as operands"
148
+ init_val = out
149
+ return init_val
150
+
151
+
152
+ while_loop_op.py_impl(DispatchKey.Autograd)(
153
+ autograd_not_implemented(while_loop_op, deferred_error=True)
154
+ )
155
+
156
+
157
+ @while_loop_op.py_impl(ProxyTorchDispatchMode)
158
+ def while_loop_tracing(mode, cond_fn, body_fn, operands):
159
+ def _trace_while_loop(proxy_mode, while_loop_op, cond_fn, body_fn, operands):
160
+ pre_dispatch = getattr(proxy_mode, "pre_dispatch", False)
161
+ with disable_proxy_modes_tracing():
162
+ cond_graph = reenter_make_fx(cond_fn, pre_dispatch)(*operands)
163
+ body_graph = reenter_make_fx(body_fn, pre_dispatch)(*operands)
164
+
165
+ next_name = None
166
+ i = 0
167
+ while not next_name:
168
+ candidate = f"while_loop_cond_graph_{i}"
169
+ if hasattr(proxy_mode.tracer.root, candidate):
170
+ i += 1
171
+ else:
172
+ next_name = candidate
173
+ cond_graph_name = next_name
174
+ body_graph_name = f"while_loop_body_graph_{i}"
175
+ assert not hasattr(proxy_mode.tracer.root, body_graph_name)
176
+
177
+ proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
178
+ proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
179
+
180
+ args = (cond_graph, body_graph, operands)
181
+
182
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
183
+
184
+ out_proxy = proxy_mode.tracer.create_proxy(
185
+ "call_function", while_loop_op, proxy_args, {}, name="while_loop"
186
+ )
187
+
188
+ # body_fn return output with the same pytree and tensor meta data as operands
189
+ # so we could just return the output after one iteration.
190
+ out = body_fn(*operands)
191
+ return track_tensor_tree(
192
+ out, out_proxy, constant=None, tracer=proxy_mode.tracer
193
+ )
194
+
195
+ if mode.enable_tracing:
196
+ return _trace_while_loop(mode, while_loop_op, cond_fn, body_fn, operands)
197
+ else:
198
+ return while_loop_op(cond_fn, body_fn, operands)
199
+
200
+
201
+ @while_loop_op.py_impl(FakeTensorMode)
202
+ def while_loop_fake_tensor_mode(mode, cond_fn, body_fn, operands):
203
+ return body_fn(*operands)
204
+
205
+
206
+ @while_loop_op.py_functionalize_impl
207
+ def while_loop_func(ctx, cond_fn, body_fn, operands):
208
+ unwrapped_operands = ctx.unwrap_tensors(operands)
209
+ with ctx.redispatch_to_next() as m:
210
+ functional_cond_fn = ctx.functionalize(cond_fn)
211
+ functional_body_fn = ctx.functionalize(body_fn)
212
+ pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
213
+ for fn, fn_name in [
214
+ (functional_cond_fn, "cond_fn"),
215
+ (functional_body_fn, "body_fn"),
216
+ ]:
217
+ if _has_potential_branch_input_mutation(
218
+ fn, unwrapped_operands, pre_dispatch=pre_dispatch
219
+ ):
220
+ raise UnsupportedAliasMutationException(
221
+ f"torch.while_loop's {fn_name} might be modifying the input!"
222
+ )
223
+
224
+ for fn in [functional_cond_fn, functional_body_fn]:
225
+ if _has_potential_branch_input_alias(
226
+ fn, unwrapped_operands, pre_dispatch=pre_dispatch
227
+ ):
228
+ raise UnsupportedAliasMutationException(
229
+ f"torch.while_loop's {fn_name} might be aliasing the input!"
230
+ )
231
+ ret = while_loop_op(functional_cond_fn, functional_body_fn, unwrapped_operands)
232
+ return ctx.wrap_tensors(ret)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/_nnapi/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (222 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mkl/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def is_available():
5
+ r"""Return whether PyTorch is built with MKL support."""
6
+ return torch._C.has_mkl
7
+
8
+
9
+ VERBOSE_OFF = 0
10
+ VERBOSE_ON = 1
11
+
12
+
13
+ class verbose:
14
+ """
15
+ On-demand oneMKL verbosing functionality.
16
+
17
+ To make it easier to debug performance issues, oneMKL can dump verbose
18
+ messages containing execution information like duration while executing
19
+ the kernel. The verbosing functionality can be invoked via an environment
20
+ variable named `MKL_VERBOSE`. However, this methodology dumps messages in
21
+ all steps. Those are a large amount of verbose messages. Moreover, for
22
+ investigating the performance issues, generally taking verbose messages
23
+ for one single iteration is enough. This on-demand verbosing functionality
24
+ makes it possible to control scope for verbose message dumping. In the
25
+ following example, verbose messages will be dumped out for the second
26
+ inference only.
27
+
28
+ .. highlight:: python
29
+ .. code-block:: python
30
+
31
+ import torch
32
+ model(data)
33
+ with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON):
34
+ model(data)
35
+
36
+ Args:
37
+ level: Verbose level
38
+ - ``VERBOSE_OFF``: Disable verbosing
39
+ - ``VERBOSE_ON``: Enable verbosing
40
+ """
41
+
42
+ def __init__(self, enable):
43
+ self.enable = enable
44
+
45
+ def __enter__(self):
46
+ if self.enable == VERBOSE_OFF:
47
+ return
48
+ st = torch._C._verbose.mkl_set_verbose(self.enable)
49
+ assert (
50
+ st
51
+ ), "Failed to set MKL into verbose mode. Please consider to disable this verbose scope."
52
+ return self
53
+
54
+ def __exit__(self, exc_type, exc_val, exc_tb):
55
+ torch._C._verbose.mkl_set_verbose(VERBOSE_OFF)
56
+ return False
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/mps/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.84 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/nnpack/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.19 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/backends/openmp/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def is_available():
5
+ r"""Return whether PyTorch is built with OpenMP support."""
6
+ return torch._C.has_openmp
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (723 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (291 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/autograd/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import torch
4
+
5
+
6
+ def is_available():
7
+ return hasattr(torch._C, "_dist_autograd_init")
8
+
9
+
10
+ if is_available() and not torch._C._dist_autograd_init():
11
+ raise RuntimeError("Failed to initialize torch.distributed.autograd")
12
+
13
+ if is_available():
14
+ from torch._C._distributed_autograd import (
15
+ get_gradients,
16
+ backward,
17
+ _init,
18
+ _new_context,
19
+ _release_context,
20
+ _get_max_id,
21
+ _is_valid_context,
22
+ _retrieve_context,
23
+ _current_context,
24
+ _get_debug_info,
25
+ DistAutogradContext,
26
+ )
27
+
28
+
29
+ class context:
30
+ '''
31
+ Context object to wrap forward and backward passes when using
32
+ distributed autograd. The ``context_id`` generated in the ``with``
33
+ statement is required to uniquely identify a distributed backward pass
34
+ on all workers. Each worker stores metadata associated with this
35
+ ``context_id``, which is required to correctly execute a distributed
36
+ autograd pass.
37
+
38
+ Example::
39
+ >>> # xdoctest: +SKIP
40
+ >>> import torch.distributed.autograd as dist_autograd
41
+ >>> with dist_autograd.context() as context_id:
42
+ >>> t1 = torch.rand((3, 3), requires_grad=True)
43
+ >>> t2 = torch.rand((3, 3), requires_grad=True)
44
+ >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
45
+ >>> dist_autograd.backward(context_id, [loss])
46
+ '''
47
+ def __enter__(self):
48
+ self.autograd_context = _new_context()
49
+ return self.autograd_context._context_id()
50
+
51
+ def __exit__(self, type, value, traceback):
52
+ _release_context(self.autograd_context._context_id())
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-311.pyc ADDED
Binary file (3.51 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/api.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ import json
10
+ from dataclasses import asdict, dataclass, field
11
+ from enum import Enum
12
+ from typing import Dict, Union, Optional
13
+
14
+ __all__ = ['EventSource', 'Event', 'NodeState', 'RdzvEvent']
15
+
16
+ EventMetadataValue = Union[str, int, float, bool, None]
17
+
18
+
19
+ class EventSource(str, Enum):
20
+ """Known identifiers of the event producers."""
21
+
22
+ AGENT = "AGENT"
23
+ WORKER = "WORKER"
24
+
25
+
26
+ @dataclass
27
+ class Event:
28
+ """
29
+ The class represents the generic event that occurs during the torchelastic job execution.
30
+
31
+ The event can be any kind of meaningful action.
32
+
33
+ Args:
34
+ name: event name.
35
+ source: the event producer, e.g. agent or worker
36
+ timestamp: timestamp in milliseconds when event occurred.
37
+ metadata: additional data that is associated with the event.
38
+ """
39
+
40
+ name: str
41
+ source: EventSource
42
+ timestamp: int = 0
43
+ metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
44
+
45
+ def __str__(self):
46
+ return self.serialize()
47
+
48
+ @staticmethod
49
+ def deserialize(data: Union[str, "Event"]) -> "Event":
50
+ if isinstance(data, Event):
51
+ return data
52
+ if isinstance(data, str):
53
+ data_dict = json.loads(data)
54
+ data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
55
+ return Event(**data_dict)
56
+
57
+ def serialize(self) -> str:
58
+ return json.dumps(asdict(self))
59
+
60
+
61
+ class NodeState(str, Enum):
62
+ """The states that a node can be in rendezvous."""
63
+
64
+ INIT = "INIT"
65
+ RUNNING = "RUNNING"
66
+ SUCCEEDED = "SUCCEEDED"
67
+ FAILED = "FAILED"
68
+
69
+
70
+ @dataclass
71
+ class RdzvEvent:
72
+ """
73
+ Dataclass to represent any rendezvous event.
74
+
75
+ Args:
76
+ name: Event name. (E.g. Current action being performed)
77
+ run_id: The run id of the rendezvous
78
+ message: The message describing the event
79
+ hostname: Hostname of the node
80
+ pid: The process id of the node
81
+ node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED)
82
+ master_endpoint: The master endpoint for the rendezvous store, if known
83
+ rank: The rank of the node, if known
84
+ local_id: The local_id of the node, if defined in dynamic_rendezvous.py
85
+ error_trace: Error stack trace, if this is an error event.
86
+ """
87
+
88
+ name: str
89
+ run_id: str
90
+ message: str
91
+ hostname: str
92
+ pid: int
93
+ node_state: NodeState
94
+ master_endpoint: str = ""
95
+ rank: Optional[int] = None
96
+ local_id: Optional[int] = None
97
+ error_trace: str = ""
98
+
99
+ def __str__(self):
100
+ return self.serialize()
101
+
102
+ @staticmethod
103
+ def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent":
104
+ if isinstance(data, RdzvEvent):
105
+ return data
106
+ if isinstance(data, str):
107
+ data_dict = json.loads(data)
108
+ data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
109
+ return RdzvEvent(**data_dict)
110
+
111
+ def serialize(self) -> str:
112
+ return json.dumps(asdict(self))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/events/handlers.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ import logging
10
+ from typing import Dict
11
+
12
+
13
+ _log_handlers: Dict[str, logging.Handler] = {
14
+ "console": logging.StreamHandler(),
15
+ "dynamic_rendezvous": logging.NullHandler(),
16
+ "null": logging.NullHandler(),
17
+ }
18
+
19
+
20
+ def get_logging_handler(destination: str = "null") -> logging.Handler:
21
+ global _log_handlers
22
+ return _log_handlers[destination]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/metrics/api.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ import abc
10
+ import time
11
+ import warnings
12
+ from collections import namedtuple
13
+ from functools import wraps
14
+ from typing import Dict, Optional
15
+
16
+ __all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream',
17
+ 'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms',
18
+ 'MetricData']
19
+
20
+ MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
21
+
22
+
23
+ class MetricsConfig:
24
+ __slots__ = ["params"]
25
+
26
+ def __init__(self, params: Optional[Dict[str, str]] = None):
27
+ self.params = params
28
+ if self.params is None:
29
+ self.params = {}
30
+
31
+
32
+ class MetricHandler(abc.ABC):
33
+ @abc.abstractmethod
34
+ def emit(self, metric_data: MetricData):
35
+ pass
36
+
37
+
38
+ class ConsoleMetricHandler(MetricHandler):
39
+ def emit(self, metric_data: MetricData):
40
+ print(
41
+ f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
42
+ )
43
+
44
+
45
+ class NullMetricHandler(MetricHandler):
46
+ def emit(self, metric_data: MetricData):
47
+ pass
48
+
49
+
50
+ class MetricStream:
51
+ def __init__(self, group_name: str, handler: MetricHandler):
52
+ self.group_name = group_name
53
+ self.handler = handler
54
+
55
+ def add_value(self, metric_name: str, metric_value: int):
56
+ self.handler.emit(
57
+ MetricData(time.time(), self.group_name, metric_name, metric_value)
58
+ )
59
+
60
+
61
+ _metrics_map: Dict[str, MetricHandler] = {}
62
+ _default_metrics_handler: MetricHandler = NullMetricHandler()
63
+
64
+
65
+ # pyre-fixme[9]: group has type `str`; used as `None`.
66
+ def configure(handler: MetricHandler, group: Optional[str] = None):
67
+ if group is None:
68
+ global _default_metrics_handler
69
+ # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
70
+ # as `MetricHandler`.
71
+ _default_metrics_handler = handler
72
+ else:
73
+ _metrics_map[group] = handler
74
+
75
+
76
+ def getStream(group: str):
77
+ if group in _metrics_map:
78
+ handler = _metrics_map[group]
79
+ else:
80
+ handler = _default_metrics_handler
81
+ return MetricStream(group, handler)
82
+
83
+
84
+ def _get_metric_name(fn):
85
+ qualname = fn.__qualname__
86
+ split = qualname.split(".")
87
+ if len(split) == 1:
88
+ module = fn.__module__
89
+ if module:
90
+ return module.split(".")[-1] + "." + split[0]
91
+ else:
92
+ return split[0]
93
+ else:
94
+ return qualname
95
+
96
+
97
+ def prof(fn=None, group: str = "torchelastic"):
98
+ r"""
99
+ @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates.
100
+
101
+ The metric name defaults to the qualified name (``class_name.def_name``) of the function.
102
+ If the function does not belong to a class, it uses the leaf module name instead.
103
+
104
+ Usage
105
+
106
+ ::
107
+
108
+ @metrics.prof
109
+ def x():
110
+ pass
111
+
112
+ @metrics.prof(group="agent")
113
+ def y():
114
+ pass
115
+ """
116
+
117
+ def wrap(f):
118
+ @wraps(f)
119
+ def wrapper(*args, **kwargs):
120
+ key = _get_metric_name(f)
121
+ try:
122
+ start = time.time()
123
+ result = f(*args, **kwargs)
124
+ put_metric(f"{key}.success", 1, group)
125
+ except Exception:
126
+ put_metric(f"{key}.failure", 1, group)
127
+ raise
128
+ finally:
129
+ put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined]
130
+ return result
131
+
132
+ return wrapper
133
+
134
+ if fn:
135
+ return wrap(fn)
136
+ else:
137
+ return wrap
138
+
139
+
140
+ def profile(group=None):
141
+ """
142
+ @profile decorator adds latency and success/failure metrics to any given function.
143
+
144
+ Usage
145
+
146
+ ::
147
+
148
+ @metrics.profile("my_metric_group")
149
+ def some_function(<arguments>):
150
+ """
151
+ warnings.warn("Deprecated, use @prof instead", DeprecationWarning)
152
+
153
+ def wrap(func):
154
+ @wraps(func)
155
+ def wrapper(*args, **kwargs):
156
+ try:
157
+ start_time = time.time()
158
+ result = func(*args, **kwargs)
159
+ publish_metric(group, f"{func.__name__}.success", 1)
160
+ except Exception:
161
+ publish_metric(group, f"{func.__name__}.failure", 1)
162
+ raise
163
+ finally:
164
+ publish_metric(
165
+ group,
166
+ f"{func.__name__}.duration.ms",
167
+ get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]
168
+ )
169
+ return result
170
+
171
+ return wrapper
172
+
173
+ return wrap
174
+
175
+
176
+ def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
177
+ """
178
+ Publish a metric data point.
179
+
180
+ Usage
181
+
182
+ ::
183
+
184
+ put_metric("metric_name", 1)
185
+ put_metric("metric_name", 1, "metric_group_name")
186
+ """
187
+ getStream(metric_group).add_value(metric_name, metric_value)
188
+
189
+
190
+ def publish_metric(metric_group: str, metric_name: str, metric_value: int):
191
+ warnings.warn(
192
+ "Deprecated, use put_metric(metric_group)(metric_name, metric_value) instead"
193
+ )
194
+ metric_stream = getStream(metric_group)
195
+ metric_stream.add_value(metric_name, metric_value)
196
+
197
+
198
+ def get_elapsed_time_ms(start_time_in_seconds: float):
199
+ """Return the elapsed time in millis from the given start time."""
200
+ end_time = time.time()
201
+ return int((end_time - start_time_in_seconds) * 1000)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-311.pyc ADDED
Binary file (4.54 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ """
10
+ Each host in a distributed PyTorch job runs with a single TorchElastic agent,
11
+ and multiple workers (as children processes of the TorchElastic agent).
12
+ Since the workers are user-provided (your PyTorch script/job), TorchElastic
13
+ has a way to propagate errors on the trainers through the agent and up to the
14
+ scheduler, which ultimately informs the end-user about the state of the job
15
+ and applies any retry policies.
16
+
17
+ TorchElastic categorizes errors into 3 categories:
18
+
19
+ +----------------+----------------+--------------------------------------------------------------+
20
+ | Category | Sub-Category | Description |
21
+ +================+================+==============================================================+
22
+ | User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) |
23
+ | +----------------+--------------------------------------------------------------+
24
+ | | Worker Failure | any failures on the worker child process |
25
+ +----------------+----------------+--------------------------------------------------------------+
26
+ | Platform Error | n/a | failures caused by the agent |
27
+ +----------------+----------------+--------------------------------------------------------------+
28
+ | Infra Error | n/a | failures outside the domain of the agent and workers |
29
+ | | | (e.g. host failures) |
30
+ +----------------+----------------+--------------------------------------------------------------+
31
+
32
+ All errors other than "Worker Failure" are either raised canonically from the
33
+ agent process or implicitly or explicitly crash the agent process. So the
34
+ standard language (python) provided exception handling strategies apply.
35
+
36
+ Worker Failures are special because the exception/failure originates on a different
37
+ process from the agent so the error needs to be propagated inter-process
38
+ (e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process).
39
+
40
+ TorchElastic agents use :func:`torch.distributed.elastic.multiprocessing.start_processes`
41
+ to launch the workers which has a simple file based inter-process error propagation
42
+ built-in.
43
+
44
+ Any function or binary entrypoint decorated with :func:`record`
45
+ will write uncaught exceptions (with the trace information) to a file specified by the
46
+ environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent)
47
+ sets this env var on each child it launches, then aggregates the error files for all
48
+ children, and propagates the one with the **smallest** timestamp (e.g. the **first** error).
49
+ """
50
+
51
+ import json
52
+ import os
53
+ import signal
54
+ import socket
55
+ import time
56
+ import warnings
57
+ from dataclasses import dataclass, field
58
+ from datetime import datetime
59
+ from functools import wraps
60
+ from string import Template
61
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
62
+
63
+ from torch.distributed.elastic.utils.logging import get_logger
64
+
65
+ from .error_handler import ErrorHandler # noqa: F401
66
+ from .handlers import get_error_handler # noqa: F401
67
+
68
+ __all__ = ["ProcessFailure", "ChildFailedError", "record", "ErrorHandler", "get_error_handler"]
69
+
70
+ log = get_logger(__name__)
71
+
72
+
73
+ JSON = Dict
74
+
75
+ _EMPTY_ERROR_DATA = {"message": "<NONE>"}
76
+ _NOT_AVAILABLE = "<N/A>"
77
+
78
+ T = TypeVar("T")
79
+
80
+
81
+ @dataclass
82
+ class ProcessFailure:
83
+ """
84
+ Represent the failed process result. When the worker process fails, it may record failure root cause into the file.
85
+
86
+ Tries to read the failure timestamp from the provided ``error_file``,
87
+ if the ``error_file`` does not exist, the timestamp is the current
88
+ timestamp (seconds since epoch).
89
+
90
+ The ``message`` field is a concise explanation of the failure. If
91
+ the error file exists then the message is obtained from the error file.
92
+ Otherwise one is generated based on the failure signature.
93
+
94
+ .. note:: It is assumed that the ``error_file`` is written by
95
+ ``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``.
96
+ Otherwise the behavior is undefined.
97
+
98
+ """
99
+
100
+ local_rank: int
101
+ pid: int
102
+ exitcode: int
103
+ error_file: str
104
+ error_file_data: JSON = field(init=False)
105
+ message: str = field(init=False)
106
+ timestamp: int = field(init=False)
107
+
108
+ def __post_init__(self):
109
+ self.error_file_data = _EMPTY_ERROR_DATA
110
+ if os.path.isfile(self.error_file):
111
+ try:
112
+ with open(self.error_file) as fp:
113
+ self.error_file_data = json.load(fp)
114
+ log.debug(
115
+ "User process failed with error data: %s", json.dumps(self.error_file_data, indent=2)
116
+ )
117
+ self.message, self.timestamp = self._get_error_data(
118
+ self.error_file_data
119
+ )
120
+ except Exception:
121
+ log.exception("Failed to parse reply file: %s", self.error_file)
122
+ raise
123
+ else:
124
+ self._set_no_reply_file()
125
+
126
+ # make up an informative message if not already present
127
+ if not self.message:
128
+ # signals typically do not generate an error file message
129
+ if self.exitcode < 0:
130
+ self.message = (
131
+ f"Signal {-self.exitcode} ({self.signal_name()})"
132
+ f" received by PID {self.pid}"
133
+ )
134
+ else:
135
+ self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
136
+
137
+ def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
138
+ message = error_file_data["message"]
139
+ if isinstance(message, str):
140
+ timestamp = int(error_file_data.get("timestamp", 0))
141
+ else:
142
+ timestamp = int(message["extraInfo"]["timestamp"])
143
+ return (message, timestamp)
144
+
145
+ def _set_no_reply_file(self):
146
+ self.error_file = _NOT_AVAILABLE
147
+ self.error_file_data = _EMPTY_ERROR_DATA
148
+ self.message = ""
149
+ self.timestamp = int(time.time())
150
+
151
+ def signal_name(self) -> str:
152
+ if self.exitcode < 0:
153
+ # We don't want to kill the parent process trying to find the signal name.
154
+ # if the signal doesn't map to a known name, use not available.
155
+ try:
156
+ return signal.Signals(-self.exitcode).name
157
+ except Exception:
158
+ return _NOT_AVAILABLE
159
+ else:
160
+ return _NOT_AVAILABLE
161
+
162
+ def timestamp_isoformat(self):
163
+ """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS)."""
164
+ return datetime.fromtimestamp(self.timestamp).isoformat(sep="_")
165
+
166
+
167
+ GlobalRank = int
168
+
169
+ _FAILURE_FORMAT_TEMPLATE = """[${idx}]:
170
+ time : ${time}
171
+ host : ${hostname}
172
+ rank : ${rank} (local_rank: ${local_rank})
173
+ exitcode : ${exitcode} (pid: ${pid})
174
+ error_file: ${error_file}
175
+ traceback : ${message}"""
176
+
177
+ # extra new lines before and after are intentional
178
+ _MSG_FORMAT_TEMPLATE = """
179
+ ${boarder}
180
+ ${title}
181
+ ${section}
182
+ Failures:
183
+ ${other_failures}
184
+ ${section}
185
+ Root Cause (first observed failure):
186
+ ${root_failure}
187
+ ${boarder}"""
188
+
189
+
190
+ class ChildFailedError(Exception):
191
+ """
192
+ Special exception type that can be raised from a function annotated with the
193
+ ``@record`` decorator to have the child process' (root exception) propagate
194
+ up the stack as-is (e.g. without being wrapped in the parent's traceback).
195
+
196
+ Useful in cases where the parent is a simple nanny process
197
+ and the child (worker) processes are actually doing meaningful compute.
198
+ In this case, errors typically occur on the child process as the parent
199
+ is not doing anything non-trivial, and child errors should be propagated
200
+ to the scheduler for accurate root cause diagnostics.
201
+
202
+ .. note:: The propagation relies on error files rather than exception handling to
203
+ support both function and binary launches.
204
+
205
+ Example:
206
+ ::
207
+
208
+ # process tree on a host (container)
209
+ 0: scheduler-init-process:
210
+ |- 1: torchelastic_agent:
211
+ |- 2: trainer_0 (ok)
212
+ |- 3: trainer_1 (fail) -> error.json
213
+ |- ...
214
+ |- n+2: trainer_n (ok)
215
+ |- n+3: other processes
216
+ |- ...
217
+
218
+ In the example above, trainer 1's failure (written into error.json) is
219
+ the root cause and should be reported to the scheduler's init process.
220
+ The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})``
221
+ upon detecting trainer 1's failure which would propagate the contents
222
+ of trainer 1's error file to the scheduler's init process.
223
+ """
224
+
225
+ def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]):
226
+ self.name = name
227
+ self.failures = failures
228
+ assert (
229
+ self.failures
230
+ ) # does not make sense to create a ChildFaileError with no failures
231
+ super().__init__(self.format_msg())
232
+
233
+ def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]:
234
+ rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp)
235
+ return rank, self.failures[rank]
236
+
237
+ def format_msg(self, boarder_delim="=", section_delim="-"):
238
+ title = f"{self.name} FAILED"
239
+ root_rank, root_failure = self.get_first_failure()
240
+
241
+ root_failure_fmt: str = ""
242
+ other_failures_fmt: List[str] = []
243
+ width = len(title)
244
+ for idx, (rank, failure) in enumerate(self.failures.items()):
245
+ fmt, w = self._format_failure(idx, rank, failure)
246
+ width = max(width, w)
247
+ if rank == root_rank:
248
+ root_failure_fmt = fmt
249
+ else:
250
+ other_failures_fmt.append(fmt)
251
+
252
+ # upper boundary on width
253
+ width = min(width, 60)
254
+
255
+ return Template(_MSG_FORMAT_TEMPLATE).substitute(
256
+ boarder=boarder_delim * width,
257
+ title=title,
258
+ section=section_delim * width,
259
+ root_failure=root_failure_fmt,
260
+ other_failures="\n".join(other_failures_fmt or [" <NO_OTHER_FAILURES>"]),
261
+ )
262
+
263
+ def _format_failure(
264
+ self, idx: int, rank: int, failure: ProcessFailure
265
+ ) -> Tuple[str, int]:
266
+
267
+ # failure.message is either a str (when the failure does not generate a traceback - e.g. signals)
268
+ # or a dict (json) of the form
269
+ # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}}
270
+ # so the display logic is:
271
+ # 1. if failure.message is not a dict (it is a str) just show it as is
272
+ # 2. else try to get the traceback (py_callstack)
273
+ # 3. if the traceback is not there, use the message
274
+ # 4. if the message is not there show <N/A>
275
+ msg = failure.message
276
+ if isinstance(failure.message, dict):
277
+ msg = (
278
+ failure.message.get("extraInfo", {})
279
+ .get("py_callstack", failure.message.get("message", "<N/A>"))
280
+ .replace("\n", "\n ") # to properly indent the traceback
281
+ )
282
+
283
+ fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
284
+ idx=idx,
285
+ time=failure.timestamp_isoformat(),
286
+ hostname=socket.getfqdn(),
287
+ rank=rank,
288
+ local_rank=failure.local_rank,
289
+ exitcode=failure.exitcode,
290
+ pid=failure.pid,
291
+ error_file=failure.error_file,
292
+ message=msg,
293
+ )
294
+ width = 0
295
+ for line in fmt.split("\n"):
296
+ width = max(width, len(line))
297
+ return fmt, width
298
+
299
+
300
+ def record(
301
+ fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
302
+ ) -> Callable[..., T]:
303
+ """
304
+ Syntactic sugar to record errors/exceptions that happened in the decorated
305
+ function using the provided ``error_handler``.
306
+
307
+ Using this decorator is equivalent to:
308
+
309
+ ::
310
+
311
+ error_handler = get_error_handler()
312
+ error_handler.initialize()
313
+ try:
314
+ foobar()
315
+ except ChildFailedError as e:
316
+ _, failure = e.get_first_failure()
317
+ error_handler.dump_error_file(failure.error_file, failure.exitcode)
318
+ raise
319
+ except Exception as e:
320
+ error_handler.record(e)
321
+ raise
322
+
323
+ .. important:: use this decorator once per process at the top level method,
324
+ typically this is the main method.
325
+
326
+ Example
327
+
328
+ ::
329
+
330
+ @record
331
+ def main():
332
+ pass
333
+
334
+ if __name__=="__main__":
335
+ main()
336
+
337
+ """
338
+ if not error_handler:
339
+ error_handler = get_error_handler()
340
+
341
+ def wrap(f):
342
+ @wraps(f)
343
+ def wrapper(*args, **kwargs):
344
+ assert error_handler is not None # assertion for mypy type checker
345
+ error_handler.initialize()
346
+ try:
347
+ return f(*args, **kwargs)
348
+ except SystemExit as se:
349
+ # For run_path based entrypoints, SystemExit with code = 0 will never exit.
350
+ # Handling it here by returning a value:
351
+ if se.code == 0:
352
+ return None
353
+ else:
354
+ raise
355
+ except ChildFailedError as e:
356
+ rank, failure = e.get_first_failure()
357
+ if failure.error_file != _NOT_AVAILABLE:
358
+ error_handler.dump_error_file(failure.error_file, failure.exitcode)
359
+ else:
360
+ log.info(
361
+ (
362
+ "local_rank %s FAILED with no error file."
363
+ " Decorate your entrypoint fn with @record for traceback info."
364
+ " See: https://pytorch.org/docs/stable/elastic/errors.html",
365
+ rank
366
+ )
367
+ )
368
+ raise
369
+ except Exception as e:
370
+ error_handler.record_exception(e)
371
+ raise
372
+
373
+ return wrapper
374
+
375
+ return wrap(fn)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/handlers.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ # Multiprocessing error-reporting module
9
+
10
+
11
+ from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
12
+
13
+ __all__ = ['get_error_handler']
14
+
15
+ def get_error_handler():
16
+ return ErrorHandler()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-311.pyc ADDED
Binary file (937 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-311.pyc ADDED
Binary file (3.72 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the BSD-style license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ from typing import Dict, Tuple
9
+
10
+ from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
11
+ SubprocessHandler,
12
+ )
13
+
14
+ __all__ = ["get_subprocess_handler"]
15
+
16
+
17
+ def get_subprocess_handler(
18
+ entrypoint: str,
19
+ args: Tuple,
20
+ env: Dict[str, str],
21
+ stdout: str,
22
+ stderr: str,
23
+ local_rank_id: int,
24
+ ):
25
+ return SubprocessHandler(
26
+ entrypoint=entrypoint,
27
+ args=args,
28
+ env=env,
29
+ stdout=stdout,
30
+ stderr=stderr,
31
+ local_rank_id=local_rank_id,
32
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Expiration timers are set up on the same process as the agent and
9
+ used from your script to deal with stuck workers. When you go into
10
+ a code-block that has the potential to get stuck you can acquire
11
+ an expiration timer, which instructs the timer server to kill the
12
+ process if it does not release the timer by the self-imposed expiration
13
+ deadline.
14
+
15
+ Usage::
16
+
17
+ import torchelastic.timer as timer
18
+ import torchelastic.agent.server as agent
19
+
20
+ def main():
21
+ start_method = "spawn"
22
+ message_queue = mp.get_context(start_method).Queue()
23
+ server = timer.LocalTimerServer(message, max_interval=0.01)
24
+ server.start() # non-blocking
25
+
26
+ spec = WorkerSpec(
27
+ fn=trainer_func,
28
+ args=(message_queue,),
29
+ ...<OTHER_PARAMS...>)
30
+ agent = agent.LocalElasticAgent(spec, start_method)
31
+ agent.run()
32
+
33
+ def trainer_func(message_queue):
34
+ timer.configure(timer.LocalTimerClient(message_queue))
35
+ with timer.expires(after=60): # 60 second expiry
36
+ # do some work
37
+
38
+ In the example above if ``trainer_func`` takes more than 60 seconds to
39
+ complete, then the worker process is killed and the agent retries the worker group.
40
+ """
41
+
42
+ from .api import TimerClient, TimerRequest, TimerServer, configure, expires # noqa: F401
43
+ from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401
44
+ from .file_based_local_timer import FileTimerClient, FileTimerServer, FileTimerRequest # noqa: F401
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-311.pyc ADDED
Binary file (7.63 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/elastic/timer/local_timer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import logging
7
+ import multiprocessing as mp
8
+ import os
9
+ import signal
10
+ import time
11
+ from queue import Empty
12
+ from typing import Any, Dict, List, Set, Tuple
13
+
14
+ from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
15
+
16
+ __all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer']
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+ class LocalTimerClient(TimerClient):
21
+ """
22
+ Client side of ``LocalTimerServer``. This client is meant to be used
23
+ on the same host that the ``LocalTimerServer`` is running on and uses
24
+ pid to uniquely identify a worker. This is particularly useful in situations
25
+ where one spawns a subprocess (trainer) per GPU on a host with multiple
26
+ GPU devices.
27
+ """
28
+
29
+ def __init__(self, mp_queue):
30
+ super().__init__()
31
+ self._mp_queue = mp_queue
32
+
33
+ def acquire(self, scope_id, expiration_time):
34
+ pid = os.getpid()
35
+ acquire_request = TimerRequest(pid, scope_id, expiration_time)
36
+ self._mp_queue.put(acquire_request)
37
+
38
+ def release(self, scope_id):
39
+ pid = os.getpid()
40
+ release_request = TimerRequest(pid, scope_id, -1)
41
+ self._mp_queue.put(release_request)
42
+
43
+
44
+ class MultiprocessingRequestQueue(RequestQueue):
45
+ """
46
+ A ``RequestQueue`` backed by python ``multiprocessing.Queue``
47
+ """
48
+
49
+ def __init__(self, mp_queue: mp.Queue):
50
+ super().__init__()
51
+ self._mp_queue = mp_queue
52
+
53
+ def size(self) -> int:
54
+ return self._mp_queue.qsize()
55
+
56
+ def get(self, size, timeout: float) -> List[TimerRequest]:
57
+ requests = []
58
+ wait = timeout
59
+ for _ in range(0, size):
60
+ start = time.time()
61
+
62
+ try:
63
+ r = self._mp_queue.get(block=True, timeout=wait)
64
+ except Empty:
65
+ break
66
+
67
+ requests.append(r)
68
+ wait = wait - (time.time() - start)
69
+ if wait <= 0:
70
+ break
71
+
72
+ return requests
73
+
74
+
75
+ class LocalTimerServer(TimerServer):
76
+ """
77
+ Server that works with ``LocalTimerClient``. Clients are expected to be
78
+ subprocesses to the parent process that is running this server. Each host
79
+ in the job is expected to start its own timer server locally and each
80
+ server instance manages timers for local workers (running on processes
81
+ on the same host).
82
+ """
83
+
84
+ def __init__(
85
+ self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
86
+ ):
87
+ super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
88
+ self._timers: Dict[Tuple[Any, str], TimerRequest] = {}
89
+
90
+ def register_timers(self, timer_requests: List[TimerRequest]) -> None:
91
+ for request in timer_requests:
92
+ pid = request.worker_id
93
+ scope_id = request.scope_id
94
+ expiration_time = request.expiration_time
95
+
96
+ # negative expiration is a proxy for a release call
97
+ if expiration_time < 0:
98
+ self._timers.pop((pid, scope_id), None)
99
+ else:
100
+ self._timers[(pid, scope_id)] = request
101
+
102
+ def clear_timers(self, worker_ids: Set[int]) -> None:
103
+ for (pid, scope_id) in list(self._timers.keys()):
104
+ if pid in worker_ids:
105
+ self._timers.pop((pid, scope_id))
106
+
107
+ def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]:
108
+ # pid -> [timer_requests...]
109
+ expired_timers: Dict[Any, List[TimerRequest]] = {}
110
+ for request in self._timers.values():
111
+ if request.expiration_time <= deadline:
112
+ expired_scopes = expired_timers.setdefault(request.worker_id, [])
113
+ expired_scopes.append(request)
114
+ return expired_timers
115
+
116
+ def _reap_worker(self, worker_id: int) -> bool:
117
+ try:
118
+ os.kill(worker_id, signal.SIGKILL)
119
+ return True
120
+ except ProcessLookupError:
121
+ log.info("Process with pid=%s does not exist. Skipping", worker_id)
122
+ return True
123
+ except Exception:
124
+ log.exception("Error terminating pid=%s", worker_id)
125
+ return False
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch
2
+ if torch.distributed.rpc.is_available():
3
+ from .api.remote_module import RemoteModule
4
+ from .functional import * # noqa: F403
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/api/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (225 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/nn/jit/templates/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-311.pyc ADDED
Binary file (6.34 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/copy.cpython-311.pyc ADDED
Binary file (5.95 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/pipeline/sync/__pycache__/worker.cpython-311.pyc ADDED
Binary file (7.46 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (225 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-311.pyc ADDED
Binary file (5.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-311.pyc ADDED
Binary file (28 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/distributed/tensor/parallel/api.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ from typing import Dict, Union
3
+
4
+ import torch
5
+ import torch.distributed._tensor.random as random
6
+ import torch.nn as nn
7
+ from torch.distributed._tensor import (
8
+ DeviceMesh,
9
+ )
10
+ from torch.distributed._tensor.random import (
11
+ is_rng_supported_mesh,
12
+ TensorParallelRNGTracker,
13
+ )
14
+ from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
15
+ from torch.distributed.tensor.parallel.style import (
16
+ ParallelStyle,
17
+ )
18
+
19
+
20
+ __all__ = [
21
+ "parallelize_module",
22
+ ]
23
+
24
+
25
+ def parallelize_module( # type: ignore[return]
26
+ module: nn.Module,
27
+ device_mesh: DeviceMesh,
28
+ parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
29
+ ) -> nn.Module:
30
+ """
31
+ Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
32
+
33
+ We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
34
+ :class:`ParallelStyle`, which indicates how user wants the module or sub_module
35
+ to be parallelized.
36
+
37
+ User can also specify different parallel style per module fully qualified name (FQN).
38
+
39
+ Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
40
+ slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
41
+
42
+ Args:
43
+ module (:class:`nn.Module`):
44
+ Module to be parallelized.
45
+ device_mesh (:class:`DeviceMesh`):
46
+ Object which describes the mesh topology
47
+ of devices for the DTensor.
48
+ parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
49
+ The plan used to parallelize the module. It can be either a
50
+ :class:`ParallelStyle` object which contains how
51
+ we prepare input/output for Tensor Parallelism or it can be a
52
+ dict of module FQN and its corresponding :class:`ParallelStyle` object.
53
+ Return:
54
+ A :class:`nn.Module` object parallelized.
55
+
56
+ Example::
57
+ >>> # xdoctest: +SKIP("distributed")
58
+ >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
59
+ >>> from torch.distributed.device_mesh import init_device_mesh
60
+ >>>
61
+ >>> # Define the module.
62
+ >>> m = Model(...)
63
+ >>> tp_mesh = init_device_mesh("cuda", (8,))
64
+ >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
65
+ >>>
66
+
67
+ .. note:: For complex module architecture like Attention, MLP layers, we recommend composing
68
+ different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
69
+ as a parallelize_plan, to achieves the desired sharding computation.
70
+ """
71
+ torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
72
+
73
+ _validate_tp_mesh_dim(device_mesh)
74
+
75
+ # instantiate a TP RNG state tracker if it's not there
76
+ if is_rng_supported_mesh(device_mesh) and not isinstance(
77
+ random._rng_tracker, TensorParallelRNGTracker
78
+ ):
79
+ random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
80
+ # TODO: we should allow user to pass in the default seed from a config
81
+ random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
82
+ # By default we execute random ops in non-tensor-parallel region. If users want
83
+ # to execute in tensor-parallel region, they can manually set this field to True
84
+ # after parallelizing the model.
85
+ random._rng_tracker.distribute_region_enabled = False
86
+
87
+ if isinstance(parallelize_plan, ParallelStyle):
88
+ return parallelize_plan._apply(module, device_mesh)
89
+ elif isinstance(parallelize_plan, dict):
90
+ for module_path, parallelize_style in parallelize_plan.items():
91
+ sub_module = module.get_submodule(module_path)
92
+ parent_module = module
93
+ if "." in module_path:
94
+ parent_module_path = ".".join(module_path.split(".")[:-1])
95
+ parent_module = module.get_submodule(parent_module_path)
96
+ module_path = module_path.split(".")[-1]
97
+ parent_module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
98
+ module_path,
99
+ parallelize_module( # type: ignore[arg-type]
100
+ sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
101
+ ),
102
+ )
103
+ return module
104
+ else:
105
+ raise RuntimeError( # pyre-ignore[7]
106
+ "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
107
+ f" parallelize_plan, {type(parallelize_plan)} found!"
108
+ )