koichi12 commited on
Commit
e7b25d3
·
verified ·
1 Parent(s): 055b29c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/__init__.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py +55 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py +48 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__init__.py +308 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__init__.py +236 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py +91 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py +52 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc +0 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py +2 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py +75 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py +303 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py +675 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__init__.py +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc +0 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py +144 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py +78 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc +0 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc +0 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py +33 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__init__.py +117 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/__init__.cpython-311.pyc +0 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/bias.cpython-311.pyc +0 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/bias.py +353 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/common_types.py +42 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/grad.py +189 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__init__.py +1 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__init__.py +31 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (217 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+
3
+ import torch._C._lazy
4
+ from torch.utils._pytree import tree_flatten, tree_unflatten
5
+
6
+ from .closure import add_step_closure, run_step_closures
7
+
8
+
9
+ def mark_step(device: str = "", wait=False):
10
+ """Triggers a mark step, which amounts to
11
+ - collecting a group of 'live' lazy tensors to index into the compilation cache
12
+ (lowering/compiling their IR graphs if not cached)
13
+ - kicking off execution of the compiled function
14
+ - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator)
15
+ """
16
+ # TODO(whc) expand this to include backend hooks and align with XLA backend needs
17
+ torch._C._lazy._mark_step(device, [], wait=wait)
18
+
19
+ run_step_closures()
20
+
21
+
22
+ def wait_device_ops(devices=None):
23
+ """Waits for all the async operations on the given devices to complete.
24
+ Args:
25
+ devices (string..., optional): The devices whose async ops need to be waited
26
+ for. If empty, all the local devices will be waited for.
27
+ """
28
+ if devices is None:
29
+ devices = []
30
+ torch._C._lazy._wait_device_ops(devices=devices)
31
+
32
+
33
+ def sync_multi(tensors, devices):
34
+ """
35
+ Sync the list of lazy tensors so there IR get lowered for the activate backend
36
+ and the compiled computation graph get cached.
37
+ """
38
+ torch._C._lazy._sync_multi(tensors, devices)
39
+
40
+
41
+ def get_tensor_id(tensor):
42
+ """Return a unique id of the lazy tensor maintained by LTC"""
43
+ return torch._C._lazy._get_tensor_id(tensor)
44
+
45
+
46
+ def to_cpu(tensors, devices=None):
47
+ devices = devices or ["lazy"]
48
+
49
+ flattened, spec = tree_flatten(tensors)
50
+ sync_multi(flattened, devices)
51
+ return tree_unflatten([t.to("cpu") for t in flattened], spec)
52
+
53
+
54
+ def save(tensors, *args, **kwargs):
55
+ torch.save(to_cpu(tensors), *args, **kwargs)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc ADDED
Binary file (1.33 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc ADDED
Binary file (887 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (1.42 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc ADDED
Binary file (1.09 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/tensor_factory_functions.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ """
4
+ tensor_factory_functions defines the list of torch functions that create tensors.
5
+ The list is grabbed by searching thru native_functions.yaml by the following
6
+ regular expression:
7
+
8
+ cat native_functions.yaml | grep 'func:' | grep -v "Tensor.*->" | grep "[-]>.*Tensor"
9
+
10
+ It's possible that new tensor factory functions are added making this list stale.
11
+ Use at your own risk or regenerate the list.
12
+ """
13
+ tensor_factory_functions = (
14
+ torch._cudnn_init_dropout_state,
15
+ torch.arange,
16
+ torch.bartlett_window,
17
+ torch.blackman_window,
18
+ torch._empty_affine_quantized,
19
+ torch.empty_strided,
20
+ torch.eye,
21
+ torch.full,
22
+ torch.from_file,
23
+ torch.hann_window,
24
+ torch.hamming_window,
25
+ torch.kaiser_window,
26
+ torch.linspace,
27
+ torch.logspace,
28
+ torch.ones,
29
+ torch.scalar_tensor,
30
+ torch.rand,
31
+ torch.randint,
32
+ torch.randn,
33
+ torch.randperm,
34
+ torch.range,
35
+ torch._efficientzerotensor,
36
+ torch.zeros,
37
+ torch.tril_indices,
38
+ torch.triu_indices,
39
+ # Note: the following functions match the regular expression search above but
40
+ # they are not available in the torch module. Comment out.
41
+ # torch._sparse_coo_tensor_with_dims,
42
+ # torch.fft_fftfreq,
43
+ # torch.fft_rfftfreq,
44
+ ) + (
45
+ # torch.tensor is special since it's not in native_functions.yaml
46
+ # add it separately
47
+ torch.tensor,
48
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__init__.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ import torch._prims as prims
8
+
9
+ import torch._prims_common as utils
10
+ import torch._refs as refs
11
+ import torch._refs.linalg as linalg
12
+ from torch import Tensor
13
+ from torch._prims_common import (
14
+ check_fp_or_complex,
15
+ check_is_matrix,
16
+ Dim,
17
+ DimsType,
18
+ ELEMENTWISE_TYPE_PROMOTION_KIND,
19
+ IntLike,
20
+ NumberType,
21
+ TensorLikeType,
22
+ )
23
+ from torch._prims_common.wrappers import (
24
+ _maybe_convert_to_dtype,
25
+ elementwise_type_promotion_wrapper,
26
+ out_wrapper,
27
+ )
28
+
29
+
30
+ __all__ = [
31
+ "diagonal",
32
+ "matrix_norm",
33
+ "norm",
34
+ "svd",
35
+ "svdvals",
36
+ "vector_norm",
37
+ "vecdot",
38
+ "cross",
39
+ ]
40
+
41
+
42
+ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str):
43
+ """
44
+ Checks related to the dtype kwarg in `linalg.*norm` functions
45
+ """
46
+ if dtype is not None:
47
+ torch._check(
48
+ utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
49
+ lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
50
+ )
51
+ torch._check(
52
+ utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
53
+ lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
54
+ fn_name=fn_name,
55
+ d="complex" if utils.is_complex_dtype(x_dtype) else "real",
56
+ dtype=dtype,
57
+ ),
58
+ )
59
+ torch._check(
60
+ utils.get_higher_dtype(dtype, x_dtype) == dtype,
61
+ lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
62
+ "without narrowing to the specified dtype ({dtype})",
63
+ )
64
+
65
+
66
+ # Utilities should come BEFORE this import
67
+ from torch._decomp import register_decomposition
68
+ from torch._decomp.decompositions import pw_cast_for_opmath
69
+
70
+
71
+ @register_decomposition(torch._ops.ops.aten.linalg_cross)
72
+ @out_wrapper()
73
+ @pw_cast_for_opmath
74
+ def cross(a: Tensor, b: Tensor, dim: int = -1):
75
+ torch._check(
76
+ a.ndim == b.ndim,
77
+ lambda: "linalg.cross: inputs must have the same number of dimensions.",
78
+ )
79
+ torch._check(
80
+ a.size(dim) == 3 and b.size(dim) == 3,
81
+ lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}",
82
+ )
83
+ a, b = torch.broadcast_tensors(a, b)
84
+ dim = utils.canonicalize_dim(a.ndim, dim)
85
+ idx = torch.arange(3, device=a.device)
86
+ return a.index_select(dim, (idx + 1) % 3) * b.index_select(
87
+ dim, (idx + 2) % 3
88
+ ) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3)
89
+
90
+
91
+ def diagonal(
92
+ input: TensorLikeType,
93
+ *,
94
+ offset: int = 0,
95
+ dim1: int = -2,
96
+ dim2: int = -1,
97
+ ) -> TensorLikeType:
98
+ return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2)
99
+
100
+
101
+ @register_decomposition(torch._ops.ops.aten.linalg_vector_norm)
102
+ @out_wrapper(exact_dtype=True)
103
+ def vector_norm(
104
+ x: TensorLikeType,
105
+ ord: Union[float, int] = 2,
106
+ dim: Optional[DimsType] = None,
107
+ keepdim: bool = False,
108
+ *,
109
+ dtype: Optional[torch.dtype] = None,
110
+ ) -> Tensor:
111
+ # Checks
112
+ check_fp_or_complex(x.dtype, "linalg.vector_norm")
113
+
114
+ if isinstance(dim, Dim):
115
+ dim = [dim] # type: ignore[assignment]
116
+
117
+ if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
118
+ torch._check(
119
+ dim is not None and len(dim) != 0,
120
+ lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
121
+ "because the operation does not have an identity",
122
+ )
123
+ shape = x.shape
124
+ assert dim is not None # mypy does not seem to be able to see through check?
125
+ for d in dim:
126
+ torch._check(
127
+ shape[d] != 0,
128
+ lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
129
+ f"dimension {d} because this dimension is empty and the "
130
+ "operation does not have an identity",
131
+ )
132
+ _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")
133
+
134
+ computation_dtype, result_dtype = utils.reduction_dtypes(
135
+ x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype
136
+ )
137
+
138
+ to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype)
139
+
140
+ # Implementation
141
+ if ord == 0.0:
142
+ return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype)
143
+ elif ord == float("inf"):
144
+ return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type]
145
+ elif ord == float("-inf"):
146
+ return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type]
147
+ else:
148
+ # From here on the computation dtype is important as the reduction is non-trivial
149
+ x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment]
150
+ reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim)
151
+
152
+ is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0
153
+ if not (is_ord_even and utils.is_float_dtype(x.dtype)):
154
+ x = torch.abs(x)
155
+ return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value]
156
+
157
+
158
+ def _backshift_permutation(dim0, dim1, ndim):
159
+ # Auxiliary function for matrix_norm
160
+ # Computes the permutation that moves the two given dimensions to the back
161
+ ret = [i for i in range(ndim) if i != dim0 and i != dim1]
162
+ ret.extend((dim0, dim1))
163
+ return ret
164
+
165
+
166
+ def _inverse_permutation(perm):
167
+ # Given a permutation, returns its inverse. It's equivalent to argsort on an array
168
+ return [i for i, j in sorted(enumerate(perm), key=lambda i_j: i_j[1])]
169
+
170
+
171
+ # CompositeImplicitAutograd
172
+ @out_wrapper(exact_dtype=True)
173
+ def matrix_norm(
174
+ A: TensorLikeType,
175
+ ord: Union[float, str] = "fro",
176
+ dim: DimsType = (-2, -1),
177
+ keepdim: bool = False,
178
+ *,
179
+ dtype: Optional[torch.dtype] = None,
180
+ ) -> TensorLikeType:
181
+ # shape
182
+ check_is_matrix(A, "linalg.matrix_norm")
183
+ # dim
184
+ dim = utils.canonicalize_dims(A.ndim, dim)
185
+ if isinstance(dim, Dim):
186
+ dim = (dim,) # type: ignore[assignment]
187
+ torch._check(
188
+ len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
189
+ )
190
+ torch._check(
191
+ dim[0] != dim[1],
192
+ lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
193
+ )
194
+ # dtype arg
195
+ _check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm")
196
+
197
+ if isinstance(ord, str):
198
+ # ord
199
+ torch._check(
200
+ ord in ("fro", "nuc"),
201
+ lambda: "linalg.matrix_norm: Order {ord} not supported.",
202
+ )
203
+ # dtype
204
+ check_fp_or_complex(
205
+ A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc"
206
+ )
207
+
208
+ if ord == "fro":
209
+ return vector_norm(A, 2, dim, keepdim, dtype=dtype)
210
+ else: # ord == "nuc"
211
+ if dtype is not None:
212
+ A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
213
+ perm = _backshift_permutation(dim[0], dim[1], A.ndim)
214
+ result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
215
+ if keepdim:
216
+ inv_perm = _inverse_permutation(perm)
217
+ result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
218
+ return result
219
+ else:
220
+ # ord
221
+ abs_ord = abs(ord)
222
+ torch._check(
223
+ abs_ord in (2, 1, float("inf")),
224
+ lambda: "linalg.matrix_norm: Order {ord} not supported.",
225
+ )
226
+ # dtype
227
+ check_fp_or_complex(
228
+ A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2
229
+ )
230
+
231
+ max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim)
232
+
233
+ if abs_ord == 2.0:
234
+ if dtype is not None:
235
+ A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
236
+ perm = _backshift_permutation(dim[0], dim[1], A.ndim)
237
+ result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
238
+ if keepdim:
239
+ inv_perm = _inverse_permutation(perm)
240
+ result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
241
+ return result
242
+ else: # 1, -1, inf, -inf
243
+ dim0, dim1 = dim
244
+ if abs_ord == float("inf"):
245
+ dim0, dim1 = dim1, dim0
246
+ if not keepdim and (dim0 < dim1):
247
+ dim1 -= 1
248
+ return max_min(
249
+ vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1
250
+ )
251
+
252
+
253
+ # CompositeImplicitAutograd
254
+ @out_wrapper(exact_dtype=True)
255
+ def norm(
256
+ A: TensorLikeType,
257
+ ord: Optional[Union[float, str]] = None,
258
+ dim: Optional[DimsType] = None,
259
+ keepdim: bool = False,
260
+ *,
261
+ dtype: Optional[torch.dtype] = None,
262
+ ) -> TensorLikeType:
263
+ if dim is not None:
264
+ if isinstance(dim, Dim):
265
+ dim = (dim,) # type: ignore[assignment]
266
+ torch._check(
267
+ len(dim) in (1, 2),
268
+ lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
269
+ )
270
+ elif ord is not None:
271
+ torch._check(
272
+ A.ndim in (1, 2),
273
+ lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
274
+ )
275
+
276
+ if ord is not None and (
277
+ (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2)
278
+ ):
279
+ if dim is None:
280
+ dim = (0, 1)
281
+ return matrix_norm(A, ord, dim, keepdim, dtype=dtype)
282
+ else:
283
+ if ord is None:
284
+ ord = 2.0
285
+ return vector_norm(A, ord, dim, keepdim, dtype=dtype)
286
+
287
+
288
+ # CompositeImplicitAutograd
289
+ @out_wrapper("U", "S", "Vh", exact_dtype=True)
290
+ def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
291
+ return prims.svd(A, full_matrices=full_matrices)
292
+
293
+
294
+ # CompositeImplicitAutograd
295
+ @out_wrapper(exact_dtype=True)
296
+ def svdvals(A: TensorLikeType) -> Tensor:
297
+ return svd(A, full_matrices=False)[1]
298
+
299
+
300
+ # CompositeImplicitAutograd
301
+ @out_wrapper()
302
+ @elementwise_type_promotion_wrapper(
303
+ type_promoting_args=("x", "y"),
304
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
305
+ )
306
+ def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
307
+ check_fp_or_complex(x.dtype, "linalg.vecdot")
308
+ return (x.conj() * y).sum(dim=dim)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__init__.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ import torch._prims as prims
6
+ import torch._prims_common as utils
7
+ import torch._refs as refs
8
+
9
+ from torch import Tensor
10
+ from torch._decomp import register_decomposition
11
+ from torch._prims_common import (
12
+ ELEMENTWISE_TYPE_PROMOTION_KIND,
13
+ Number,
14
+ NumberType,
15
+ TensorLike,
16
+ TensorLikeType,
17
+ )
18
+ from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper
19
+ from torch._refs import (
20
+ _make_alias,
21
+ _make_elementwise_binary_reference,
22
+ _make_elementwise_unary_reference,
23
+ )
24
+
25
+
26
+ __all__ = [
27
+ "bessel_j0",
28
+ "bessel_j1",
29
+ "entr",
30
+ "erfcx",
31
+ "expit",
32
+ "i0e",
33
+ "i1",
34
+ "i1e",
35
+ "log_ndtr",
36
+ "logit",
37
+ "log_softmax",
38
+ "multigammaln",
39
+ "ndtr",
40
+ "ndtri",
41
+ "softmax",
42
+ "spherical_bessel_j0",
43
+ "xlog1py",
44
+ "zeta",
45
+ ]
46
+ aten = torch._ops.ops.aten
47
+
48
+
49
+ @_make_elementwise_unary_reference(
50
+ ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
51
+ )
52
+ def bessel_j0(a: TensorLikeType) -> TensorLikeType:
53
+ return prims.bessel_j0(a)
54
+
55
+
56
+ @_make_elementwise_unary_reference(
57
+ ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
58
+ )
59
+ def bessel_j1(a: TensorLikeType) -> TensorLikeType:
60
+ return prims.bessel_j1(a)
61
+
62
+
63
+ @register_decomposition(aten.special_entr)
64
+ @out_wrapper()
65
+ @elementwise_type_promotion_wrapper(
66
+ type_promoting_args=("a",),
67
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
68
+ )
69
+ def entr(a: TensorLikeType) -> TensorLikeType:
70
+ return torch.where(
71
+ torch.isnan(a),
72
+ a,
73
+ torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)),
74
+ )
75
+
76
+
77
+ @register_decomposition(aten.special_erfcx)
78
+ @out_wrapper()
79
+ @elementwise_type_promotion_wrapper(
80
+ type_promoting_args=("a",),
81
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
82
+ )
83
+ def erfcx(a: TensorLikeType) -> TensorLikeType:
84
+ return prims.erfcx(a)
85
+
86
+
87
+ # alias for sigmoid
88
+ expit = _make_alias(torch.sigmoid, "expit")
89
+
90
+
91
+ @_make_elementwise_unary_reference(
92
+ ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
93
+ )
94
+ def i0e(a: TensorLikeType) -> TensorLikeType:
95
+ return prims.bessel_i0e(a)
96
+
97
+
98
+ @_make_elementwise_unary_reference(
99
+ ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
100
+ )
101
+ def i1(a: TensorLikeType) -> TensorLikeType:
102
+ return prims.bessel_i1(a)
103
+
104
+
105
+ @_make_elementwise_unary_reference(
106
+ ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
107
+ )
108
+ def i1e(a: TensorLikeType) -> TensorLikeType:
109
+ return prims.bessel_i1e(a)
110
+
111
+
112
+ @register_decomposition(aten.special_log_ndtr)
113
+ @out_wrapper()
114
+ @elementwise_type_promotion_wrapper(
115
+ type_promoting_args=("a",),
116
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
117
+ )
118
+ def log_ndtr(a: TensorLikeType) -> TensorLikeType:
119
+ # Note: M_SQRT1_2 is the value of 1 / √2
120
+ M_SQRT1_2 = 0.707106781186547524400844362104849039
121
+ t = a * M_SQRT1_2
122
+ return torch.where(
123
+ a < 1.0,
124
+ torch.log(torch.special.erfcx(-t) / 2) - t * t,
125
+ torch.log1p(-torch.erfc(t) / 2),
126
+ )
127
+
128
+
129
+ @register_decomposition(aten.logit)
130
+ @out_wrapper()
131
+ @elementwise_type_promotion_wrapper(
132
+ type_promoting_args=("self",),
133
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
134
+ )
135
+ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
136
+ if eps is None:
137
+ eps = -1.0
138
+ lo = eps
139
+ hi = 1 - eps
140
+ self = torch.clamp(self, lo, hi)
141
+ return torch.log(torch.true_divide(self, torch.sub(1, self)))
142
+
143
+
144
+ @register_decomposition(aten.special_xlog1py)
145
+ @out_wrapper()
146
+ @elementwise_type_promotion_wrapper(
147
+ type_promoting_args=("a", "b"),
148
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
149
+ )
150
+ def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
151
+ torch._check(
152
+ isinstance(a, TensorLike) or isinstance(b, TensorLike),
153
+ lambda: 'Expected either argument a or b to be a Tensor"',
154
+ )
155
+
156
+ # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
157
+ if isinstance(a, TensorLike) and isinstance(b, Number):
158
+ b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device)
159
+ elif isinstance(b, TensorLike) and isinstance(a, Number):
160
+ a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device)
161
+
162
+ # mypy: expected "Tensor"
163
+ assert isinstance(a, TensorLike)
164
+ assert isinstance(b, TensorLike)
165
+ rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log1p(b)))
166
+ return torch.where(torch.isnan(b), float("nan"), rhs)
167
+
168
+
169
+ @register_decomposition(aten.mvlgamma)
170
+ @out_wrapper()
171
+ @elementwise_type_promotion_wrapper(
172
+ type_promoting_args=("a",),
173
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
174
+ )
175
+ def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
176
+ c = 0.25 * p * (p - 1) * math.log(math.pi)
177
+ b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device)
178
+ return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
179
+
180
+
181
+ @register_decomposition(aten.special_ndtr)
182
+ @out_wrapper()
183
+ @elementwise_type_promotion_wrapper(
184
+ type_promoting_args=("a",),
185
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
186
+ )
187
+ def ndtr(a: TensorLikeType) -> TensorLikeType:
188
+ # Note: M_SQRT1_2 is the value of 1 / √2
189
+ M_SQRT1_2 = 0.707106781186547524400844362104849039
190
+ a_sqrt_2 = a * M_SQRT1_2
191
+ return (1 + torch.erf(a_sqrt_2)) * 0.5
192
+
193
+
194
+ @register_decomposition(aten.special_ndtri)
195
+ @out_wrapper()
196
+ @elementwise_type_promotion_wrapper(
197
+ type_promoting_args=("a",),
198
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
199
+ )
200
+ def ndtri(a: TensorLikeType) -> TensorLikeType:
201
+ return prims.ndtri(a)
202
+
203
+
204
+ # Forwarding alias: the special variant doesn't support the out kwarg
205
+ # CompositeImplicitAutograd - don't register decomp
206
+ def log_softmax(
207
+ a: TensorLikeType,
208
+ dim: int,
209
+ dtype: Optional[torch.dtype] = None,
210
+ ) -> TensorLikeType:
211
+ return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
212
+
213
+
214
+ # Forwarding alias: the special variant doesn't support the out kwarg
215
+ # CompositeImplicitAutograd - don't register decomp
216
+ def softmax(
217
+ a: TensorLikeType,
218
+ dim: int,
219
+ dtype: Optional[torch.dtype] = None,
220
+ ) -> TensorLikeType:
221
+ return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
222
+
223
+
224
+ @_make_elementwise_unary_reference(
225
+ ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
226
+ )
227
+ def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType:
228
+ return prims.spherical_bessel_j0(a)
229
+
230
+
231
+ # TODO: add docstring
232
+ @_make_elementwise_binary_reference(
233
+ type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
234
+ )
235
+ def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
236
+ return prims.zeta(a, b)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nvtx.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling."""
2
+
3
+ from contextlib import contextmanager
4
+
5
+ try:
6
+ from torch._C import _nvtx
7
+ except ImportError:
8
+
9
+ class _NVTXStub:
10
+ @staticmethod
11
+ def _fail(*args, **kwargs):
12
+ raise RuntimeError(
13
+ "NVTX functions not installed. Are you sure you have a CUDA build?"
14
+ )
15
+
16
+ rangePushA = _fail
17
+ rangePop = _fail
18
+ markA = _fail
19
+
20
+ _nvtx = _NVTXStub() # type: ignore[assignment]
21
+
22
+ __all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
23
+
24
+
25
+ def range_push(msg):
26
+ """
27
+ Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started.
28
+
29
+ Args:
30
+ msg (str): ASCII message to associate with range
31
+ """
32
+ return _nvtx.rangePushA(msg)
33
+
34
+
35
+ def range_pop():
36
+ """Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended."""
37
+ return _nvtx.rangePop()
38
+
39
+
40
+ def range_start(msg) -> int:
41
+ """
42
+ Mark the start of a range with string message. It returns an unique handle
43
+ for this range to pass to the corresponding call to rangeEnd().
44
+
45
+ A key difference between this and range_push/range_pop is that the
46
+ range_start/range_end version supports range across threads (start on one
47
+ thread and end on another thread).
48
+
49
+ Returns: A range handle (uint64_t) that can be passed to range_end().
50
+
51
+ Args:
52
+ msg (str): ASCII message to associate with the range.
53
+ """
54
+ return _nvtx.rangeStartA(msg)
55
+
56
+
57
+ def range_end(range_id) -> None:
58
+ """
59
+ Mark the end of a range for a given range_id.
60
+
61
+ Args:
62
+ range_id (int): an unique handle for the start range.
63
+ """
64
+ _nvtx.rangeEnd(range_id)
65
+
66
+
67
+ def mark(msg):
68
+ """
69
+ Describe an instantaneous event that occurred at some point.
70
+
71
+ Args:
72
+ msg (str): ASCII message to associate with the event.
73
+ """
74
+ return _nvtx.markA(msg)
75
+
76
+
77
+ @contextmanager
78
+ def range(msg, *args, **kwargs):
79
+ """
80
+ Context manager / decorator that pushes an NVTX range at the beginning
81
+ of its scope, and pops it at the end. If extra arguments are given,
82
+ they are passed as arguments to msg.format().
83
+
84
+ Args:
85
+ msg (str): message to associate with the range
86
+ """
87
+ range_push(msg.format(*args, **kwargs))
88
+ try:
89
+ yield
90
+ finally:
91
+ range_pop()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_pytree.cpython-311.pyc ADDED
Binary file (5.88 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph_module.cpython-311.pyc ADDED
Binary file (41.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/immutable_collections.cpython-311.pyc ADDED
Binary file (4.59 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/proxy.cpython-311.pyc ADDED
Binary file (32.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/tensor_type.cpython-311.pyc ADDED
Binary file (5.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
2
+ BVar
3
+ from torch.fx.experimental.migrate_gradual_types.operation import op_leq
4
+
5
+
6
+ def gen_tvar(curr):
7
+ """
8
+ Generate a tensor variable
9
+ :param curr: The current counter
10
+ :return: a tensor variable and the updated counter
11
+ """
12
+ curr += 1
13
+ return TVar(curr), curr
14
+
15
+
16
+ def gen_dvar(curr):
17
+ """
18
+ Generate a dimension variable
19
+ :param curr: the current counter
20
+ :return: a dimension variable and an updated counter
21
+ """
22
+ curr += 1
23
+ return DVar(curr), curr
24
+
25
+ def gen_bvar(curr):
26
+ """
27
+ Generate a boolean variable
28
+ :param curr: the current counter
29
+ :return: a boolean variable and an updated counter
30
+ """
31
+ curr += 1
32
+ return BVar(curr), curr
33
+
34
+ def gen_tensor_dims(n, curr):
35
+ """
36
+ Generate a list of tensor dimensions
37
+ :param n: the number of dimensions
38
+ :param curr: the current counter
39
+ :return: a list of dimension variables and an updated counter
40
+ """
41
+ dims = []
42
+ for _ in range(n):
43
+ dvar, curr = gen_dvar(curr)
44
+ dims.append(dvar)
45
+ return dims, curr
46
+
47
+
48
+ def gen_nat_constraints(list_of_dims):
49
+ """
50
+ Generate natural number constraints for dimensions
51
+ """
52
+ return [BinConstraintD(0, d, op_leq) for d in list_of_dims]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc ADDED
Binary file (5.28 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc ADDED
Binary file (4.45 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (781 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-311.pyc ADDED
Binary file (2.32 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-311.pyc ADDED
Binary file (21.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-311.pyc ADDED
Binary file (5.94 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-311.pyc ADDED
Binary file (9.07 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/split_module.cpython-311.pyc ADDED
Binary file (25 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-311.pyc ADDED
Binary file (3.27 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc ADDED
Binary file (6.11 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ from . import pass_manager
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc ADDED
Binary file (14.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from collections import namedtuple
3
+ from typing import Optional
4
+
5
+ from torch.fx.graph_module import GraphModule
6
+ from torch.fx._compatibility import compatibility
7
+
8
+
9
+ __all__ = ['PassResult', 'PassBase']
10
+
11
+ @compatibility(is_backward_compatible=False)
12
+ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
13
+ """
14
+ Result of a pass:
15
+ graph_module: The modified graph module
16
+ modified: A flag for if the pass has modified the graph module
17
+ """
18
+ def __new__(cls, graph_module, modified):
19
+ return super().__new__(cls, graph_module, modified)
20
+
21
+ @compatibility(is_backward_compatible=False)
22
+ class PassBase(abc.ABC):
23
+ """
24
+ Base interface for implementing passes.
25
+
26
+ It is required to implement the `call` function so that we can directly
27
+ pass instances of the Pass directly to the PassManager and call them as a
28
+ function.
29
+
30
+ We can directly pass an instance of a class implementing this interface into
31
+ the PassManager's `passes` attribute.
32
+ """
33
+
34
+ def __call__(self, graph_module: GraphModule) -> Optional[PassResult]:
35
+ """
36
+ Runs the precondition check, the pass itself, and the postcondition check.
37
+ """
38
+
39
+ self.requires(graph_module)
40
+ res = self.call(graph_module)
41
+ self.ensures(graph_module)
42
+ return res
43
+
44
+ @abc.abstractmethod
45
+ def call(self, graph_module: GraphModule) -> Optional[PassResult]:
46
+ """
47
+ The pass that is run through the given graph module. To implement a
48
+ pass, it is required to implement this function.
49
+
50
+ Args:
51
+ graph_module: The graph module we will run a pass on
52
+ """
53
+ pass
54
+
55
+ def requires(self, graph_module: GraphModule) -> None: # noqa: B027
56
+ """
57
+ This function will be called before the pass is run and will check that
58
+ the given graph module contains the preconditions needed to run the
59
+ pass. It is not required to implement this function.
60
+
61
+ Args:
62
+ graph_module: The graph module we will run checks on
63
+ """
64
+ pass
65
+
66
+ def ensures(self, graph_module: GraphModule) -> None: # noqa: B027
67
+ """
68
+ This function will be called after the pass is run and will check that
69
+ the given graph module contains the postconditions needed to run the
70
+ pass. It is not required to implement this function.
71
+
72
+ Args:
73
+ graph_module: The graph module we will run checks on
74
+ """
75
+ pass
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ from queue import Queue
4
+ from functools import wraps
5
+ from typing import Callable, Dict, List
6
+
7
+ import torch.nn as nn
8
+ from torch.fx.graph_module import GraphModule
9
+ from torch.fx._compatibility import compatibility
10
+ from torch.fx.passes.infra.pass_base import PassResult
11
+
12
+ logger = logging.getLogger(__name__)
13
+ logger.setLevel(logging.WARNING)
14
+
15
+ __all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
16
+
17
+ @compatibility(is_backward_compatible=False)
18
+ def pass_result_wrapper(fn: Callable) -> Callable:
19
+ """
20
+ Wrapper for passes which currently do not return a PassResult.
21
+ This wrapper makes them return a PassResult containing the modified object
22
+ and True for the "modified" flag.
23
+
24
+ Args:
25
+ fn (Callable[Module, Any])
26
+
27
+ Returns:
28
+ wrapped_fn (Callable[Module, PassResult])
29
+ """
30
+ if fn is None:
31
+ return None
32
+
33
+ @wraps(fn)
34
+ def wrapped_fn(gm):
35
+ res = fn(gm)
36
+ if res is None:
37
+ return PassResult(gm, True)
38
+ if isinstance(res, PassResult):
39
+ return res
40
+ elif isinstance(res, nn.Module):
41
+ return PassResult(res, True)
42
+
43
+ if not inspect.isfunction(fn):
44
+ wrapped_fn.__name__ = type(fn).__name__
45
+
46
+ return wrapped_fn
47
+
48
+ def _validate_pass_schedule_constraint(
49
+ constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
50
+ ) -> None:
51
+ for i, a in enumerate(passes):
52
+ for j, b in enumerate(passes[i + 1 :]):
53
+ if constraint(a, b):
54
+ continue
55
+ raise RuntimeError(
56
+ f"pass schedule constraint violated. Expected {a} before {b}"
57
+ f" but found {a} at index {i} and {b} at index{j} in pass"
58
+ f" list."
59
+ )
60
+
61
+ def _topological_sort_passes(
62
+ passes: List[Callable], constraints: List[Callable]
63
+ ) -> List[Callable]:
64
+ """
65
+ Args
66
+ passes: Passes that we are ordering
67
+ constraints: Constraints applied on these passes
68
+
69
+ Returns
70
+ A sorted list of callables and a boolean of if a circular dependency
71
+ existed
72
+ """
73
+ if len(constraints) == 0:
74
+ return passes
75
+
76
+ # Contruct a graph mapping nodes to a list of their users
77
+ graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
78
+ indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0)
79
+ candidates: Queue = Queue()
80
+ for a in passes:
81
+ for b in passes:
82
+ if a == b:
83
+ continue
84
+
85
+ for constraint in constraints:
86
+ if not constraint(a, b):
87
+ graph[b].append(a)
88
+ indegree_map[a] += 1
89
+
90
+ if indegree_map[a] == 0:
91
+ candidates.put(a)
92
+
93
+ visited: Dict[Callable, bool] = dict.fromkeys(passes, False)
94
+ sorted_passes: List[Callable] = []
95
+
96
+ while not candidates.empty():
97
+ p = candidates.get()
98
+ sorted_passes.append(p)
99
+ visited[p] = True
100
+
101
+ for n in graph[p]:
102
+ if not visited[n]:
103
+ indegree_map[n] -= 1
104
+ if indegree_map[n] == 0:
105
+ candidates.put(n)
106
+
107
+ # Check if there are unvisited nodes (aka cycles in the graph)
108
+ cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
109
+ if len(cycle_passes) != 0:
110
+ error = f"Circular dependency detected within the following passes: {cycle_passes}"
111
+ raise RuntimeError(error)
112
+
113
+ return sorted_passes
114
+
115
+ @compatibility(is_backward_compatible=False)
116
+ def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
117
+ """
118
+ Defines a partial order ('depends on' function) where `this` must occur
119
+ before `that`.
120
+
121
+ For example, the following pass list and constraint list would be invalid.
122
+ ```
123
+ passes = [pass_b, pass_a]
124
+
125
+ constraints = [
126
+ this_before_that_pass_constraint(pass_a, pass_b)
127
+ ]
128
+ ```
129
+
130
+ Args:
131
+ this (Callable): pass which should occur first
132
+ that (Callable): pass which should occur later
133
+
134
+ Returns:
135
+ depends_on (Callable[[Object, Object], bool]
136
+ """
137
+
138
+ def depends_on(a: Callable, b: Callable):
139
+ if a == that and b == this:
140
+ return False
141
+ return True
142
+
143
+ return depends_on
144
+
145
+
146
+ @compatibility(is_backward_compatible=False)
147
+ class PassManager:
148
+ """
149
+ Construct a PassManager.
150
+
151
+ Collects passes and constraints. This defines the pass schedule, manages
152
+ pass constraints and pass execution.
153
+
154
+ Args:
155
+ passes (Optional[List[Callable]]): List of passes. A pass is a
156
+ callable which modifies an object and returns a PassResult
157
+ constraint (Optional[List[Callable]]): List of constraints. A
158
+ constraint is a callable which takes two passes (A, B) and returns
159
+ True if A depends on B and False otherwise. See implementation of
160
+ `this_before_that_pass_constraint` for example.
161
+ steps (int): Max number of times we run the passes (default = 1).
162
+ run_checks_after_each_pass (bool): Whether to run checks and linting
163
+ after each pass
164
+ suppress_check_failures (bool): Whether to raise errors when running
165
+ checks
166
+ """
167
+
168
+ passes: List[Callable[[nn.Module], PassResult]]
169
+ constraints: List[Callable[[Callable, Callable], bool]]
170
+ _validated: bool = False
171
+ steps: int = 1
172
+
173
+ def __init__(
174
+ self,
175
+ passes=None,
176
+ constraints=None,
177
+ steps=None,
178
+ run_checks_after_each_pass: bool = False,
179
+ suppress_check_failures: bool = False,
180
+ ):
181
+ self.passes = passes or []
182
+ self.constraints = constraints or []
183
+ if steps:
184
+ self.steps = steps
185
+
186
+ self.run_checks_after_each_pass = run_checks_after_each_pass
187
+ self.suppress_check_failures = suppress_check_failures
188
+
189
+ def add_pass(self, _pass: Callable):
190
+ """
191
+ Adds a pass into the current list of passes.
192
+ """
193
+ self.passes.append(_pass)
194
+ self._validated = False
195
+
196
+ def add_constraint(self, constraint: Callable):
197
+ """
198
+ Adds a constraint into the current list of constraints.
199
+ """
200
+ self.constraints.append(constraint)
201
+ self._validated = False
202
+
203
+ def validate_constraints(self):
204
+ """
205
+ Validates that current pass schedule defined by `self.passes` is valid
206
+ according to all constraints in `self.constraints`
207
+ """
208
+ if self._validated:
209
+ return
210
+ for constraint in self.constraints:
211
+ _validate_pass_schedule_constraint(constraint, self.passes)
212
+ self._validated = True
213
+
214
+ def solve_constraints(self):
215
+ """
216
+ Finds a valid traversal order based on the given constraints and orders
217
+ the passes based on this order.
218
+
219
+ If a circular dependency exists between the constraints and steps = 1,
220
+ then we will raise an error because if steps != 1 this means that we
221
+ will re-run the passes, allowing for circular dependencies.
222
+ """
223
+ self.passes = _topological_sort_passes(self.passes, self.constraints)
224
+ self._validated = True
225
+
226
+ def add_checks(self, check: Callable) -> None:
227
+ """
228
+ Adds a function which takes runs various checks on a given graph module.
229
+ This function is run before and after each pass if the
230
+ `run_checks_after_each_pass` flag is enabled.
231
+ """
232
+ sig = inspect.signature(check)
233
+
234
+ if len(list(sig.parameters.values())) != 1:
235
+ raise TypeError("PassManager check function should only take in one variable, a module")
236
+
237
+ setattr(self, "check", check) # noqa: B010
238
+
239
+ def check(self, module: nn.Module) -> None:
240
+ pass
241
+
242
+ def __call__(self, module: nn.Module) -> PassResult:
243
+ """
244
+ Runs a list of passes in the order based on `self.passes` on the given
245
+ graph module. Each time a pass is run, checks and linting will be run on
246
+ the graph module if `run_checks_after_each_pass` is set.
247
+
248
+ If the module is a graph module, we will run the list of passes until
249
+ the graph stops changing, or until `steps` number of times.
250
+ """
251
+ # Order the passes based on the constraints
252
+ if not self._validated:
253
+ self.solve_constraints()
254
+
255
+ # Check graph invariants
256
+ self.check(module)
257
+
258
+ # Run the set of passes `steps` number of times or until the graph stops
259
+ # changing
260
+ overall_modified = False
261
+ for _ in range(self.steps):
262
+ modified = False
263
+
264
+ # Run the set of passes on the graph module
265
+ for i, fn in enumerate(self.passes):
266
+ fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
267
+ logger.debug("Running pass '%s'", fn_name)
268
+
269
+ try:
270
+ res = fn(module)
271
+
272
+ if not isinstance(res, PassResult) and not hasattr(
273
+ res, "graph_module"
274
+ ):
275
+ raise TypeError(
276
+ f"The result of the pass {fn_name} should be type PassResult."
277
+ + "Please wrap it with pass_result_wrapper()"
278
+ )
279
+ module = res.graph_module
280
+ modified = modified or res.modified
281
+
282
+ if isinstance(module, GraphModule):
283
+ logger.debug("Graph after pass '%s': %s", fn_name, module.graph)
284
+ module.recompile()
285
+
286
+ # Check graph invariants
287
+ if self.run_checks_after_each_pass:
288
+ self.check(module)
289
+
290
+ except Exception as e:
291
+ prev_pass_names = [
292
+ p.__name__ if inspect.isfunction(p) else type(p).__name__
293
+ for p in self.passes[:i]
294
+ ]
295
+ msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
296
+ raise Exception(msg) from e
297
+
298
+ # If the graph no longer changes, then we can stop running these passes
299
+ overall_modified = overall_modified or modified
300
+ if not modified:
301
+ break
302
+
303
+ return PassResult(module, overall_modified)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.fx import Node
3
+ from torch.fx._compatibility import compatibility
4
+ from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
5
+ from torch.utils._pytree import tree_map_only
6
+ from torch.utils import _pytree as pytree
7
+ from torch.multiprocessing.reductions import StorageWeakRef
8
+
9
+ import _operator
10
+ from enum import Enum
11
+ import itertools
12
+ from typing import Set, Dict
13
+ from collections import defaultdict
14
+
15
+ __all__ = ['reinplace']
16
+
17
+ class _ViewType(Enum):
18
+ NonView = 0
19
+ SingleOutputView = 1
20
+ MultiOutputView = 2
21
+
22
+ def _is_view_op(tgt):
23
+ if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
24
+ schema = tgt._schema
25
+ if len(schema.arguments) > 0:
26
+ first_arg = schema.arguments[0]
27
+ # check if op is a view
28
+ return first_arg.alias_info is not None and not first_arg.alias_info.is_write
29
+
30
+ def _get_view_type(tgt) -> _ViewType:
31
+ if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
32
+ schema = tgt._schema
33
+ if len(schema.arguments) > 0:
34
+ first_arg = schema.arguments[0]
35
+ # check if op is a view
36
+ if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
37
+ # check if op is a multi-output view
38
+ if '*' in first_arg.alias_info.after_set:
39
+ return _ViewType.MultiOutputView
40
+ else:
41
+ return _ViewType.SingleOutputView
42
+ return _ViewType.NonView
43
+
44
+
45
+ # Stores a bunch of metadata related to functionalization each node.
46
+ # Relevant metadata:
47
+ # n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors)
48
+ # The fake tensor output from running the current node
49
+ # n.meta['view_of']: Node
50
+ # If the current node n is a view of some base tensor, the 'view_of' field tells us which
51
+ # view node was used to generate the current node (a view tensor).
52
+ # This information actually makes `fake_result` redundant, but we can use `fake_result`
53
+ # to sanity check that our aliasing information is correct.
54
+ @compatibility(is_backward_compatible=False)
55
+ class _FunctionalizationMetadataProp(torch.fx.Interpreter):
56
+
57
+ def run_node(self, node: Node):
58
+ self.node_counter += 1
59
+ result = super().run_node(node)
60
+ node.meta['fake_result'] = result
61
+ node.meta['node_idx'] = self.node_counter
62
+
63
+ # (1) Update metadata with the list of nodes that are used by this node
64
+ # copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
65
+ # We don't want to treat it as "being used as an input".
66
+ node_args = node.args
67
+ if node.target is torch.ops.aten.copy_.default:
68
+ node_args = node_args[1:]
69
+
70
+ # (2) Update metadata to track aliasing information about view tensor nodes.
71
+ if node.op == 'call_function':
72
+ view_type = _get_view_type(node.target)
73
+ if view_type == _ViewType.SingleOutputView:
74
+ assert isinstance(node.args[0], Node)
75
+ node.meta['view_of'] = node.args[0]
76
+ elif view_type == _ViewType.MultiOutputView:
77
+ self.multi_output_view_nodes[node] = node.args[0]
78
+
79
+ # Check if we returned a multi-output view,
80
+ # and we're now grabbing the individual views from the output.
81
+ #
82
+ # For multi-output views, we want to map each output view to the base,
83
+ # but this mapping involves two separate nodes in FX IR.
84
+ # e.g. "a, b = x_1.split(...)" becomes:
85
+ # %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
86
+ # %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
87
+ # %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
88
+ # And we'd like to set:
89
+ # getitem1.meta['view_of'] = x_1
90
+ elif node.target is _operator.getitem:
91
+ list_arg = node.args[0]
92
+ maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None)
93
+ if maybe_base_of_view is not None:
94
+ # Note: we could also track indexing info here for multi-output views.
95
+ # I don't think this metadata is strictly needed for de-functionalization.
96
+ assert isinstance(maybe_base_of_view, Node)
97
+ node.meta['view_of'] = maybe_base_of_view
98
+
99
+ if 'view_of' in node.meta:
100
+ # We're linking the current node with its first argument as views.
101
+ # Assert here that this is actually the case, and their storages are the same.
102
+ assert isinstance(node.meta['fake_result'], FakeTensor)
103
+ assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
104
+ view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
105
+ base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
106
+ assert view_storage == base_storage
107
+ return result
108
+
109
+
110
+
111
+ def propagate(self, *args):
112
+ self.multi_output_view_nodes = {}
113
+ self.node_counter = -1
114
+
115
+ with FakeTensorMode() as mode:
116
+ fake_args = [mode.from_tensor(a) for a in args]
117
+ return super().run(*fake_args)
118
+
119
+ def _schemas_match(functional_schema, inplace_schema):
120
+ names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
121
+ arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
122
+ a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
123
+ # for the inplace op, its first argument should be mutable
124
+ assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
125
+ # and its remaining arguments shouldn't be.
126
+ assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
127
+ return names_match and arg_types_match
128
+
129
+ # TODO: this should be beefed up to be able to properly re-inplace with:
130
+ # - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
131
+ # - out= ops (e.g. angle -> angle.out)
132
+ # TODO: we should also figure this info out using torchgen.
133
+ def _maybe_get_inplace_op(op):
134
+ # __module__ seems broken; it returns torch._ops.aten which doesn't exist
135
+ if not isinstance(op, torch._ops.OpOverload):
136
+ return None
137
+ # Some view ops have inplace variants (as_strided_, etc),
138
+ # but we do NOT want the reinplacing pass to directly add these into the program.
139
+ # (they'll require extra special handling, aren't aren't really useful for perf anyway)
140
+ if _is_view_op(op):
141
+ return None
142
+ op_namespace = op.__module__.split(".")[-1]
143
+ op_base_name = op.overloadpacket.__name__
144
+ maybe_namespace_module = getattr(torch.ops, op_namespace)
145
+ maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
146
+ if maybe_inplace_op is None:
147
+ return None
148
+
149
+ inplace_overloads = [
150
+ getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
151
+ ]
152
+ inplace_overloads_with_matching_schemas = [
153
+ f
154
+ for f in inplace_overloads
155
+ if _schemas_match(op._schema, f._schema)
156
+ ]
157
+ # Just because foo() and foo_() are both existing operators,
158
+ # They aren't guaranteed to have compatible schemas.
159
+ # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant,
160
+ # Even though several overloads of pow_ exist.
161
+ if len(inplace_overloads_with_matching_schemas) == 0:
162
+ return None
163
+ assert len(inplace_overloads_with_matching_schemas) == 1
164
+ inplace_op = inplace_overloads_with_matching_schemas[0]
165
+ return inplace_op
166
+
167
+ _VIEW_INVERSE_MAP = {
168
+ torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
169
+ torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
170
+ torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
171
+ torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
172
+ }
173
+
174
+ # This function, given a set of set of (aliased) tensor nodes,
175
+ # Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
176
+ # in the node ordering.
177
+ def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
178
+ def _add_if_tensor(x, set_):
179
+ if isinstance(x, FakeTensor):
180
+ set_.add(StorageWeakRef(x._typed_storage()))
181
+
182
+ nodes_used_after = set()
183
+ for t in tensor_aliases:
184
+ # get all nodes that use the current alias
185
+ usage_nodes = t.users
186
+ for n in usage_nodes:
187
+ # We only care about usages after the current node
188
+ if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index:
189
+ continue
190
+ # We also don't care about intermediate view ops.
191
+ # They only matter if their output is then used elsewhere
192
+ # (either in an out-of-place op, or as an output to the function).
193
+ if n in tensor_aliases:
194
+ if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
195
+ continue
196
+ nodes_used_after.add(n)
197
+ return nodes_used_after
198
+
199
+ # Given an op that we're trying to re-inplace, "b = foo(a)",
200
+ # And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
201
+ # Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
202
+ # If there are any aliases in the alias_set(a) that satisfy:
203
+ # (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
204
+ # (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
205
+ # as "alias"
206
+ def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
207
+ def matching_view_metadata(a, b):
208
+ return a.size() == b.size() and \
209
+ a.stride() == b.stride() and \
210
+ a.storage_offset() == b.storage_offset()
211
+
212
+ view_inverse_nodes = set()
213
+ # Go through them in node order, so we can see chains of view_scatter ops.
214
+ for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
215
+ if n.target not in _VIEW_INVERSE_MAP:
216
+ continue
217
+ base = n.args[0]
218
+ mutated_view = n.args[1]
219
+ assert isinstance(base, Node)
220
+ assert isinstance(base.meta['fake_result'], FakeTensor)
221
+ assert isinstance(mutated_view, Node)
222
+ assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
223
+ # Check that this view_inverse op actually corresponds to taking doing the inverse
224
+ # of one of our existing self_alias nodes.
225
+ original_view = _VIEW_INVERSE_MAP[n.target]
226
+ for self_alias in self_aliases:
227
+ # We're looking for some alias of the self arg, "alias",
228
+ # that was created from some op `alias = foo(base, args...)`
229
+ # such that the current _scatter op "inverts" that foo call.
230
+ # We can check that by running the original op again, and checking that the strides match.
231
+ if 'view_of' not in self_alias.meta:
232
+ continue
233
+ self_alias_base = self_alias.meta['view_of']
234
+ try:
235
+ # The we're trying to re-use the args from the view_scatter call inside of the corresponding
236
+ # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
237
+ # of the current alias we're looking at.
238
+ view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
239
+ expected_metadata = self_alias.meta['fake_result']
240
+ # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
241
+ if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
242
+ matching_view_metadata(view_replay_metadata, expected_metadata):
243
+ view_inverse_nodes.add(n)
244
+ except Exception:
245
+ continue
246
+
247
+ return view_inverse_nodes
248
+
249
+
250
+ @compatibility(is_backward_compatible=True)
251
+ def reinplace(gm, *sample_args):
252
+ """
253
+ Given an fx.GraphModule, modifies it to perform "reinplacing",
254
+ mutating the nodes of the graph.
255
+ We look for out-of-place op call sites like `b = a.add(...)`,
256
+ and convert them to be inplace (`b = a.add_(...)`),
257
+ as long as the input to the current operator ("a") isn't re-used
258
+ anywhere later in the graph.
259
+
260
+ This pass currently expects to operate on a **functional, ATen** graph.
261
+ This can be obtained by running `make_fx(functionalize(f))`.
262
+
263
+ Sample inputs are needed to determine aliasing relationships of the inputs.
264
+ In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
265
+ inputs to the program.
266
+
267
+ Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows:
268
+
269
+ (1) Perform some initial checks on the metadata of "a" and "args..."
270
+ that can disqualify them from being reinplaced.
271
+
272
+ (1a) Check that the self argument we're attempting to reinplace
273
+ has acceptable dtype/size metadata to reinplace with.
274
+
275
+ For example, if we have:
276
+ a = torch.ones(1)
277
+ b = torch.ones(10)
278
+ out = torch.add(a, b)
279
+ We can't turn that into
280
+ a.add_(b)
281
+ Because that would require resizing "a".
282
+
283
+ Similarly, we can't convert torch.ge(a, b) into a.ge_(b),
284
+ because that would require changing a's dtype (from e.g. float32 to bool).
285
+ Note that in this specific example, we could technically do better..
286
+
287
+ If we see the pattern:
288
+ a_1 = a.ge(b)
289
+ a_2 = aten._to_copy(a_1, a.dtype)
290
+ Then we this should be valid to completely re-inplace
291
+ (this is exactly what functionalization will emit when it sees a.ge_(b)).
292
+
293
+ This optimization is only really important for user programs
294
+ that directly use inplace comparison ops though.
295
+
296
+ We also cannot re-inplace on tensors that have overlapping memory,
297
+ e.g. torch.ones(1).expand(4, 4).add_(1)
298
+
299
+ (1b) Check if "a" is an alias of any of the program inputs.
300
+
301
+ If it is, skip and move to the next node.
302
+ Inplace'ing an op that would cause it to mutate a program is not sound,
303
+ because that would be a side effect visible to the user.
304
+
305
+ NOTE: there's a future optimization that we should make:
306
+ if "a" is a (alias of a) program input, but later in the program
307
+ there is a node that looks like "a.copy_(...)",
308
+ Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
309
+ which will later be overwritten by the copy_() call.
310
+
311
+ This will be an important optimization to have for programs that mutate
312
+ their inputs. It currently isn't implemented though.
313
+
314
+ (1c) Check if "a" and "args..." alias
315
+
316
+ For example, re-inplacing to create code like the below
317
+ isn't guaranteed to be sound:
318
+
319
+ aten.mul_(a, a)
320
+
321
+ (2) Check that "a" and all of its outstanding aliases are not used anywhere
322
+ later in the graph. If this is the case, then it's safe to re-inplace
323
+ to "b = foo_(a)".
324
+
325
+ There are a few caveats to this, explained in more detail below:
326
+ (a) If "a" is used later as an argument to a view op, that is okay.
327
+ It's only a problem if "a" (or that view) is later passed
328
+ into a normal operator, or if it is returned as the program output.
329
+ (b) If "a" is a repeat argument in `foo()`, then don't reinplace.
330
+ Most ATen kernels don't make any guarantees that this is sound,
331
+ e.g. if you do aten.mul_(a, a).
332
+ So we'll just ban re-inplacing in this case.
333
+ It's only a problem if "a" (or that view) is later passed
334
+ (c) If "a" is used as an input into a view "inverse" / "scatter"
335
+ operator, it is potentially fine to re-inplace
336
+ (and remove that scatter operator from the graph).
337
+ See below for a more detailed example.
338
+
339
+ NOTE: there is an optimization in this step that is crucial
340
+ to fully recovering performance from functionalization.
341
+
342
+ Given this program:
343
+ def f(x):
344
+ a = torch.ops.aten.add(x, x)
345
+ b = torch.ops.aten.diagonal(a)
346
+ torch.ops.aten.fill_(b, 0)
347
+ return d
348
+
349
+ Functionalization will emit the following:
350
+ def f(x):
351
+ a = torch.ops.aten.add(x, x)
352
+ b = torch.ops.aten.diagonal(a, 0, 1)
353
+ b_updated = torch.ops.aten.fill(b, 0)
354
+ a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1)
355
+ return a_updated
356
+
357
+ Ordinarily, we would not be able to reinplace the fill,
358
+ because "b" aliases with "a" which is used by the diagonal_scatter call.
359
+
360
+ "re-inplacing" is on the hook for figuring out that it is ok to
361
+ completely, the expensive diagonal_scatter call, if we re-inplace the add().
362
+
363
+ So, for every `alias in alias_set(a)`, instead of checking
364
+ that "alias" is not used anywhere later in the graph,
365
+ we check that
366
+ EITHER:
367
+ (a) alias is not used anywhere later in the graph
368
+ OR:
369
+ (b) alias is used exactly once later on in the graph,
370
+ in the following op:
371
+
372
+ out = foo_scatter(alias, x, args...)
373
+
374
+ where the following must hold:
375
+ (i) "foo_scatter" is the "inverse" operator for foo.
376
+ This only applies to "foo" ops that are view operators,
377
+ which view into a subset of the original tensor's memory.
378
+ In practice, there are ~4 operators where this applies:
379
+ diagonal -> diagonal_scatter
380
+ slice -> slice_scatter
381
+ select -> select_scatter
382
+ as_strided -> as_strided_scatter
383
+ (ii) "args..." are the same between the foo() and foo_scatter() calls.
384
+
385
+ (3) Perform the actual re-inplacing on foo!
386
+
387
+ (3b) is the common case, but special care is needed for {view}_scatter (3a)
388
+
389
+ (3a) {view}_scatter ops.
390
+
391
+ Consider this program:
392
+ a = torch.zeros(2, 2)
393
+ b = torch.ones(2)
394
+ a[0] = b
395
+
396
+ Post functionalization, that will look like:
397
+ a = torch.zeros(2)
398
+ b = torch.ones(1)
399
+ a_updated = torch.select_scatter(a, b, 0, 0)
400
+
401
+ In this case though, there is no "functional" op to re-inplace!
402
+ Instead, we'd like to directly remove toe select_scatter call.
403
+ We already know from (3) that this is valid,
404
+ because "a" has no later usages in the graph.
405
+
406
+ We perform the re-inplacing on the {view}_scatter op like so
407
+ Before:
408
+ a_updated = torch.select_scatter(a, b, args...)
409
+ After:
410
+ a_slice = a.select(a, args...)
411
+ a_slice.copy_(b)
412
+
413
+ (3b) Otherwise, replace the functional op with its inplace variant.
414
+ Before:
415
+ b = foo(a, args...)
416
+ After:
417
+ a.foo_(args...)
418
+
419
+ (4) Finally, after converting either:
420
+ Before:
421
+ b = foo(a)
422
+ After:
423
+ foo_(a)
424
+ or
425
+ Before:
426
+ b = {slice}_scatter(a, mutated_slice, args...)
427
+ After:
428
+ slice = {slice}(a, args...)
429
+ slice.copy_(mutated_slice)
430
+
431
+ We now need to find all later nodes that use "b" as an argument
432
+ and update them to take in "a" instead.
433
+
434
+ Note that for the majority of inplace ops, this isn't actually necessary
435
+ (because most inplace ops return "self" as their output).
436
+ This isn't generally true for all mutable ops though, which is why
437
+ we need to actually replace all of the arguments.
438
+
439
+ We also need to update our metadata of Dict[StorageWeakRef, Set[Node]],
440
+ That maps a given tensor storage to the set of all nodes that take in that storage
441
+ as an input.
442
+ Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused
443
+ together.
444
+
445
+ (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them"
446
+ during step (3) get manually deleted from the graph.
447
+ Their outputs are no longer used, so technically standard DCE would be able
448
+ to do this, but we can no longer run FX's DCE pass now that we have mutable
449
+ ops in the graph.
450
+ """
451
+ _FunctionalizationMetadataProp(gm).propagate(*sample_args)
452
+
453
+ # Useful debug printing
454
+ # def _print(x):
455
+ # if isinstance(x, FakeTensor):
456
+ # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}')
457
+
458
+ # for n in gm.graph.nodes:
459
+ # print(n.format_node())
460
+ # if hasattr(n, 'meta'):
461
+ # print(f'node_idx: {n.meta["node_idx"]}')
462
+ # if 'fake_result' in n.meta:
463
+ # tree_map(_print, n.meta['fake_result'])
464
+ # if 'view_of' in n.meta:
465
+ # print(f'view_of: {str(n.meta["view_of"])}')
466
+ # print()
467
+
468
+ # We need to know which nodes correspond to inputs (or their aliases)
469
+ # so we know not to re-inplace them.
470
+ # NOTE: later, we'll need to add an optimization for fully recovering performance
471
+ # on programs that mutate inputs.
472
+ input_storages = {
473
+ StorageWeakRef(
474
+ node.meta['fake_result']._typed_storage()
475
+ ) for node in gm.graph.nodes if node.op == 'placeholder'}
476
+
477
+
478
+ # We also need to know for a given node, what are all of its aliasing nodes.
479
+ storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
480
+ for n in gm.graph.nodes:
481
+ if 'fake_result' in n.meta:
482
+ # Tree-mapping because some ops can return lists of tensors.
483
+ def _add_to_map(x):
484
+ if isinstance(x, FakeTensor):
485
+ storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
486
+ pytree.tree_map_(_add_to_map, n.meta['fake_result'])
487
+
488
+ # inplace-ify functional ops, subject to the constraints written below.
489
+ all_later_view_inverse_nodes_to_delete = set()
490
+ for idx, node in enumerate(gm.graph.nodes):
491
+ if node.op == 'call_function':
492
+
493
+ # Today, the re-inplace pass on directly acts on:
494
+ # - functional ops with an inplace variant
495
+ # - {view}_scatter ops that can be potentially removed from the graph.
496
+ # Both of these ops take in tensor first args, so filtering on this condition
497
+ # makes the later code simpler.
498
+ # We should revisit this at some point though, particularly when we also want
499
+ # the reinplacer to be able to handle out= and mutable operators
500
+ # and tensorlist first args (like `_foreach_` ops).
501
+ if not isinstance(node.target, torch._ops.OpOverload):
502
+ continue
503
+ if len(node.target._schema.arguments) < 1:
504
+ continue
505
+ if type(node.target._schema.arguments[0].type) != torch.TensorType:
506
+ continue
507
+
508
+ # Step 1a: Check that the self argument we're attempting to reinplace
509
+ # has the same size/stride as the output.
510
+ # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor)
511
+ # As it would require resizing scalar_tensor.
512
+ # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor),
513
+ # this is probably an optimization to revisit later).
514
+ self_arg = node.args[0]
515
+ self_flattened = pytree.tree_leaves(self_arg.meta['fake_result'])
516
+ node_flattened = pytree.tree_leaves(node.meta['fake_result'])
517
+ self_has_wrong_metadata = False
518
+ if len(self_flattened) == len(node_flattened):
519
+ for self_meta, node_meta in zip(self_flattened, node_flattened):
520
+ if self_meta.numel() != node_meta.numel():
521
+ self_has_wrong_metadata = True
522
+ if self_meta.dtype != node_meta.dtype:
523
+ self_has_wrong_metadata = True
524
+ # We also cannot re-inplace on tensors that have internal memory overlap.
525
+ # e.g. torch.ones(1).expand(4, 4).add_(1)
526
+ if torch._debug_has_internal_overlap(self_meta) == 1:
527
+ self_has_wrong_metadata = True
528
+ # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace,
529
+ # Since users should never really be calling the functional "torch.ops.aten.resize"
530
+ # op directly in their programs.
531
+ if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default:
532
+ continue
533
+
534
+ # Step 1b: ensure that the op we're trying to re-inplace isn't a program input
535
+ self_arg_name = self_arg.name
536
+ self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
537
+ if self_arg_storage in input_storages:
538
+ # TODO: later, add the optimization for handling `copy_()` calls in the graph.
539
+ continue
540
+ if len([x for x in node.args if x is self_arg]) > 1:
541
+ # Step 1c:
542
+ # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
543
+ # so we prevent re-inplacing in this case.
544
+ continue
545
+
546
+ self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
547
+ self_aliases = storage_to_nodes[self_arg_storage]
548
+
549
+ # First, we find all later usages of any of the aliases of self_arg.
550
+ later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
551
+ # Then, we check if any of those later usages are actually view_scatter ops
552
+ # that are safe to fully remove.
553
+ later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
554
+
555
+ # Step 2: Check to see if the input to the op is re-used later in the graph.
556
+ # If not (same goes for its aliases), then this op is safe to re-in place.
557
+ # This is a slightly roundabout way to check that there are no later usages of the current self argument.
558
+ # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
559
+ can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
560
+ if not can_reinplace:
561
+ continue
562
+
563
+ # Step 3a: Special handling for when we see *_scatter operators.
564
+ # When we see an operator like `b = torch.slice_scatter(a, ...)`,
565
+ # instead of trying to "inplace" it into a.slice_scatter_(..._),
566
+ # we would prefer to remove it from the graph entirely,
567
+ # and instead copy_() the slice directly into the larger tensor.
568
+ # See the description of the algorithm for a full example.
569
+ if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete:
570
+ view_op = _VIEW_INVERSE_MAP[node.target]
571
+ # Before:
572
+ # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
573
+ # After:
574
+ # slice = torch.ops.aten.slice.default(base, args...)
575
+ # slice.copy_(mutated_slice)
576
+ with gm.graph.inserting_before(node):
577
+ mutated_slice_node = node.args[1]
578
+ remaining_slice_args = node.args[2:]
579
+ slice_node = gm.graph.create_node(
580
+ 'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
581
+ copy_node = gm.graph.create_node(
582
+ 'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
583
+ # Add the slice_scatter node to our "nodes to delete" list.
584
+ all_later_view_inverse_nodes_to_delete.add(node)
585
+
586
+
587
+ else:
588
+ # Step 3b: Check to see if this operator has an inplace variant.
589
+ maybe_inplace_op = _maybe_get_inplace_op(node.target)
590
+ if maybe_inplace_op is None:
591
+ continue
592
+ # And if so, replace it with its inplace variant.
593
+ node.target = maybe_inplace_op
594
+
595
+ # At this point, 'storage_to_nodes' will be stale.
596
+ # Now that we're inplacing `b = foo(a)`, we need to effectively
597
+ # union together the dict values for b and a's storage.
598
+ # Hmm... morally I think we also want to keep the `fake_result` metadata
599
+ # up to date here, but I'm not sure how easy it is to do.
600
+ # Maybe it's fine to wait until the end of the pass to update it.
601
+ curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
602
+ storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
603
+ storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
604
+
605
+ # Need to remember the view_scatter view nodes we found so we can remove them alter.
606
+ all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages)
607
+
608
+ # Step 4:
609
+ # Now that we've replaced b = a.foo() with a.foo_(),
610
+ # We need to replace any later usages of "b" with "a"
611
+ for old in itertools.chain([node], later_view_inverse_node_usages):
612
+ new = old.args[0]
613
+ nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
614
+ for node_to_update in nodes_to_update:
615
+ new_args = []
616
+ args = node_to_update.args
617
+
618
+ def replace_arg(a):
619
+ if a == old:
620
+ return new
621
+ return a
622
+
623
+ # First, replace usages of "b" with "a"
624
+ node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
625
+ node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
626
+
627
+ # Second, update our storage_to_nodes data structure.
628
+ old_flattened_res = pytree.tree_leaves(old.meta['fake_result'])
629
+ node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result'])
630
+
631
+ old_res_storage = {
632
+ StorageWeakRef(
633
+ x._typed_storage()
634
+ ) for x in old_flattened_res if isinstance(x, FakeTensor)}
635
+ node_res_storage = {
636
+ StorageWeakRef(
637
+ x._typed_storage()
638
+ ) for x in node_flattened_res if isinstance(x, FakeTensor)}
639
+
640
+ # This will happen if we're updating a view op, e.g.
641
+ # e.g. replacing
642
+ # x = view(old)
643
+ # x = view(new)
644
+ # When that happens, we need to make sure to keep our
645
+ # storage mapping up to date.
646
+ #
647
+ # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor,
648
+ # or multiple tensors that all share the same storage.
649
+ # We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
650
+ if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
651
+ new_flattened_res = pytree.tree_leaves(new.meta['fake_result'])
652
+ new_res_storage = {
653
+ StorageWeakRef(
654
+ x._typed_storage()
655
+ ) for x in new_flattened_res if isinstance(x, FakeTensor)}
656
+ assert len(new_res_storage) == 1
657
+ (old_ref,) = old_res_storage
658
+ (new_ref,) = new_res_storage
659
+ (node_ref,) = node_res_storage
660
+ # Technically, "old_ref" and all its aliases will remain
661
+ # in our mapping.
662
+ # That should be fine though, since we deleted "old"
663
+ # from the graph at this point.
664
+ storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
665
+ storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
666
+
667
+ # Step 4: delete any _scatter nodes that we de-functionalized
668
+ # Need to take care not to delete any of these nodes until after *all* modifications
669
+ # to the graph are finished.
670
+ for to_delete in all_later_view_inverse_nodes_to_delete:
671
+ gm.graph.erase_node(to_delete)
672
+
673
+
674
+ gm.recompile()
675
+ return gm
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tests/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-311.pyc ADDED
Binary file (22.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/source_matcher_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from torch.fx.graph import Graph
3
+ from torch.fx.node import Node
4
+ from torch.fx._compatibility import compatibility
5
+ from typing import Dict, List, Any, Type, Optional, Callable
6
+ import logging
7
+ import os
8
+
9
+
10
+ __all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition']
11
+
12
+ # Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
13
+ def _init_logger():
14
+ logger = logging.getLogger(__name__)
15
+
16
+ level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
17
+ logger.setLevel(level)
18
+ console = logging.StreamHandler()
19
+ formatter = logging.Formatter("%(filename)s > %(message)s")
20
+ console.setFormatter(formatter)
21
+ console.setLevel(level)
22
+ # add the handlers to the logger
23
+ logger.addHandler(console)
24
+ logger.propagate = False
25
+ return logger
26
+
27
+ logger = _init_logger()
28
+
29
+
30
+ @compatibility(is_backward_compatible=False)
31
+ @dataclass
32
+ class SourcePartition:
33
+ # Nodes in a particular partition
34
+ nodes: List[Node]
35
+
36
+ # The source these nodes decomposed from
37
+ source: Any
38
+
39
+ # Nodes in the graph that are needed as inputs to the partition
40
+ input_nodes: List[Node] = field(default_factory=list)
41
+
42
+ # Nodes in the partition that are being used by nodes outside of the
43
+ # partition
44
+ output_nodes: List[Node] = field(default_factory=list)
45
+
46
+ # Parameters that are being used
47
+ params: List[Node] = field(default_factory=list)
48
+
49
+
50
+ @compatibility(is_backward_compatible=False)
51
+ def get_source_partitions(
52
+ graph: Graph,
53
+ wanted_sources: List[Any],
54
+ filter_fn: Optional[Callable[[Node], bool]] = None,
55
+ ) -> Dict[Any, List[SourcePartition]]:
56
+ """
57
+ Args:
58
+ graph: The graph we want to partition
59
+ wanted_sources: List of sources of nodes that were decomposed from this
60
+ source. This can be a function (ex. torch.nn.functional.linear) or a
61
+ leaf module type (ex. torch.nn.Linear).
62
+
63
+ Returns:
64
+ Dictionary mapping sources that were given to a list of SourcePartitions
65
+ that correspond to the list of nodes that were decomposed from the given
66
+ source.
67
+ """
68
+ modules: Dict[Type, Dict[str, List[Node]]] = {}
69
+
70
+ for node in graph.nodes:
71
+ # The metadata source_fn should contain a tuple of a unique name for the
72
+ # source, and the source function if the node is decomposed from a
73
+ # function, or the type of module if the node is decomposed from a leaf
74
+ # module
75
+
76
+ if (source_fn_st := node.meta.get("source_fn_stack", None)) is None:
77
+ continue
78
+
79
+ source_fn = source_fn_st[-1]
80
+ if source_fn[1] not in wanted_sources:
81
+ continue
82
+
83
+ diff_modules = modules.setdefault(source_fn[1], {})
84
+ partition = diff_modules.setdefault(source_fn[0], [])
85
+ partition.append(node)
86
+
87
+ def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
88
+ input_nodes = set()
89
+ output_nodes = set()
90
+ params = set()
91
+ for node in nodes:
92
+ for arg in node.args:
93
+ if isinstance(arg, Node) and arg not in nodes:
94
+ input_nodes.add(arg)
95
+
96
+ if node.op == "get_attr":
97
+ params.add(node)
98
+
99
+ for user in node.users.keys():
100
+ if user not in nodes:
101
+ output_nodes.add(node)
102
+
103
+ return SourcePartition(
104
+ nodes,
105
+ module_type,
106
+ list(input_nodes),
107
+ list(output_nodes),
108
+ list(params), # type: ignore[arg-type]
109
+ )
110
+
111
+ ret: Dict[Type[Any], List[SourcePartition]] = {}
112
+
113
+ if filter_fn:
114
+ # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the
115
+ # filter condition
116
+ filtered_modules = {}
117
+ for tp, name_to_partition in modules.items():
118
+ filtered_name_to_partition = {
119
+ name: partition
120
+ for name, partition in name_to_partition.items()
121
+ if all(map(filter_fn, partition))
122
+ }
123
+ filtered_modules[tp] = filtered_name_to_partition
124
+ modules = filtered_modules
125
+
126
+ for k, v in modules.items():
127
+ ret[k] = [make_partition(partition, k) for partition in v.values()]
128
+
129
+ return ret
130
+
131
+
132
+ @compatibility(is_backward_compatible=False)
133
+ def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool:
134
+ """
135
+ Given two subgraphs A and B (in the form of a list of nodes), checks if
136
+ A has nodes connecting to at least one node in B -- aka there exists a node
137
+ in B that uses a node in A (not the other way around).
138
+ """
139
+
140
+ for node in reversed(subgraph1.nodes):
141
+ for user in node.users.keys():
142
+ if user in subgraph2.nodes:
143
+ return True
144
+ return False
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__init__.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module.
2
+
3
+ It registers custom reducers, that use shared memory to provide shared
4
+ views on the same data in different processes. Once the tensor/storage is moved
5
+ to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible
6
+ to send it to other processes without making any copies.
7
+
8
+ The API is 100% compatible with the original module - it's enough to change
9
+ ``import multiprocessing`` to ``import torch.multiprocessing`` to have all the
10
+ tensors sent through the queues or shared via other mechanisms, moved to shared
11
+ memory.
12
+
13
+ Because of the similarity of APIs we do not document most of this package
14
+ contents, and we recommend referring to very good docs of the original module.
15
+ """
16
+ import multiprocessing
17
+ import sys
18
+
19
+ import torch
20
+ from .reductions import init_reductions
21
+
22
+ __all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]
23
+
24
+
25
+ from multiprocessing import * # noqa: F403
26
+
27
+
28
+ __all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined]
29
+
30
+
31
+ # This call adds a Linux specific prctl(2) wrapper function to this module.
32
+ # See https://github.com/pytorch/pytorch/pull/14391 for more information.
33
+ torch._C._multiprocessing_init()
34
+
35
+
36
+ """Add helper function to spawn N processes and wait for completion of any of
37
+ them. This depends `mp.get_context` which was added in Python 3.4."""
38
+ from .spawn import (
39
+ ProcessContext,
40
+ ProcessExitedException,
41
+ ProcessRaisedException,
42
+ spawn,
43
+ SpawnContext,
44
+ start_processes,
45
+ )
46
+
47
+
48
+ if sys.platform == "darwin" or sys.platform == "win32":
49
+ _sharing_strategy = "file_system"
50
+ _all_sharing_strategies = {"file_system"}
51
+ else:
52
+ _sharing_strategy = "file_descriptor"
53
+ _all_sharing_strategies = {"file_descriptor", "file_system"}
54
+
55
+
56
+ def set_sharing_strategy(new_strategy):
57
+ """Set the strategy for sharing CPU tensors.
58
+
59
+ Args:
60
+ new_strategy (str): Name of the selected strategy. Should be one of
61
+ the values returned by :func:`get_all_sharing_strategies()`.
62
+ """
63
+ global _sharing_strategy
64
+ assert new_strategy in _all_sharing_strategies
65
+ _sharing_strategy = new_strategy
66
+
67
+
68
+ def get_sharing_strategy():
69
+ """Return the current strategy for sharing CPU tensors."""
70
+ return _sharing_strategy
71
+
72
+
73
+ def get_all_sharing_strategies():
74
+ """Return a set of sharing strategies supported on a current system."""
75
+ return _all_sharing_strategies
76
+
77
+
78
+ init_reductions()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.72 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/reductions.cpython-311.pyc ADDED
Binary file (19.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/spawn.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ __all__ = ["register_after_fork"]
4
+
5
+ if sys.platform == "win32":
6
+ import multiprocessing.util as _util
7
+
8
+ def _register(func):
9
+ def wrapper(arg):
10
+ func()
11
+
12
+ _util.register_after_fork(_register, wrapper)
13
+
14
+ else:
15
+ import os
16
+
17
+ def _register(func):
18
+ os.register_at_fork(after_in_child=func)
19
+
20
+
21
+ def register_after_fork(func):
22
+ """Register a callable to be executed in the child process after a fork.
23
+
24
+ Note:
25
+ In python < 3.7 this will only work with processes created using the
26
+ ``multiprocessing`` module. In python >= 3.7 it also works with
27
+ ``os.fork()``.
28
+
29
+ Args:
30
+ func (function): Function taking no arguments to be called in the child after fork
31
+
32
+ """
33
+ _register(func)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__init__.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
2
+ import contextlib
3
+ from typing import List, Union
4
+ from warnings import warn
5
+
6
+ from torch.backends.cuda import (
7
+ can_use_efficient_attention,
8
+ can_use_flash_attention,
9
+ enable_flash_sdp,
10
+ enable_math_sdp,
11
+ enable_mem_efficient_sdp,
12
+ flash_sdp_enabled,
13
+ math_sdp_enabled,
14
+ mem_efficient_sdp_enabled,
15
+ SDPAParams,
16
+ )
17
+
18
+ __all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
19
+
20
+ # Note: [SDPA warnings]
21
+ # TODO: Consider using this for sdpa regardless of subclasses
22
+ # This only effects users of bias subclasses
23
+ # If this is set to True, we will warn the user if they are not using the fused kernels
24
+ # As well, it will raise warnings for all the reasons why the fused kernels can't be run.
25
+ # To set this to True, run
26
+ # torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True
27
+ WARN_FOR_UNFUSED_KERNELS = False
28
+
29
+
30
+ from torch._C import _SDPBackend as SDPBackend
31
+
32
+ # Hacks for Sphinx documentation:
33
+ # https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class
34
+ SDPBackend = SDPBackend
35
+ r"""An enum-like class that contains the different backends for scaled dot product attention.
36
+ This backend class is designed to be used with the sdpa_kernel context manager.
37
+
38
+ The following Enums are available:
39
+ - ERROR: An error occurred when trying to determine the backend.
40
+ - MATH: The math backend for scaled dot product attention.
41
+ - FLASH_ATTENTION: The flash attention backend for scaled dot product attention.
42
+ - EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention.
43
+ - CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention.
44
+
45
+ See :func:`torch.nn.attention.sdpa_kernel` for more details.
46
+
47
+ .. warning:: This class is in beta and subject to change.
48
+ """
49
+ SDPBackend.__module__ = __name__
50
+ SDPBackend.__name__ = "SDPBackend"
51
+
52
+
53
+ def _raise_kernel_warnings(params: SDPAParams) -> None:
54
+ """
55
+ If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings
56
+ for all the reasons why the fused kernels can't be run. If using subclasses
57
+ """
58
+ if WARN_FOR_UNFUSED_KERNELS:
59
+ if not can_use_efficient_attention(params):
60
+ warn("Efficient attention can't be used because:")
61
+ can_use_efficient_attention(params, True)
62
+ if not can_use_flash_attention(params):
63
+ warn("Flash attention can't be used because:")
64
+ can_use_flash_attention(params, True)
65
+
66
+
67
+ @contextlib.contextmanager
68
+ def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]):
69
+ r"""
70
+ Context manager to select which backend to use for scaled dot product attention.
71
+
72
+ .. warning:: This function is beta and subject to change.
73
+
74
+ Args:
75
+ backend (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention.
76
+
77
+ Example:
78
+
79
+ .. code-block:: python
80
+
81
+ from torch.nn.functional import scaled_dot_product_attention
82
+ from torch.nn.attention import SDPBackend, sdpa_kernel
83
+ # Only enable flash attention backend
84
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
85
+ scaled_dot_product_attention(...)
86
+
87
+ # Enable the Math or Efficient attention backends
88
+ with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
89
+ scaled_dot_product_attention(...)
90
+
91
+ This context manager can be used to select which backend to use for scaled dot product attention.
92
+ Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.
93
+ """
94
+ assert isinstance(
95
+ backends, (list, SDPBackend)
96
+ ), "Backend must be an instance of SDPBackend or a list of SDPBackend instances"
97
+
98
+ if isinstance(backends, SDPBackend):
99
+ backends = [backends]
100
+
101
+ backends = set(backends)
102
+ previous_flash: bool = flash_sdp_enabled()
103
+ previous_mem_efficient: bool = mem_efficient_sdp_enabled()
104
+ previous_math: bool = math_sdp_enabled()
105
+ try:
106
+ enable_flash = SDPBackend.FLASH_ATTENTION in backends
107
+ enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends
108
+ enable_math = SDPBackend.MATH in backends
109
+
110
+ enable_flash_sdp(enable_flash)
111
+ enable_mem_efficient_sdp(enable_mem_efficient)
112
+ enable_math_sdp(enable_math)
113
+ yield {}
114
+ finally:
115
+ enable_flash_sdp(previous_flash)
116
+ enable_mem_efficient_sdp(previous_mem_efficient)
117
+ enable_math_sdp(previous_math)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.52 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/bias.cpython-311.pyc ADDED
Binary file (15.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/bias.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Defines bias subclasses that work with scaled_dot_product_attention"""
2
+ from enum import auto, IntEnum
3
+ from typing import Optional
4
+ from warnings import warn
5
+
6
+ import torch
7
+ from torch.backends.cuda import (
8
+ can_use_efficient_attention,
9
+ can_use_flash_attention,
10
+ SDPAParams,
11
+ )
12
+ from torch.nn.attention import _raise_kernel_warnings
13
+ from torch.nn.attention._utils import (
14
+ _calculate_scale,
15
+ _input_requires_grad,
16
+ _postprocess_flash_output,
17
+ _validate_sdpa_input,
18
+ )
19
+ from torch.nn.functional import scaled_dot_product_attention
20
+
21
+ __all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
22
+
23
+
24
+ torch._dynamo.allow_in_graph(can_use_flash_attention)
25
+ torch._dynamo.allow_in_graph(can_use_efficient_attention)
26
+ torch._dynamo.allow_in_graph(SDPAParams)
27
+
28
+
29
+ class CausalVariant(IntEnum):
30
+ r"""
31
+ Enum for causal variants used in attention mechanisms.
32
+
33
+ Defines two types of causal biases:
34
+
35
+ `UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention.
36
+ The equivalent pytorch code for constructing this bias is:
37
+
38
+ .. code-block:: python
39
+
40
+ torch.tril(torch.ones(size, dtype=torch.bool))
41
+
42
+ For instance, with `shape=(3,4)`, the materialized bias tensor will be:
43
+
44
+ .. code-block:: text
45
+
46
+ [[1, 0, 0, 0],
47
+ [1, 1, 0, 0],
48
+ [1, 1, 1, 0]]
49
+
50
+
51
+ `LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower
52
+ right corner of the matrix.
53
+
54
+ The equivalent pytorch code for constructing this bias is:
55
+
56
+ .. code-block:: python
57
+
58
+ diagonal_offset = size[1] - size[0]
59
+ torch.tril(
60
+ torch.ones(size, dtype=torch.bool),
61
+ diagonal=diagonal_offset,
62
+ )
63
+
64
+ For instance, with `shape=(3,4)`, the materialized bias tensor will be:
65
+
66
+ .. code-block:: text
67
+
68
+ [[1, 1, 0, 0],
69
+ [1, 1, 1, 0],
70
+ [1, 1, 1, 1]]
71
+
72
+ Note that these variants are equivalent to each other when the sequence lengths of the query and key/value
73
+ tensors are equal since the triangular matrix is square.
74
+
75
+ .. warning:: This enum is a prototype and subject to change.
76
+ """
77
+
78
+ UPPER_LEFT = auto()
79
+ LOWER_RIGHT = auto()
80
+
81
+
82
+ class CausalBias(torch.Tensor):
83
+ """
84
+ A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.
85
+
86
+ This class is used for defining causal (triangular) attention biases. For construing the bias, there exist
87
+ two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`.
88
+
89
+ Example:
90
+
91
+ .. code-block:: python
92
+
93
+ from torch.nn.attention.bias import causal_lower_right
94
+
95
+ bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8
96
+
97
+ # Create a lower-right causal bias
98
+ attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
99
+
100
+ q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
101
+ k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
102
+ v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
103
+
104
+ out = F.scaled_dot_product_attention(q, k, v, attn_bias)
105
+
106
+ .. warning:: This class is a prototype and subject to change.
107
+ """
108
+
109
+ def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int):
110
+ """
111
+ Initializes the CausalBias instance with a specified variant and sequence lengths.
112
+
113
+ Args:
114
+ variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT).
115
+ seq_len_q (int): The sequence length of the query tensor.
116
+ seq_len_kv (int): The sequence length of the key/value tensor.
117
+
118
+ Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs.
119
+ """
120
+ assert isinstance(variant, CausalVariant)
121
+ self.variant = variant
122
+ self.seq_len_q = seq_len_q
123
+ self.seq_len_kv = seq_len_kv
124
+ if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
125
+ warn(
126
+ "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!"
127
+ )
128
+
129
+ def _upper_left(self, device: torch.device) -> torch.Tensor:
130
+ """Upper left causal bias"""
131
+ return torch.tril(
132
+ torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool)
133
+ )
134
+
135
+ def _lower_right(self, device: torch.device) -> torch.Tensor:
136
+ """Lower right causal bias"""
137
+ diagonal_offset = self.seq_len_kv - self.seq_len_q
138
+ return torch.tril(
139
+ torch.ones(
140
+ self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool
141
+ ),
142
+ diagonal=diagonal_offset,
143
+ )
144
+
145
+ def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
146
+ """
147
+ Materializes the causal bias into a tensor form.
148
+
149
+ Depending on the variant, this method generates either an upper-left or lower-right
150
+ triangular matrix to represent the causal bias.
151
+
152
+ Args:
153
+ device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU.
154
+
155
+ Returns:
156
+ torch.Tensor: The materialized bias tensor.
157
+ """
158
+ if device is None:
159
+ device = torch.device("cpu")
160
+ if self.variant == CausalVariant.UPPER_LEFT:
161
+ return self._upper_left(device)
162
+ elif self.variant == CausalVariant.LOWER_RIGHT:
163
+ return self._lower_right(device)
164
+
165
+ @staticmethod
166
+ def _dispatch(
167
+ query: torch.Tensor,
168
+ key: torch.Tensor,
169
+ value: torch.Tensor,
170
+ attn_mask: "CausalBias",
171
+ dropout_p: float = 0.0,
172
+ is_causal: bool = False,
173
+ scale: Optional[float] = None,
174
+ ) -> torch.Tensor:
175
+ r"""
176
+ Handles the logic for computing attention with the specified causal bias.
177
+
178
+ Args:
179
+ query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
180
+ key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
181
+ value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
182
+ attn_mask (CausalBias): The type of causal attention to apply.
183
+ A boolean mask where a value of True indicates that the element *should* take part in attention.
184
+ A float mask of the same type as query, key, value that is added to the attention score.
185
+ dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
186
+ is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal
187
+ are set.
188
+ scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
189
+ to :math:`\frac{1}{\sqrt{E}}`.
190
+
191
+ Returns:
192
+ output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
193
+
194
+ Raises:
195
+ ValueError: If the causal bias variant is not a CausalVariant type.
196
+
197
+ """
198
+ if is_causal:
199
+ raise ValueError("CausalBias should not be used with causal=True")
200
+
201
+ if (
202
+ attn_mask.seq_len_q == attn_mask.seq_len_kv
203
+ or attn_mask.variant == CausalVariant.UPPER_LEFT
204
+ ):
205
+ return scaled_dot_product_attention(
206
+ query,
207
+ key,
208
+ value,
209
+ attn_mask=None,
210
+ dropout_p=dropout_p,
211
+ is_causal=True,
212
+ scale=scale,
213
+ )
214
+ elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
215
+ _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
216
+ sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
217
+ if can_use_flash_attention(sdpa_params):
218
+ needs_padding = query.size(-1) % 8 != 0
219
+ og_head_size = query.size(-1)
220
+ og_scale = _calculate_scale(og_head_size, scale)
221
+ if needs_padding:
222
+ query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8))
223
+ key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8))
224
+ value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8))
225
+ out = torch.ops.aten._scaled_dot_product_flash_attention(
226
+ query,
227
+ key,
228
+ value,
229
+ dropout_p,
230
+ is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right
231
+ return_debug_mask=False,
232
+ scale=og_scale,
233
+ )[0]
234
+ return _postprocess_flash_output(out, og_head_size)
235
+ if can_use_efficient_attention(sdpa_params):
236
+ compute_log_sumexp = False
237
+ if _input_requires_grad(query, key, value):
238
+ compute_log_sumexp = True
239
+ return torch.ops.aten._efficient_attention_forward(
240
+ query.transpose(1, 2),
241
+ key.transpose(1, 2),
242
+ value.transpose(1, 2),
243
+ bias=None,
244
+ cu_seqlens_q=None,
245
+ cu_seqlens_k=None,
246
+ max_seqlen_q=None,
247
+ max_seqlen_k=None,
248
+ dropout_p=dropout_p,
249
+ custom_mask_type=int(attn_mask.variant),
250
+ compute_log_sumexp=compute_log_sumexp,
251
+ scale=scale,
252
+ causal_diagonal=None,
253
+ seqlen_k=None,
254
+ )[0].transpose(1, 2)
255
+ else:
256
+ _raise_kernel_warnings(sdpa_params)
257
+ # We cant use efficient attention the only support for lower right is via materialization
258
+ return scaled_dot_product_attention(
259
+ query,
260
+ key,
261
+ value,
262
+ attn_mask=attn_mask._materialize(query.device),
263
+ dropout_p=dropout_p,
264
+ is_causal=False,
265
+ scale=scale,
266
+ )
267
+ else:
268
+ raise ValueError(
269
+ f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}"
270
+ )
271
+
272
+ @classmethod
273
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
274
+ """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""
275
+ if kwargs is None:
276
+ kwargs = {}
277
+ if func != torch.nn.functional.scaled_dot_product_attention:
278
+ raise NotImplementedError(
279
+ "CausalBias only supports scaled_dot_product_attention"
280
+ )
281
+ return cls._dispatch(*args, **kwargs)
282
+
283
+ def __repr__(self):
284
+ return self._materialize().__repr__()
285
+
286
+
287
+ def causal_upper_left(*size) -> CausalBias:
288
+ """
289
+ Creates an upper-left triangular causal bias.
290
+
291
+ This function generates a upper-left triangular matrix to represent causal attention bias with a
292
+ diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix.
293
+ This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`.
294
+
295
+ The equivalent pytorch code for constructing this bias is:
296
+
297
+ .. code-block:: python
298
+
299
+ torch.tril(torch.ones(size, dtype=torch.bool))
300
+
301
+ For instance, with `shape=(3,4)`, the materialized bias tensor will be:
302
+
303
+ .. code-block:: text
304
+
305
+ [[1, 0, 0, 0],
306
+ [1, 1, 0, 0],
307
+ [1, 1, 1, 0]]
308
+
309
+ Args:
310
+ size: The size of the bias matrix.
311
+
312
+ Returns:
313
+ CausalBias: The UPPER_LEFT triangular causal bias variant.
314
+ """
315
+ assert len(size) == 2, "causal_upper_left only supports 2D tensors"
316
+ seq_len_q, seq_len_kv = size
317
+ return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv)
318
+
319
+
320
+ def causal_lower_right(*size) -> CausalBias:
321
+ """
322
+ Creates a lower-right triangular causal bias.
323
+
324
+ This function generates a lower-right triangular matrix to represent causal attention bias with a
325
+ diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix.
326
+
327
+ The equivalent pytorch code for constructing this bias is:
328
+
329
+ .. code-block:: python
330
+
331
+ diagonal_offset = size[1] - size[0]
332
+ torch.tril(
333
+ torch.ones(size, dtype=torch.bool),
334
+ diagonal=diagonal_offset,
335
+ )
336
+
337
+ For instance, with `shape=(3,4)`, the materialized bias tensor will be:
338
+
339
+ .. code-block:: text
340
+
341
+ [[1, 1, 0, 0],
342
+ [1, 1, 1, 0],
343
+ [1, 1, 1, 1]]
344
+
345
+ Args:
346
+ size: The size of the bias matrix.
347
+
348
+ Returns:
349
+ CausalBias: The LOWER_RIGHT triangular causal bias variant.
350
+ """
351
+ assert len(size) == 2, "causal_lower_right only supports 2D tensors"
352
+ seq_len_q, seq_len_kv = size
353
+ return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/common_types.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypeVar, Union, Tuple, Optional
2
+ from .. import Tensor
3
+
4
+ # Create some useful type aliases
5
+
6
+ # Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally
7
+ # broadcast to a tuple.
8
+ # Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations.
9
+ T = TypeVar('T')
10
+ _scalar_or_tuple_any_t = Union[T, Tuple[T, ...]]
11
+ _scalar_or_tuple_1_t = Union[T, Tuple[T]]
12
+ _scalar_or_tuple_2_t = Union[T, Tuple[T, T]]
13
+ _scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]]
14
+ _scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]]
15
+ _scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]]
16
+ _scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]]
17
+
18
+ # For arguments which represent size parameters (eg, kernel size, padding)
19
+ _size_any_t = _scalar_or_tuple_any_t[int]
20
+ _size_1_t = _scalar_or_tuple_1_t[int]
21
+ _size_2_t = _scalar_or_tuple_2_t[int]
22
+ _size_3_t = _scalar_or_tuple_3_t[int]
23
+ _size_4_t = _scalar_or_tuple_4_t[int]
24
+ _size_5_t = _scalar_or_tuple_5_t[int]
25
+ _size_6_t = _scalar_or_tuple_6_t[int]
26
+
27
+ # For arguments which represent optional size parameters (eg, adaptive pool parameters)
28
+ _size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]]
29
+ _size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]]
30
+ _size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]]
31
+
32
+ # For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters)
33
+ _ratio_2_t = _scalar_or_tuple_2_t[float]
34
+ _ratio_3_t = _scalar_or_tuple_3_t[float]
35
+ _ratio_any_t = _scalar_or_tuple_any_t[float]
36
+
37
+ _tensor_list_t = _scalar_or_tuple_any_t[Tensor]
38
+
39
+ # For the return value of max pooling operations that may or may not return indices.
40
+ # With the proposed 'Literal' feature to Python typing, it might be possible to
41
+ # eventually eliminate this.
42
+ _maybe_indices_t = _scalar_or_tuple_2_t[Tensor]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/grad.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradient interface."""
2
+
3
+ import torch
4
+ from .modules.utils import _single, _pair, _triple
5
+
6
+
7
+ def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
8
+ r"""Compute the gradient of conv1d with respect to the input of the convolution.
9
+
10
+ This is same as the 1D transposed convolution operator under the hood but requires
11
+ the shape of the gradient w.r.t. input to be specified explicitly.
12
+
13
+ Args:
14
+ input_size : Shape of the input gradient tensor
15
+ weight: weight tensor (out_channels x in_channels/groups x kW)
16
+ grad_output : output gradient tensor (minibatch x out_channels x oW)
17
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
18
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
19
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
20
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
21
+
22
+ Examples::
23
+
24
+ >>> input = torch.randn(1, 1, 3, requires_grad=True)
25
+ >>> weight = torch.randn(1, 1, 1, requires_grad=True)
26
+ >>> output = F.conv1d(input, weight)
27
+ >>> grad_output = torch.randn(output.shape)
28
+ >>> grad_input = torch.autograd.grad(output, input, grad_output)
29
+ >>> F.grad.conv1d_input(input.shape, weight, grad_output)
30
+
31
+ """
32
+ input = grad_output.new_empty(1).expand(input_size)
33
+
34
+ return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
35
+ _single(stride), _single(padding), _single(dilation),
36
+ False, [0], groups, (True, False, False))[0]
37
+
38
+
39
+ def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
40
+ r"""Compute the gradient of conv1d with respect to the weight of the convolution.
41
+
42
+ Args:
43
+ input: input tensor of shape (minibatch x in_channels x iW)
44
+ weight_size : Shape of the weight gradient tensor
45
+ grad_output : output gradient tensor (minibatch x out_channels x oW)
46
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
47
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
48
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
49
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
50
+
51
+ Examples::
52
+
53
+ >>> input = torch.randn(1, 1, 3, requires_grad=True)
54
+ >>> weight = torch.randn(1, 1, 1, requires_grad=True)
55
+ >>> output = F.conv1d(input, weight)
56
+ >>> grad_output = torch.randn(output.shape)
57
+ >>> # xdoctest: +SKIP
58
+ >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
59
+ >>> F.grad.conv1d_weight(input, weight.shape, grad_output)
60
+
61
+ """
62
+ weight = grad_output.new_empty(1).expand(weight_size)
63
+
64
+ return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
65
+ _single(stride), _single(padding), _single(dilation),
66
+ False, [0], groups, (False, True, False))[1]
67
+
68
+
69
+ def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
70
+ r"""Compute the gradient of conv2d with respect to the input of the convolution.
71
+
72
+ This is same as the 2D transposed convolution operator under the hood but requires
73
+ the shape of the gradient w.r.t. input to be specified explicitly.
74
+
75
+ Args:
76
+ input_size : Shape of the input gradient tensor
77
+ weight: weight tensor (out_channels x in_channels/groups x kH x kW)
78
+ grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
79
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
80
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
81
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
82
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
83
+
84
+ Examples::
85
+
86
+ >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
87
+ >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
88
+ >>> output = F.conv2d(input, weight)
89
+ >>> grad_output = torch.randn(output.shape)
90
+ >>> grad_input = torch.autograd.grad(output, input, grad_output)
91
+ >>> F.grad.conv2d_input(input.shape, weight, grad_output)
92
+
93
+ """
94
+ input = grad_output.new_empty(1).expand(input_size)
95
+
96
+ return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
97
+ _pair(stride), _pair(padding), _pair(dilation),
98
+ False, [0], groups, (True, False, False))[0]
99
+
100
+
101
+ def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
102
+ r"""Compute the gradient of conv2d with respect to the weight of the convolution.
103
+
104
+ Args:
105
+ input: input tensor of shape (minibatch x in_channels x iH x iW)
106
+ weight_size : Shape of the weight gradient tensor
107
+ grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
108
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
109
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
110
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
111
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
112
+
113
+ Examples::
114
+
115
+ >>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
116
+ >>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
117
+ >>> output = F.conv2d(input, weight)
118
+ >>> grad_output = torch.randn(output.shape)
119
+ >>> # xdoctest: +SKIP
120
+ >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
121
+ >>> F.grad.conv2d_weight(input, weight.shape, grad_output)
122
+
123
+ """
124
+ weight = grad_output.new_empty(1).expand(weight_size)
125
+
126
+ return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
127
+ _pair(stride), _pair(padding), _pair(dilation),
128
+ False, [0], groups, (False, True, False))[1]
129
+
130
+
131
+ def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
132
+ r"""Compute the gradient of conv3d with respect to the input of the convolution.
133
+
134
+ This is same as the 3D transposed convolution operator under the hood but requires
135
+ the shape of the gradient w.r.t. input to be specified explicitly.
136
+
137
+ Args:
138
+ input_size : Shape of the input gradient tensor
139
+ weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW)
140
+ grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
141
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
142
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
143
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
144
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
145
+
146
+ Examples::
147
+
148
+ >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
149
+ >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
150
+ >>> output = F.conv3d(input, weight)
151
+ >>> grad_output = torch.randn(output.shape)
152
+ >>> grad_input = torch.autograd.grad(output, input, grad_output)
153
+ >>> F.grad.conv3d_input(input.shape, weight, grad_output)
154
+
155
+ """
156
+ input = grad_output.new_empty(1).expand(input_size)
157
+
158
+ return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
159
+ _triple(stride), _triple(padding), _triple(dilation),
160
+ False, [0], groups, (True, False, False))[0]
161
+
162
+
163
+ def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
164
+ r"""Compute the gradient of conv3d with respect to the weight of the convolution.
165
+
166
+ Args:
167
+ input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
168
+ weight_size : Shape of the weight gradient tensor
169
+ grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
170
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
171
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
172
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
173
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
174
+
175
+ Examples::
176
+
177
+ >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
178
+ >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
179
+ >>> output = F.conv3d(input, weight)
180
+ >>> grad_output = torch.randn(output.shape)
181
+ >>> grad_weight = torch.autograd.grad(output, weight, grad_output)
182
+ >>> F.grad.conv3d_weight(input, weight.shape, grad_output)
183
+
184
+ """
185
+ weight = grad_output.new_empty(1).expand(weight_size)
186
+
187
+ return torch.ops.aten.convolution_backward(grad_output, input, weight, None,
188
+ _triple(stride), _triple(padding), _triple(dilation),
189
+ False, [0], groups, (False, True, False))[1]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .linear_relu import LinearReLU
2
+ from .linear_fused import LinearBn1d
3
+ from .conv_fused import (
4
+ ConvBn1d,
5
+ ConvBn2d,
6
+ ConvBn3d,
7
+ ConvBnReLU1d,
8
+ ConvBnReLU2d,
9
+ ConvBnReLU3d,
10
+ ConvReLU1d,
11
+ ConvReLU2d,
12
+ ConvReLU3d,
13
+ update_bn_stats,
14
+ freeze_bn_stats,
15
+ )
16
+
17
+ __all__ = [
18
+ "LinearReLU",
19
+ "LinearBn1d",
20
+ "ConvReLU1d",
21
+ "ConvReLU2d",
22
+ "ConvReLU3d",
23
+ "ConvBn1d",
24
+ "ConvBn2d",
25
+ "ConvBn3d",
26
+ "ConvBnReLU1d",
27
+ "ConvBnReLU2d",
28
+ "ConvBnReLU3d",
29
+ "update_bn_stats",
30
+ "freeze_bn_stats",
31
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (846 Bytes). View file