koichi12 commited on
Commit
445c885
·
verified ·
1 Parent(s): a8eed2c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe +3 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py +1451 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py +105 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py +506 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/exc.py +98 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py +220 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py +277 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/metrics.py +419 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_helpers.py +344 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocator.h +401 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h +61 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSEvent.h +100 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSProfiler.h +393 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h +321 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/EmbeddingBag.h +139 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h +21 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h +72 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h +11 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pow.h +69 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOps.h +56 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h +55 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorCompare.h +49 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIterator.h +2 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TriangularOpsUtils.h +57 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h +62 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h +28 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh +296 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.h +32 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh +384 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh +379 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h +61 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sort.h +17 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh +40 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh +38 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh +40 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh +38 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh +680 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/thread_constants.h +22 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/OperationUtils.h +394 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/TensorFactory.h +12 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h +103 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h +130 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/ConvUtils.h +62 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/IndexKernel.h +14 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/PackedParams.h +147 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h +29 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h +527 -0
.gitattributes CHANGED
@@ -77,3 +77,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/_
77
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
78
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
79
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 filter=lfs diff=lfs merge=lfs -text
 
 
77
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
78
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
79
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn.so.8 filter=lfs diff=lfs merge=lfs -text
80
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64-arm.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebc4c06b7d95e74e315419ee7e88e1d0f71e9e9477538c00a93a9ff8c66a6cfc
3
+ size 182784
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/codecache.py ADDED
The diff for this file is too large to render. See raw diff
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/compile_fx.py ADDED
@@ -0,0 +1,1451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import functools
3
+ import logging
4
+ import os
5
+ import sys
6
+ import time
7
+ import warnings
8
+ from itertools import count
9
+
10
+ from typing import (
11
+ Any,
12
+ Callable,
13
+ Dict,
14
+ FrozenSet,
15
+ List,
16
+ Optional,
17
+ Sequence,
18
+ Tuple,
19
+ Union,
20
+ )
21
+ from unittest import mock
22
+
23
+ from functorch.compile import min_cut_rematerialization_partition
24
+
25
+ import torch.fx
26
+ import torch.utils._pytree as pytree
27
+ from torch._dynamo import (
28
+ compiled_autograd,
29
+ config as dynamo_config,
30
+ logging as dynamo_logging,
31
+ utils as dynamo_utils,
32
+ )
33
+ from torch._dynamo.utils import (
34
+ counters,
35
+ detect_fake_mode,
36
+ lazy_format_graph_code,
37
+ optimus_scuba_log,
38
+ )
39
+ from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
40
+ from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
41
+ from torch._inductor.cudagraph_utils import BoxedDeviceIndex
42
+
43
+ from torch._inductor.debug import save_args_for_compile_fx_inner
44
+ from torch._inductor.utils import BoxedBool, count_tangents
45
+ from torch._logging import trace_structured
46
+ from torch._ops import OpOverload
47
+ from torch._subclasses.fake_tensor import FakeTensor
48
+ from torch._utils_internal import signpost_event
49
+ from torch.fx.passes.fake_tensor_prop import FakeTensorProp
50
+
51
+ from .._dynamo.backends.common import aot_autograd
52
+ from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined]
53
+ from ..fx.graph import _PyTreeCodeGen
54
+ from . import config, metrics
55
+ from .debug import DebugContext
56
+ from .decomposition import select_decomp_table
57
+ from .fx_passes.joint_graph import joint_graph_passes
58
+ from .fx_passes.post_grad import post_grad_passes, view_to_reshape
59
+ from .fx_passes.pre_grad import pre_grad_passes
60
+ from .graph import GraphLowering
61
+ from .ir import ExternKernelNode
62
+ from .utils import get_dtype_size, has_incompatible_cudagraph_ops, output_node
63
+ from .virtualized import V
64
+
65
+ if config.is_fbcode():
66
+ from torch._inductor.fb.utils import time_and_log
67
+ else:
68
+ # no-op decorator
69
+ def time_and_log(attr: str, extra_loggings: Optional[Dict[str, str]] = None):
70
+ return dynamo_utils.identity
71
+
72
+
73
+ log = logging.getLogger(__name__)
74
+ perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
75
+ post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs")
76
+ ALIGNMENT = 16
77
+
78
+
79
+ # copy_ fails when trying to write to tensors with memory overlap,
80
+ # for expanded dimensions (a dimension which used to have size 1 -> ?)
81
+ # we can select one element from that dimension and write to it
82
+ # to achieve writing to all values of that dimension of the input tensor
83
+ def get_expanded_dims(t):
84
+ if not isinstance(t, torch.Tensor):
85
+ return None
86
+ return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
87
+
88
+
89
+ def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
90
+ for expanded_dim in expanded_dims:
91
+ t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
92
+ return t
93
+
94
+
95
+ def complex_memory_overlap(t: torch.Tensor) -> bool:
96
+ # if torch._debug_has_internal_overlap thinks this tensor potentially has
97
+ # memory overlap internally, let's dig deeper to find out whether it's true.
98
+ t = index_expanded_dims(t, get_expanded_dims(t))
99
+ if torch._debug_has_internal_overlap(t) != 0:
100
+ strides = t.stride()
101
+ sizes = t.shape
102
+ indices = list(range(len(strides)))
103
+ indices = [x for _, x in sorted(zip(strides, indices))]
104
+ for i in range(len(strides)):
105
+ prev_stride = 1 if i == 0 else strides[indices[i - 1]]
106
+ prev_size = 1 if i == 0 else sizes[indices[i - 1]]
107
+ if strides[indices[i]] < prev_stride * prev_size:
108
+ return True
109
+ return False
110
+
111
+
112
+ @functools.lru_cache(None)
113
+ def _step_logger():
114
+ return dynamo_logging.get_step_logger(log)
115
+
116
+
117
+ @functools.lru_cache(None)
118
+ def _warn_tf32_disabled():
119
+ if (
120
+ torch.cuda.is_available()
121
+ and not torch.backends.cuda.matmul.allow_tf32
122
+ and torch.cuda.get_device_capability() >= (8, 0)
123
+ ):
124
+ warnings.warn(
125
+ "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
126
+ "Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
127
+ )
128
+
129
+
130
+ def _unlift_graph(mod, gm, graph_signature):
131
+ from torch.export.unflatten import _assign_attr, _AttrKind
132
+
133
+ state_dict = {}
134
+ for name, param in mod.named_parameters(remove_duplicate=False):
135
+ state_dict[name] = param
136
+ _assign_attr(
137
+ param,
138
+ gm,
139
+ name,
140
+ attr_kind=_AttrKind.PARAMETER,
141
+ )
142
+ for name, buffer in mod.named_buffers(remove_duplicate=False):
143
+ state_dict[name] = buffer
144
+ _assign_attr(
145
+ buffer,
146
+ gm,
147
+ name,
148
+ attr_kind=_AttrKind.BUFFER,
149
+ )
150
+
151
+ placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
152
+ lifted_inputs = []
153
+ for node in placeholder_nodes:
154
+ node_name = node.name
155
+ if node_name in graph_signature.inputs_to_parameters:
156
+ lifted_inputs.append(graph_signature.inputs_to_parameters[node_name])
157
+ elif node_name in graph_signature.inputs_to_buffers:
158
+ lifted_inputs.append(graph_signature.inputs_to_buffers[node_name])
159
+ else:
160
+ assert node_name in graph_signature.user_inputs
161
+ lifted_inputs.append(None)
162
+
163
+ from torch.export._unlift import _unlift
164
+
165
+ outputs = list(gm.graph.nodes)[-1].args[0]
166
+ mutated_outputs = []
167
+ for out in outputs:
168
+ if out in graph_signature.buffers_to_mutate:
169
+ mutated_outputs.append(graph_signature.buffers_to_mutate[out.name])
170
+ else:
171
+ mutated_outputs.append(None)
172
+
173
+ unlifted_gm = _unlift(
174
+ gm,
175
+ lifted_inputs,
176
+ mutated_outputs,
177
+ pytree.LeafSpec(),
178
+ None,
179
+ state_dict,
180
+ {},
181
+ )
182
+ return unlifted_gm
183
+
184
+
185
+ def _get_subgraph_names(gm):
186
+ for node in gm.graph.nodes:
187
+ if node.target == torch.ops.higher_order.cond:
188
+ true_subgraph_name = node.args[1].name
189
+ false_subgraph_name = node.args[2].name
190
+ yield true_subgraph_name
191
+ yield false_subgraph_name
192
+
193
+
194
+ def _recursive_pre_grad_passes(gm, example_inputs):
195
+ for subgraph_name in _get_subgraph_names(gm):
196
+ subgraph = getattr(gm, subgraph_name)
197
+ # as we don't have recursive example inputs, passing None here
198
+ new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None)
199
+ setattr(gm, subgraph_name, new_subgraph)
200
+ return pre_grad_passes(gm, example_inputs)
201
+
202
+
203
+ def _recursive_joint_graph_passes(gm):
204
+ for subgraph_name in _get_subgraph_names(gm):
205
+ subgraph = getattr(gm, subgraph_name)
206
+ _recursive_joint_graph_passes(subgraph)
207
+ joint_graph_passes(gm)
208
+
209
+
210
+ def _recursive_post_grad_passes(gm, is_inference: bool = False):
211
+ for subgraph_name in _get_subgraph_names(gm):
212
+ subgraph = getattr(gm, subgraph_name)
213
+ _recursive_post_grad_passes(subgraph, is_inference)
214
+ post_grad_passes(gm, is_inference)
215
+
216
+
217
+ def split_const_gm(
218
+ gm: torch.fx.GraphModule,
219
+ ) -> Tuple[torch.fx.GraphModule, Dict[str, int]]:
220
+ """
221
+ This function takes an GraphModule input "gm".
222
+ The gm will be split into 2 components,
223
+ 1) const_gm, which consists the subgraph of gm that can be constant folded.
224
+ 2) gm (being inplace modified,) which returns the graph after constant folding.
225
+
226
+ const_output_index is a mapping of corresponding node name from gm to the
227
+ output index of const_gm.
228
+ Returns (const_gm, const_output_index)
229
+ """
230
+ from torch._inductor.constant_folding import (
231
+ CONST_MODULE_TAG,
232
+ META_TAG,
233
+ MODULE_TAG,
234
+ replace_node_with_constant,
235
+ run_and_get_constant_graph,
236
+ )
237
+
238
+ const_gm = run_and_get_constant_graph(gm)
239
+ const_result = const_gm()
240
+
241
+ const_outputs = {
242
+ x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
243
+ }
244
+
245
+ to_erase_node = []
246
+ to_replace_node = []
247
+ const_output_index = {}
248
+ for node in gm.graph.nodes:
249
+ if node.name in const_outputs:
250
+ to_replace_node.append(node)
251
+ elif node.meta[META_TAG] == CONST_MODULE_TAG:
252
+ to_erase_node.append(node)
253
+
254
+ for node in to_replace_node:
255
+ new_const_name = "_FOLDED_CONST_" + node.name
256
+ replace_node_with_constant(
257
+ gm,
258
+ node,
259
+ const_result[const_outputs[node.name]],
260
+ new_const_name,
261
+ )
262
+ const_output_index[new_const_name] = const_outputs[node.name]
263
+ for node in to_erase_node[::-1]:
264
+ if node.users:
265
+ for n in node.users:
266
+ assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty."
267
+ else:
268
+ gm.graph.erase_node(node)
269
+ gm.recompile()
270
+
271
+ return const_gm, const_output_index
272
+
273
+
274
+ def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
275
+ aten = torch.ops.aten
276
+ tf32_ops = {
277
+ aten.mm.default,
278
+ aten.addmm.default,
279
+ aten.bmm.default,
280
+ aten.baddbmm.default,
281
+ }
282
+ for node in gm.graph.nodes:
283
+ if (
284
+ node.op == "call_function"
285
+ and node.target in tf32_ops
286
+ and isinstance(node.meta.get("val", None), torch.Tensor)
287
+ and node.meta["val"].dtype == torch.float32
288
+ and node.meta["val"].device.type == "cuda"
289
+ ):
290
+ return True
291
+ return False
292
+
293
+
294
+ @DebugContext.wrap
295
+ def count_bytes_inner(
296
+ gm: torch.fx.GraphModule,
297
+ example_inputs: List[torch.Tensor],
298
+ num_fixed: int = 0,
299
+ **kwargs,
300
+ ):
301
+ shape_env = _shape_env_from_inputs(example_inputs)
302
+ fake_mode = fake_tensor_prop(gm, example_inputs)
303
+
304
+ with V.set_fake_mode(fake_mode):
305
+ _recursive_post_grad_passes(gm, False)
306
+
307
+ graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
308
+ with V.set_graph_handler(graph), V.set_real_inputs(example_inputs):
309
+ graph.run(*example_inputs)
310
+ num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
311
+ metrics.num_bytes_accessed += num_bytes
312
+ metrics.nodes_num_elem += nodes_num_elem
313
+ metrics.node_runtimes += node_runtimes
314
+ return make_boxed_func(gm.forward)
315
+
316
+
317
+ def fake_tensor_prop(
318
+ gm: torch.fx.GraphModule,
319
+ example_inputs: List[torch.Tensor],
320
+ force_allow_non_fake_inputs: bool = False,
321
+ ):
322
+ """
323
+ If we can not detect fake mode from the context of inputs, create one.
324
+
325
+ The created fake mode will be returned.
326
+ """
327
+ fake_mode = detect_fake_mode(example_inputs)
328
+ if not fake_mode:
329
+ fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
330
+ FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
331
+ else:
332
+ ctx = (
333
+ contextlib.nullcontext()
334
+ if not force_allow_non_fake_inputs
335
+ else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
336
+ )
337
+ with ctx: # type: ignore[attr-defined]
338
+ FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
339
+ *example_inputs
340
+ )
341
+
342
+ return fake_mode
343
+
344
+
345
+ # pass config dict back to user
346
+ def get_patched_config_dict(config_patches=None) -> Dict[str, Any]:
347
+ with config.patch(config_patches):
348
+ return config.get_config_copy()
349
+
350
+
351
+ @DebugContext.wrap
352
+ @torch.utils._python_dispatch._disable_current_modes()
353
+ @time_and_log(
354
+ attr="compilation time (in seconds)",
355
+ extra_loggings={"config_dict": str(get_patched_config_dict())},
356
+ )
357
+ # Need this decorator for compile_fx_inner even if we already have one for
358
+ # compile_fx. The reason is the compilation for backward graph may happen after
359
+ # compile_fx return and we may want to use the _LazyGraphModule for compiling
360
+ # the backward graph as well.
361
+ @_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
362
+ @dynamo_utils.dynamo_timed(phase_name="inductor_compile")
363
+ def compile_fx_inner(
364
+ gm: torch.fx.GraphModule,
365
+ example_inputs: List[torch.Tensor],
366
+ cudagraphs: Optional[BoxedBool] = None,
367
+ num_fixed: int = 0,
368
+ is_backward: bool = False,
369
+ graph_id: Optional[int] = None,
370
+ cpp_wrapper: bool = False,
371
+ aot_mode: bool = False,
372
+ is_inference: bool = False,
373
+ boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
374
+ user_visible_outputs: FrozenSet[str] = frozenset(),
375
+ layout_opt: Optional[bool] = None,
376
+ extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
377
+ ) -> Union[CompiledFxGraph, str]:
378
+ """
379
+ Inductor API that compiles a single graph.
380
+
381
+ If you change the argument list for this function, make sure you
382
+ also update the call to save_args_for_compile_fx_inner below accordingly.
383
+ """
384
+ if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode:
385
+ # trigger the real recompilation for _LazyGraphModule before returning
386
+ # the forward method.
387
+ from torch.fx._lazy_graph_module import _LazyGraphModule
388
+
389
+ _LazyGraphModule.force_recompile(gm)
390
+ return make_boxed_func(gm.forward)
391
+
392
+ assert isinstance(
393
+ next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
394
+ ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
395
+
396
+ if config.save_args:
397
+ save_args_for_compile_fx_inner(
398
+ gm,
399
+ example_inputs,
400
+ cudagraphs=cudagraphs,
401
+ num_fixed=num_fixed,
402
+ is_backward=is_backward,
403
+ graph_id=graph_id,
404
+ cpp_wrapper=cpp_wrapper,
405
+ aot_mode=aot_mode,
406
+ is_inference=is_inference,
407
+ boxed_forward_device_index=boxed_forward_device_index,
408
+ user_visible_outputs=user_visible_outputs,
409
+ layout_opt=layout_opt,
410
+ )
411
+
412
+ if cudagraphs is None:
413
+ cudagraphs = BoxedBool(config.triton.cudagraphs)
414
+
415
+ # Inputs to fx_codegen_and_compile
416
+ # Anything that affects codegen should go here, so if the signature
417
+ # of fx_codegen_and_compile changes, the dict should be updated accordingly
418
+ graph_kwargs = {
419
+ "cudagraphs": cudagraphs,
420
+ "num_fixed": num_fixed,
421
+ "is_backward": is_backward,
422
+ "graph_id": graph_id,
423
+ "cpp_wrapper": cpp_wrapper,
424
+ "aot_mode": aot_mode,
425
+ "is_inference": is_inference,
426
+ "user_visible_outputs": user_visible_outputs,
427
+ "layout_opt": layout_opt,
428
+ "extern_node_serializer": extern_node_serializer,
429
+ }
430
+
431
+ start = time.time()
432
+
433
+ if config.fx_graph_cache and not aot_mode:
434
+ compiled_graph = FxGraphCache.load(
435
+ fx_codegen_and_compile, gm, example_inputs, graph_kwargs
436
+ )
437
+ else:
438
+ compiled_graph = fx_codegen_and_compile(
439
+ gm, example_inputs, **graph_kwargs # type: ignore[arg-type]
440
+ )
441
+
442
+ log.debug("FX codegen and compilation took %.3fs", time.time() - start)
443
+
444
+ # check cudagraph disabling reasons from inductor lowering
445
+ if cudagraphs and compiled_graph.disabled_cudagraphs_reason:
446
+ perf_hint_log.warning(
447
+ "skipping cudagraphs due to %s", compiled_graph.disabled_cudagraphs_reason
448
+ )
449
+ BoxedBool.disable(cudagraphs)
450
+
451
+ # Return the output strides to the caller via TracingContext
452
+ context = torch._guards.TracingContext.try_get()
453
+ if context is not None and context.output_strides is not None:
454
+ assert len(context.output_strides) == 0
455
+ context.output_strides.extend(compiled_graph.output_strides)
456
+
457
+ if aot_mode:
458
+ return compiled_graph
459
+
460
+ if cudagraphs:
461
+ # output args are tuple of first argument
462
+ output = output_node(gm)
463
+ assert len(output.args) == 1
464
+ stack_traces = [
465
+ (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
466
+ for arg in output.args[0]
467
+ ]
468
+
469
+ complex_memory_overlap_inputs = any(
470
+ complex_memory_overlap(t)
471
+ for t in example_inputs
472
+ if isinstance(t, torch.Tensor)
473
+ )
474
+
475
+ from torch._inductor.cudagraph_utils import check_for_mutation
476
+
477
+ has_mutation_str = check_for_mutation(gm, compiled_graph, num_fixed)
478
+ has_mutation = has_mutation_str is not None
479
+
480
+ if has_mutation:
481
+ compiled_graph.disabled_cudagraphs_reason = has_mutation_str
482
+
483
+ cudagraph_tests = [
484
+ (not has_mutation, "mutated inputs"),
485
+ (not has_incompatible_cudagraph_ops(gm), "incompatible ops"),
486
+ (not complex_memory_overlap_inputs, "complex memory overlap"),
487
+ (
488
+ all(
489
+ isinstance(t, (torch.Tensor, torch.SymInt)) for t in example_inputs
490
+ ),
491
+ "non-Tensor inputs",
492
+ ),
493
+ ]
494
+ cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b]
495
+
496
+ if not cudagraph_fail_reasons:
497
+ if not config.triton.cudagraph_trees:
498
+ # Force specialize all inputs so that CUDA graphs will work
499
+ for t in example_inputs:
500
+ if isinstance(t, torch.SymInt):
501
+ int(t) # guard
502
+
503
+ if (
504
+ boxed_forward_device_index is not None
505
+ and not is_inference
506
+ and not is_backward
507
+ ):
508
+ boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
509
+
510
+ compiled_graph.current_callable = cudagraphify(
511
+ compiled_graph.get_current_callable(),
512
+ example_inputs,
513
+ static_input_idxs=range(num_fixed),
514
+ device_index=next(iter(compiled_graph.device_idxs)),
515
+ stack_traces=stack_traces,
516
+ is_backward=is_backward,
517
+ is_inference=is_inference,
518
+ constants=tuple(compiled_graph.constants.values()),
519
+ )
520
+ else:
521
+ BoxedBool.disable(cudagraphs)
522
+
523
+ # See [Backward Generation Handling]
524
+ # if cudagraph'd the forward and set the device, we need to let the cudagraph manager
525
+ # know we are we running the backward even if we will not run it in cudagraphs
526
+ if is_backward and config.triton.cudagraph_trees:
527
+ assert boxed_forward_device_index is not None
528
+ assert boxed_forward_device_index.value is not None
529
+ compiled_graph_callable = compiled_graph.get_current_callable()
530
+
531
+ manager = torch._inductor.cudagraph_trees.get_manager(
532
+ boxed_forward_device_index.value, create_if_none_exists=False
533
+ )
534
+ # should already exist from forward
535
+ assert manager is not None
536
+
537
+ def compiled_artifact(new_inputs):
538
+ manager.set_to_running_backward()
539
+ return compiled_graph_callable(new_inputs)
540
+
541
+ compiled_graph.current_callable = compiled_artifact
542
+
543
+ if "cuda" in compiled_graph.device_types:
544
+ # prefer better disable_cudagraphs_reason bc stack trace
545
+ # TODO: migrate all disable reasons to stack trace, refactor
546
+ if compiled_graph.disabled_cudagraphs_reason:
547
+ perf_hint_log.warning(compiled_graph.disabled_cudagraphs_reason)
548
+ else:
549
+ perf_hint_log.warning(
550
+ "skipping cudagraphs due to %s", cudagraph_fail_reasons
551
+ )
552
+
553
+ # cudagraphs does its own aligning of inputs
554
+ if not cudagraphs:
555
+ new_callable = align_inputs(
556
+ compiled_graph.get_current_callable(), example_inputs, range(num_fixed)
557
+ )
558
+ if new_callable is not compiled_graph.get_current_callable():
559
+ compiled_graph.current_callable = new_callable
560
+
561
+ _step_logger()(
562
+ logging.INFO,
563
+ "torchinductor done compiling "
564
+ f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
565
+ f"graph {graph_id}",
566
+ )
567
+
568
+ # aot autograd needs to know to pass in inputs as a list
569
+ compiled_graph._boxed_call = True
570
+ return compiled_graph
571
+
572
+
573
+ def fx_codegen_and_compile(
574
+ gm: torch.fx.GraphModule,
575
+ example_inputs: List[torch.Tensor],
576
+ cudagraphs: Optional[BoxedBool] = None,
577
+ num_fixed: int = 0,
578
+ is_backward: bool = False,
579
+ graph_id: Optional[int] = None,
580
+ cpp_wrapper: bool = False,
581
+ aot_mode: bool = False,
582
+ is_inference: bool = False,
583
+ user_visible_outputs: FrozenSet[str] = frozenset(),
584
+ layout_opt: Optional[bool] = None,
585
+ extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
586
+ ) -> Union[CompiledFxGraph, str]:
587
+ if is_tf32_warning_applicable(gm):
588
+ _warn_tf32_disabled()
589
+
590
+ # lift the maximum depth of the Python interpreter stack
591
+ # to adapt large/deep models
592
+ sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
593
+
594
+ _step_logger()(
595
+ logging.INFO,
596
+ "torchinductor compiling "
597
+ f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
598
+ f"graph {graph_id}",
599
+ )
600
+ V.debug.fx_graph(gm, example_inputs)
601
+ # TODO: Should we actually dump this? It should be redundant with the aot
602
+ # structured logs...
603
+ # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False))
604
+
605
+ shape_env = _shape_env_from_inputs(example_inputs)
606
+
607
+ # Convert view to reshape in the graph. This is necessary primarily for
608
+ # layout optimization. Do it unconditionally for uniformity.
609
+ #
610
+ # It's needed because when we do layout optimization, an contiguous tensor
611
+ # in eager mode may becomes a channels last tensor. A view op previously
612
+ # can be applied to the contiguous tensor may not be able to be applied
613
+ # on the channels tensor any more. An error like
614
+ # RuntimeError: view size is not compatible with input tensor's size and stride
615
+ # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
616
+ # will be printed.
617
+ #
618
+ # Replace view op to reshape op in this case.
619
+ # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this.
620
+ #
621
+ # Also this has to be done before FakeTensorProp below to avoid the failed
622
+ # .view() call.
623
+ view_to_reshape(gm)
624
+
625
+ # It is safe to run FakeTensorProp under no_grad because by the time
626
+ # we're in inductor, we assume that AOTAutograd has already "taken care"
627
+ # of autograd, so there should be no more autograd-related API's in the
628
+ # graph.
629
+ with torch.no_grad():
630
+ fake_mode = fake_tensor_prop(gm, example_inputs)
631
+
632
+ # pattern matcher passes might not preserve striding information
633
+ # on node.meta["val"]. if in the future we rely on these being
634
+ # correct we will need to fix.
635
+
636
+ with V.set_fake_mode(fake_mode):
637
+ # has some issues with memory in training
638
+ _recursive_post_grad_passes(gm, is_inference=is_inference)
639
+ V.debug.fx_graph_transformed(gm, example_inputs)
640
+ post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
641
+ trace_structured(
642
+ "inductor_post_grad_graph",
643
+ payload_fn=lambda: gm.print_readable(print_output=False),
644
+ )
645
+ optimus_scuba_log["inductor_post_grad"] = counters["inductor"]
646
+ signpost_event(
647
+ "optimus",
648
+ "compile_fx.post_grad_passes",
649
+ optimus_scuba_log,
650
+ )
651
+
652
+ with V.set_fake_mode(fake_mode):
653
+ const_output_index = None
654
+ const_graph = None
655
+ const_code = None
656
+
657
+ if aot_mode and config.aot_inductor.use_runtime_constant_folding:
658
+ const_gm, const_output_index = split_const_gm(gm)
659
+
660
+ const_graph = GraphLowering(
661
+ const_gm,
662
+ example_inputs=[],
663
+ shape_env=shape_env,
664
+ num_static_inputs=num_fixed,
665
+ graph_id=graph_id,
666
+ cpp_wrapper=cpp_wrapper,
667
+ aot_mode=aot_mode,
668
+ user_visible_outputs=user_visible_outputs,
669
+ extern_node_serializer=extern_node_serializer,
670
+ is_inference=is_inference,
671
+ is_const_graph=True,
672
+ )
673
+ with V.set_graph_handler(const_graph):
674
+ assert cpp_wrapper, "AOT mode only supports C++ wrapper"
675
+ const_graph.run()
676
+
677
+ const_code, _ = const_graph.codegen_with_cpp_wrapper()
678
+
679
+ graph = GraphLowering(
680
+ gm,
681
+ # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
682
+ # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
683
+ # we currently use fake tensors and defake them later.
684
+ example_inputs=example_inputs,
685
+ shape_env=shape_env,
686
+ num_static_inputs=num_fixed,
687
+ graph_id=graph_id,
688
+ cpp_wrapper=cpp_wrapper,
689
+ aot_mode=aot_mode,
690
+ user_visible_outputs=user_visible_outputs,
691
+ extern_node_serializer=extern_node_serializer,
692
+ is_inference=is_inference,
693
+ const_output_index=const_output_index,
694
+ const_code=const_code,
695
+ const_module=const_graph,
696
+ )
697
+ with V.set_graph_handler(graph):
698
+ graph.run(*example_inputs)
699
+ output_strides: List[Optional[Tuple[int, ...]]] = []
700
+ if graph.graph_outputs is not None:
701
+ # We'll put the output strides in the compiled graph so we
702
+ # can later return them to the caller via TracingContext
703
+ for out in graph.graph_outputs:
704
+ if hasattr(out, "layout"):
705
+ output_strides.append(
706
+ tuple(
707
+ V.graph.sizevars.size_hint(s) for s in out.layout.stride
708
+ )
709
+ )
710
+ else:
711
+ output_strides.append(None)
712
+
713
+ metrics_helper = metrics.CachedMetricsHelper()
714
+ compiled_fn = graph.compile_to_fn()
715
+
716
+ if V.aot_compilation is True:
717
+ return compiled_fn
718
+
719
+ if cudagraphs and not V.graph.disable_cudagraphs_reason:
720
+ from torch._inductor.cudagraph_utils import (
721
+ check_lowering_disable_cudagraph,
722
+ )
723
+
724
+ V.graph.disable_cudagraphs_reason = check_lowering_disable_cudagraph(
725
+ V.graph.device_node_mapping
726
+ )
727
+
728
+ compiled_graph = CompiledFxGraph(
729
+ compiled_fn,
730
+ graph,
731
+ output_strides,
732
+ V.graph.disable_cudagraphs_reason,
733
+ metrics_helper.get_deltas(),
734
+ )
735
+
736
+ return compiled_graph
737
+
738
+
739
+ def clone_preserve_strides(x: torch.Tensor):
740
+ needed_size = (
741
+ sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
742
+ )
743
+ buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
744
+ return torch.as_strided(buffer, x.size(), x.stride())
745
+
746
+
747
+ def copy_misaligned_inputs(
748
+ new_inputs: List[torch.Tensor], check_inputs_idxs: Sequence[int]
749
+ ) -> None:
750
+ for i in check_inputs_idxs:
751
+ if new_inputs[i].data_ptr() % ALIGNMENT:
752
+ new_inputs[i] = clone_preserve_strides(new_inputs[i])
753
+
754
+
755
+ def get_input_idxs_to_check(
756
+ inputs: Union[List[torch.Tensor], Sequence[int]],
757
+ static_input_idxs: Sequence[int],
758
+ ) -> Sequence[int]:
759
+ def is_aligned(storage_offset, dtype):
760
+ return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0
761
+
762
+ ids_to_check = []
763
+ for i, input in enumerate(inputs):
764
+ if (
765
+ isinstance(input, torch.Tensor)
766
+ and (
767
+ i not in static_input_idxs
768
+ or not is_aligned(input.storage_offset(), input.dtype)
769
+ )
770
+ and input.device.type == "cuda"
771
+ ):
772
+ ids_to_check.append(i)
773
+ return ids_to_check
774
+
775
+
776
+ def align_inputs_from_check_idxs(
777
+ model: Callable[[List[torch.Tensor]], Any], inputs_to_check: Sequence[int]
778
+ ):
779
+ if len(inputs_to_check) == 0:
780
+ return model
781
+
782
+ def run(new_inputs):
783
+ copy_misaligned_inputs(new_inputs, inputs_to_check)
784
+ return model(new_inputs)
785
+
786
+ return run
787
+
788
+
789
+ def align_inputs(
790
+ model: Callable[[List[torch.Tensor]], Any],
791
+ inputs: List[torch.Tensor],
792
+ static_input_idxs: Sequence[int] = (),
793
+ ):
794
+ inputs_to_check = get_input_idxs_to_check(inputs, static_input_idxs)
795
+ return align_inputs_from_check_idxs(model, inputs_to_check)
796
+
797
+
798
+ @dynamo_utils.dynamo_timed
799
+ def cudagraphify(
800
+ model: torch.fx.GraphModule,
801
+ inputs: List[torch.Tensor],
802
+ static_input_idxs: Sequence[int] = (),
803
+ *,
804
+ device_index: int,
805
+ stack_traces: List[Optional[str]],
806
+ is_backward: bool,
807
+ is_inference: bool,
808
+ constants: Tuple[torch.Tensor, ...] = (),
809
+ ):
810
+ from torch._inductor.cudagraph_trees import (
811
+ cudagraphify_impl as new_cudagraphify_impl,
812
+ )
813
+
814
+ cudagraphify_fn: Callable[..., Any]
815
+ if config.triton.cudagraph_trees:
816
+ cudagraphify_fn = functools.partial(
817
+ new_cudagraphify_impl,
818
+ device_index=device_index,
819
+ stack_traces=stack_traces,
820
+ is_backward=is_backward,
821
+ is_inference=is_inference,
822
+ constants=constants,
823
+ )
824
+ else:
825
+ cudagraphify_fn = cudagraphify_impl
826
+
827
+ # if using fake tensors, defer cudagraphs until we get real inputs at runtime
828
+ if not any(isinstance(inp, FakeTensor) for inp in inputs):
829
+ return cudagraphify_fn(model, inputs, static_input_idxs)
830
+
831
+ compiled_fn = None
832
+
833
+ def run(new_inputs):
834
+ nonlocal compiled_fn
835
+ if compiled_fn is None:
836
+ with dynamo_utils.preserve_rng_state():
837
+ compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
838
+ return compiled_fn(new_inputs)
839
+
840
+ return run
841
+
842
+
843
+ def remove_unaligned_input_idxs(
844
+ inputs: Union[List[torch.Tensor], Sequence[int]],
845
+ static_input_idxs: Sequence[int],
846
+ ):
847
+ """
848
+ We require all inputs to be aligned, so introduce a copy for any
849
+ that aren't.
850
+ """
851
+ aligned_static_input_idxs = []
852
+ for idx, input in zip(static_input_idxs, inputs):
853
+ if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
854
+ aligned_static_input_idxs.append(idx)
855
+ if len(aligned_static_input_idxs) != len(static_input_idxs):
856
+ return aligned_static_input_idxs
857
+ return static_input_idxs
858
+
859
+
860
+ def static_input(x: torch.Tensor):
861
+ """
862
+ Copy and input while preserving strides
863
+ """
864
+ # TODO(jansel): figure out why this version doesn't work:
865
+ # return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
866
+ needed_size = (
867
+ sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
868
+ )
869
+ buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device)
870
+ return torch.as_strided(buffer, x.size(), x.stride())
871
+
872
+
873
+ def index_expanded_dims_and_copy_(
874
+ dst: torch.Tensor,
875
+ src: torch.Tensor,
876
+ expanded_dims: List[int],
877
+ ):
878
+ "Index into expanded dimensions of both dst and src then copy_"
879
+ dst = index_expanded_dims(dst, expanded_dims)
880
+ src = index_expanded_dims(src, expanded_dims)
881
+ dst.copy_(src)
882
+
883
+
884
+ def cudagraphify_impl(
885
+ model: torch.fx.GraphModule,
886
+ inputs: List[torch.Tensor],
887
+ static_input_idxs: Sequence[int] = (),
888
+ ):
889
+ """
890
+ Assumes inputs[static_input_idxs[i]] are always the same memory address
891
+ """
892
+ check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs)
893
+ static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
894
+ copy_misaligned_inputs(inputs, check_input_idxs)
895
+
896
+ assert isinstance(inputs, list)
897
+
898
+ inps_expanded_dims = [
899
+ get_expanded_dims(x) if idx not in static_input_idxs else []
900
+ for idx, x in enumerate(inputs)
901
+ ]
902
+
903
+ # allocate static tensor inputs
904
+ static_inputs = [
905
+ x
906
+ if not isinstance(x, torch.Tensor)
907
+ else static_input(x)
908
+ if idx not in static_input_idxs
909
+ else x.detach()
910
+ for idx, x in enumerate(inputs)
911
+ ]
912
+
913
+ # copy over input values for fresh allocations
914
+ for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)):
915
+ if isinstance(x, torch.Tensor) and idx not in static_input_idxs:
916
+ index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims)
917
+
918
+ # warmup
919
+ torch.cuda.synchronize()
920
+ stream = torch.cuda.Stream()
921
+ stream.wait_stream(torch.cuda.current_stream())
922
+ # copy static_inputs because it will be cleared in model
923
+ with torch.cuda.stream(stream):
924
+ model(list(static_inputs))
925
+ stream.synchronize()
926
+ torch.cuda.current_stream().wait_stream(stream)
927
+ torch.cuda.synchronize()
928
+
929
+ # record
930
+ graph = torch.cuda.CUDAGraph()
931
+ with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):
932
+ static_outputs = model(list(static_inputs))
933
+ if not isinstance(static_outputs, (list, tuple)):
934
+ static_outputs = (static_outputs,)
935
+
936
+ if config.size_asserts:
937
+
938
+ def run(new_inputs):
939
+ assert len(static_inputs) == len(new_inputs)
940
+ for idx, (dst, src, expanded_dims) in enumerate(
941
+ zip(static_inputs, new_inputs, inps_expanded_dims)
942
+ ):
943
+ if not isinstance(dst, torch.Tensor):
944
+ pass
945
+ elif idx in static_input_idxs:
946
+ assert dst.data_ptr() == src.data_ptr()
947
+ else:
948
+ # TODO - could make one single op of multiple slices
949
+ # and avoid dispatch.
950
+ # Could also pre-index the `dst` tensors
951
+ index_expanded_dims_and_copy_(dst, src, expanded_dims)
952
+ new_inputs.clear()
953
+ graph.replay()
954
+ return static_outputs
955
+
956
+ else:
957
+ copy_indices = [
958
+ idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
959
+ ]
960
+
961
+ def run(new_inputs):
962
+ for idx in copy_indices:
963
+ expanded_dims = inps_expanded_dims[idx]
964
+ index_expanded_dims_and_copy_(
965
+ static_inputs[idx], new_inputs[idx], expanded_dims
966
+ )
967
+ new_inputs.clear()
968
+ graph.replay()
969
+ return static_outputs
970
+
971
+ return align_inputs_from_check_idxs(run, check_input_idxs)
972
+
973
+
974
+ def compile_fx_aot(
975
+ model_: torch.fx.GraphModule,
976
+ example_inputs_: List[torch.Tensor],
977
+ inner_compile: Callable[..., Any] = compile_fx_inner,
978
+ config_patches: Optional[Dict[str, Any]] = None,
979
+ ):
980
+ config_patches: Dict[str, Any] = (
981
+ {"cpp_wrapper": True}
982
+ if config_patches is None
983
+ else {**config_patches, "cpp_wrapper": True}
984
+ )
985
+ if (
986
+ "aot_inductor.output_path" not in config_patches
987
+ and not config.aot_inductor.output_path
988
+ ):
989
+ config_patches = {
990
+ **config_patches,
991
+ "aot_inductor.output_path": code_hash(model_.code),
992
+ }
993
+
994
+ extern_node_serializer = config_patches.pop("extern_node_serializer", None)
995
+ with V.set_aot_compilation(True):
996
+ compiled_lib_path = compile_fx(
997
+ model_,
998
+ example_inputs_,
999
+ inner_compile=functools.partial(
1000
+ inner_compile,
1001
+ aot_mode=True,
1002
+ extern_node_serializer=extern_node_serializer,
1003
+ ),
1004
+ config_patches=config_patches,
1005
+ )
1006
+ assert os.path.exists(
1007
+ compiled_lib_path
1008
+ ), f"AOTInductor compiled library does not exist at {compiled_lib_path}"
1009
+ return compiled_lib_path
1010
+
1011
+
1012
+ _graph_counter = count(0)
1013
+
1014
+
1015
+ def fw_compiler_freezing(
1016
+ aot_autograd_model: torch.fx.GraphModule,
1017
+ aot_example_inputs: List[torch.Tensor],
1018
+ dynamo_model: torch.fx.GraphModule,
1019
+ num_example_inputs: int,
1020
+ inner_compile: Callable[..., Any],
1021
+ cudagraphs: BoxedBool,
1022
+ graph_id: int,
1023
+ forward_device: BoxedDeviceIndex,
1024
+ ):
1025
+ from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
1026
+
1027
+ # partition_fn won't be called
1028
+ _recursive_joint_graph_passes(aot_autograd_model)
1029
+
1030
+ layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True)
1031
+ if layout_opt:
1032
+ # make sure meta['val'] is properly setup
1033
+ fake_tensor_prop(aot_autograd_model, aot_example_inputs, True)
1034
+ convert_conv_weights_to_channels_last(aot_autograd_model)
1035
+
1036
+ opt_model, preserved_arg_indices = freeze(
1037
+ dynamo_model,
1038
+ aot_autograd_model,
1039
+ aot_example_inputs, # type: ignore[arg-type]
1040
+ )
1041
+
1042
+ aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
1043
+ num_fixed = len(preserved_arg_indices) - num_example_inputs
1044
+
1045
+ fake_mode = detect_fake_mode(aot_example_inputs)
1046
+
1047
+ # for freezing, all graph outputs should be user visible
1048
+ *_, model_outputs_node = opt_model.graph.nodes
1049
+ model_outputs = model_outputs_node.args[0]
1050
+ user_visible_outputs = [
1051
+ n.name for n in model_outputs if isinstance(n, torch.fx.Node)
1052
+ ]
1053
+
1054
+ # constant params will be real tensors, not fake
1055
+ tracing_context = torch._guards.TracingContext.try_get()
1056
+ if tracing_context is not None:
1057
+ params_flat = tracing_context.params_flat
1058
+ assert params_flat is not None
1059
+ for i in range(len(params_flat)):
1060
+ if i not in preserved_arg_indices:
1061
+ params_flat[i] = None
1062
+
1063
+ with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
1064
+ optimized_function = inner_compile(
1065
+ opt_model,
1066
+ aot_example_inputs,
1067
+ num_fixed=num_fixed,
1068
+ cudagraphs=cudagraphs,
1069
+ graph_id=graph_id,
1070
+ is_inference=True,
1071
+ boxed_forward_device_index=forward_device,
1072
+ layout_opt=layout_opt,
1073
+ user_visible_outputs=user_visible_outputs,
1074
+ )
1075
+
1076
+ # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper
1077
+ # that drops constant-ified params
1078
+ if V.aot_compilation is True:
1079
+ return optimized_function
1080
+
1081
+ def wrapper(args):
1082
+ args_new = [args[i] for i in preserved_arg_indices]
1083
+ args.clear()
1084
+ return optimized_function(args_new)
1085
+
1086
+ wrapper._boxed_call = True # type: ignore[attr-defined]
1087
+
1088
+ return wrapper
1089
+
1090
+
1091
+ @_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
1092
+ def compile_fx(
1093
+ model_: torch.fx.GraphModule,
1094
+ example_inputs_: List[torch.Tensor],
1095
+ inner_compile: Callable[..., Any] = compile_fx_inner,
1096
+ config_patches: Optional[Dict[str, Any]] = None,
1097
+ decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
1098
+ ):
1099
+ """Main entrypoint to a compile given FX graph"""
1100
+ if config_patches:
1101
+ with config.patch(config_patches):
1102
+ return compile_fx(
1103
+ model_,
1104
+ example_inputs_,
1105
+ # need extra layer of patching as backwards is compiled out of scope
1106
+ inner_compile=config.patch(config_patches)(inner_compile),
1107
+ decompositions=decompositions,
1108
+ )
1109
+
1110
+ if config.cpp_wrapper:
1111
+ with config.patch(
1112
+ {
1113
+ "cpp_wrapper": False,
1114
+ "triton.autotune_cublasLt": False,
1115
+ "triton.cudagraphs": False,
1116
+ "triton.store_cubin": True,
1117
+ }
1118
+ ), V.set_real_inputs(example_inputs_):
1119
+ inputs_ = example_inputs_
1120
+ if isinstance(model_, torch.fx.GraphModule):
1121
+ fake_inputs = [
1122
+ node.meta.get("val")
1123
+ for node in model_.graph.nodes
1124
+ if node.op == "placeholder"
1125
+ ]
1126
+ if all(v is not None for v in fake_inputs):
1127
+ # Validate devices before switching to fake tensors.
1128
+ for idx, fi, i in zip(count(), fake_inputs, inputs_):
1129
+ if fi.device != i.device:
1130
+ raise ValueError(
1131
+ f"Device mismatch between fake input and example input at position #{idx}: "
1132
+ f"{fi.device} vs {i.device}. If the model was exported via torch.export(), "
1133
+ "make sure torch.export() and torch.aot_compile() run on the same device."
1134
+ )
1135
+ inputs_ = fake_inputs
1136
+ return compile_fx(
1137
+ model_,
1138
+ inputs_,
1139
+ inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
1140
+ decompositions=decompositions,
1141
+ )
1142
+
1143
+ recursive_compile_fx = functools.partial(
1144
+ compile_fx,
1145
+ inner_compile=inner_compile,
1146
+ decompositions=decompositions,
1147
+ )
1148
+
1149
+ if not graph_returns_tuple(model_):
1150
+ return make_graph_return_tuple(
1151
+ model_,
1152
+ example_inputs_,
1153
+ recursive_compile_fx,
1154
+ )
1155
+
1156
+ if isinstance(model_, torch.fx.GraphModule):
1157
+ if isinstance(model_.graph._codegen, _PyTreeCodeGen):
1158
+ # this graph is the result of dynamo.export()
1159
+ return handle_dynamo_export_graph(
1160
+ model_,
1161
+ example_inputs_,
1162
+ recursive_compile_fx,
1163
+ )
1164
+
1165
+ model_ = _recursive_pre_grad_passes(model_, example_inputs_)
1166
+ optimus_scuba_log["inductor_pre_grad"] = counters["inductor"]
1167
+ signpost_event(
1168
+ "optimus",
1169
+ "compile_fx.pre_grad_passes",
1170
+ optimus_scuba_log,
1171
+ )
1172
+
1173
+ if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
1174
+ return flatten_graph_inputs(
1175
+ model_,
1176
+ example_inputs_,
1177
+ recursive_compile_fx,
1178
+ )
1179
+
1180
+ assert not config._raise_error_for_testing
1181
+ num_example_inputs = len(example_inputs_)
1182
+ cudagraphs = BoxedBool(config.triton.cudagraphs)
1183
+ forward_device = BoxedDeviceIndex(None)
1184
+
1185
+ graph_id = next(_graph_counter)
1186
+
1187
+ decompositions = (
1188
+ decompositions if decompositions is not None else select_decomp_table()
1189
+ )
1190
+
1191
+ @dynamo_utils.dynamo_timed
1192
+ def fw_compiler_base(
1193
+ model: torch.fx.GraphModule,
1194
+ example_inputs: List[torch.Tensor],
1195
+ is_inference: bool,
1196
+ ):
1197
+ if is_inference:
1198
+ # partition_fn won't be called
1199
+ _recursive_joint_graph_passes(model)
1200
+
1201
+ fixed = torch._inductor.utils.num_fw_fixed_arguments(
1202
+ num_example_inputs, len(example_inputs)
1203
+ )
1204
+ user_visible_outputs = set()
1205
+
1206
+ if config.keep_output_stride:
1207
+ *_, model_outputs_node = model.graph.nodes
1208
+ assert model_outputs_node.op == "output"
1209
+ model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
1210
+ num_model_outputs = len(model_outputs)
1211
+
1212
+ context = torch._guards.TracingContext.try_get()
1213
+ # See Note [User Outputs in the inductor graph]
1214
+ if context is not None and context.fw_metadata and not is_inference:
1215
+ original_output_start_index = (
1216
+ context.fw_metadata.num_mutated_inp_runtime_indices
1217
+ )
1218
+ else:
1219
+ original_output_start_index = 0
1220
+
1221
+ if isinstance(model_, torch.fx.GraphModule):
1222
+ *_, orig_model_outputs_node = model_.graph.nodes
1223
+ assert orig_model_outputs_node.op == "output"
1224
+ orig_model_outputs, _ = pytree.tree_flatten(
1225
+ orig_model_outputs_node.args
1226
+ )
1227
+ num_orig_model_outputs = len(orig_model_outputs)
1228
+ else:
1229
+ num_orig_model_outputs = num_model_outputs
1230
+
1231
+ assert num_orig_model_outputs <= num_model_outputs
1232
+
1233
+ # Note [User Outputs in the inductor graph]
1234
+ # We makes the following assumption
1235
+ # For inference
1236
+ # len(orig_model_outputs) == len(model_outputs)
1237
+ # For training
1238
+ # len(orig_model_outputs) <= len(model_outputs)
1239
+ # During training, most of the time the model_outputs starts with
1240
+ # original module's outputs followed by saved activations.
1241
+ # But this can be not true if the model have inplace updated tensors.
1242
+ # AOTAutograd will make those tensors being returned before the original
1243
+ # module's output.
1244
+ # To make things safe, we'll use original_output_start_index field
1245
+ # set by AOTAutograd to decide where the original module outputs start.
1246
+ orig_output_end_idx = original_output_start_index + num_orig_model_outputs
1247
+ # Sanity chec: we are about to splice out the "user" outputs from the full set
1248
+ # of "graph" outputs. Make sure we're within bounds.
1249
+ assert orig_output_end_idx <= num_model_outputs
1250
+
1251
+ user_visible_outputs = {
1252
+ n.name
1253
+ for n in model_outputs[original_output_start_index:orig_output_end_idx]
1254
+ if isinstance(n, torch.fx.Node)
1255
+ }
1256
+
1257
+ return inner_compile(
1258
+ model,
1259
+ example_inputs,
1260
+ num_fixed=fixed,
1261
+ cudagraphs=cudagraphs,
1262
+ graph_id=graph_id,
1263
+ is_inference=is_inference,
1264
+ boxed_forward_device_index=forward_device,
1265
+ user_visible_outputs=user_visible_outputs,
1266
+ )
1267
+
1268
+ fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
1269
+
1270
+ if config.freezing and not torch.is_grad_enabled():
1271
+ inference_compiler = functools.partial(
1272
+ fw_compiler_freezing,
1273
+ dynamo_model=model_,
1274
+ num_example_inputs=num_example_inputs,
1275
+ inner_compile=inner_compile,
1276
+ cudagraphs=cudagraphs,
1277
+ graph_id=graph_id,
1278
+ forward_device=forward_device,
1279
+ )
1280
+ else:
1281
+ inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
1282
+
1283
+ def partition_fn(graph, joint_inputs, **kwargs):
1284
+ _recursive_joint_graph_passes(graph)
1285
+ return min_cut_rematerialization_partition(
1286
+ graph, joint_inputs, **kwargs, compiler="inductor"
1287
+ )
1288
+
1289
+ @dynamo_utils.dynamo_timed
1290
+ @dynamo_utils.maybe_cprofile
1291
+ def bw_compiler(model: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
1292
+ fixed = count_tangents(model)
1293
+ return inner_compile(
1294
+ model,
1295
+ example_inputs,
1296
+ num_fixed=fixed,
1297
+ cudagraphs=cudagraphs,
1298
+ is_backward=True,
1299
+ graph_id=graph_id,
1300
+ boxed_forward_device_index=forward_device,
1301
+ )
1302
+
1303
+ # TODO: can add logging before/after the call to create_aot_dispatcher_function
1304
+ # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
1305
+ # once torchdynamo is merged into pytorch
1306
+
1307
+ fake_mode = detect_fake_mode(example_inputs_) or torch._subclasses.FakeTensorMode(
1308
+ allow_non_fake_inputs=True
1309
+ )
1310
+ tracing_context = (
1311
+ torch._guards.TracingContext.try_get()
1312
+ or torch._guards.TracingContext(fake_mode)
1313
+ )
1314
+
1315
+ if V.aot_compilation is True:
1316
+ gm, graph_signature = aot_export_module(
1317
+ model_, example_inputs_, trace_joint=False, decompositions=decompositions
1318
+ )
1319
+ unlifted_gm = _unlift_graph(model_, gm, graph_signature)
1320
+ if "dynamo_flat_name_to_original_fqn" in model_.meta:
1321
+ unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[
1322
+ "dynamo_flat_name_to_original_fqn"
1323
+ ]
1324
+ with V.set_fake_mode(fake_mode), compiled_autograd.disable():
1325
+ return inference_compiler(unlifted_gm, example_inputs_)
1326
+
1327
+ with V.set_fake_mode(fake_mode), torch._guards.tracing(
1328
+ tracing_context
1329
+ ), compiled_autograd.disable():
1330
+ return aot_autograd(
1331
+ fw_compiler=fw_compiler,
1332
+ bw_compiler=bw_compiler,
1333
+ inference_compiler=inference_compiler,
1334
+ decompositions=decompositions,
1335
+ partition_fn=partition_fn,
1336
+ keep_inference_input_mutations=True,
1337
+ )(model_, example_inputs_)
1338
+
1339
+
1340
+ def _shape_env_from_inputs(inputs: List[torch.Tensor]):
1341
+ shape_env = None
1342
+ fake_mode = detect_fake_mode(inputs)
1343
+
1344
+ # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
1345
+ # pass in real inputs for now.
1346
+ # if len(inputs) > 0:
1347
+ # assert fake_mode is not None, breakpoint()
1348
+
1349
+ if fake_mode is not None:
1350
+ return fake_mode.shape_env
1351
+
1352
+ # When there are no tensor inputs, get shape_env from the first SymInt.
1353
+ for input in inputs:
1354
+ if isinstance(input, torch.SymInt):
1355
+ return input.node.shape_env
1356
+
1357
+ # TODO(voz): Should we always have one anyway?
1358
+ return None
1359
+
1360
+
1361
+ def graph_returns_tuple(gm: torch.fx.GraphModule):
1362
+ """True if a FX graph returns a tuple"""
1363
+ if not isinstance(gm, torch.fx.GraphModule):
1364
+ return True # can't check this, assume true
1365
+ (rv,) = output_node(gm).args
1366
+ if isinstance(rv, (list, tuple)):
1367
+ return True
1368
+ if (
1369
+ isinstance(rv, torch.fx.node.Node)
1370
+ and hasattr(rv.target, "_schema")
1371
+ and len(rv.target._schema.returns) > 1
1372
+ and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns)
1373
+ ):
1374
+ # for graphs whose result is one node with multiple outputs
1375
+ return True
1376
+ return False
1377
+
1378
+
1379
+ def make_graph_return_tuple(
1380
+ gm: torch.fx.GraphModule,
1381
+ inputs: List[torch.Tensor],
1382
+ compile_gm: Callable[..., Any],
1383
+ ):
1384
+ """
1385
+ Mutate gm so it returns a tuple. This is only needed for graphs
1386
+ not created by torchdynamo that return non-tuples.
1387
+ """
1388
+ node = output_node(gm)
1389
+ (rv,) = node.args
1390
+ rv, spec = pytree.tree_flatten(rv)
1391
+ with gm.graph.inserting_before(node):
1392
+ gm.graph.output(rv)
1393
+ gm.graph.erase_node(node)
1394
+ assert graph_returns_tuple(gm)
1395
+
1396
+ compiled_fn = compile_gm(gm, inputs)
1397
+
1398
+ @functools.wraps(compiled_fn)
1399
+ def wrapper(*args, **kwargs):
1400
+ return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
1401
+
1402
+ return wrapper
1403
+
1404
+
1405
+ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
1406
+ """
1407
+ Mutate inputs so that they are flat and wrap gm such that it
1408
+ accepts those inputs. This is only needed for graphs not created
1409
+ by torchdynamo that take bumpy inputs.
1410
+ """
1411
+ inputs, spec = pytree.tree_flatten(inputs)
1412
+
1413
+ class GmWrapper(torch.nn.Module):
1414
+ def __init__(self):
1415
+ super().__init__()
1416
+ self.gm = gm
1417
+
1418
+ def forward(self, *args):
1419
+ args: List[Any] = list(args)
1420
+ return self.gm(*pytree.tree_unflatten(args, spec))
1421
+
1422
+ compiled_fn = compile_gm(GmWrapper(), inputs)
1423
+
1424
+ @functools.wraps(compiled_fn)
1425
+ def wrapper(*args):
1426
+ # note this doesn't check the spec, assuming it is the same
1427
+ return compiled_fn(*pytree.arg_tree_leaves(*args))
1428
+
1429
+ return wrapper
1430
+
1431
+
1432
+ def handle_dynamo_export_graph(
1433
+ gm: torch.fx.GraphModule,
1434
+ inputs: List[torch.Tensor],
1435
+ compile_gm: Callable[..., Any],
1436
+ ):
1437
+ """
1438
+ `torch._dynamo.export` embeds pytrees in the FX graph codegen object,
1439
+ convert that to a normal FX graph so inductor can compile it.
1440
+ """
1441
+ codegen = gm.graph._codegen
1442
+ gm.graph._codegen = torch.fx.graph.CodeGen()
1443
+ gm.recompile()
1444
+
1445
+ compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs))
1446
+
1447
+ @functools.wraps(compiled_fn)
1448
+ def wrapper(*args):
1449
+ return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
1450
+
1451
+ return wrapper
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/cudagraph_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import Dict, Iterable, Optional
3
+
4
+ import torch
5
+ from torch._inductor.codecache import CompiledFxGraph
6
+
7
+
8
+ def get_mutating_use_stack_trace(placeholder_node: torch.fx.Node) -> Optional[str]:
9
+ # reinplaced uses might have a single, non-copy_ use
10
+ if len(placeholder_node.users) == 1:
11
+ return next(iter(placeholder_node.users)).meta.get("stack_trace", None)
12
+
13
+ for use in placeholder_node.users:
14
+ if use.target == torch.ops.aten.copy_.default:
15
+ if stack_trace := use.meta.get("stack_trace", None):
16
+ return stack_trace
17
+
18
+ return None
19
+
20
+
21
+ def format_default_skip_message(reason: str) -> str:
22
+ return f"skipping cudagraphs due to {reason}"
23
+
24
+
25
+ def get_mutation_stack_trace(
26
+ gm: torch.fx.GraphModule, mutation_indices: Iterable[int]
27
+ ) -> str:
28
+ stack_trace: Optional[str] = ""
29
+ placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
30
+
31
+ for idx in mutation_indices:
32
+ placeholder = placeholders[idx]
33
+ if stack_trace := get_mutating_use_stack_trace(placeholder):
34
+ break
35
+
36
+ if stack_trace:
37
+ msg = f"skipping cudagraphs due to mutation on input. Found from : \n {stack_trace}"
38
+ return msg
39
+
40
+ return format_default_skip_message("mutated inputs")
41
+
42
+
43
+ def check_for_mutation(
44
+ gm: torch.fx.GraphModule, compiled_graph: CompiledFxGraph, num_fixed: int
45
+ ) -> Optional[str]:
46
+ default_msg = format_default_skip_message("mutated inputs")
47
+
48
+ # doesnt work for non-trees because the warmup run would apply mutation twice
49
+ if torch._inductor.config.triton.cudagraph_trees:
50
+ # checking if mutation is only on parameters/static inputs
51
+ mutation_indices = [
52
+ idx for idx in compiled_graph.mutated_input_idxs if idx >= num_fixed
53
+ ]
54
+ has_mutation = len(mutation_indices) != 0
55
+ if not has_mutation:
56
+ return None
57
+
58
+ return get_mutation_stack_trace(gm, mutation_indices)
59
+
60
+ else:
61
+ has_mutation = len(compiled_graph.mutated_inputs) != 0
62
+ return None if not has_mutation else default_msg
63
+
64
+
65
+ def get_use_stack_trace(node) -> Optional[str]:
66
+ for use in node.users:
67
+ if stack_trace := use.meta.get("stack_trace", None):
68
+ return stack_trace
69
+ return None
70
+
71
+
72
+ def check_multiple_devices_or_any_cpu_nodes(
73
+ device_node_mapping: Dict[torch.device, torch.fx.Node]
74
+ ) -> Optional[str]:
75
+ if cpu_node := device_node_mapping.get(torch.device("cpu")):
76
+ if stack_trace := get_use_stack_trace(cpu_node):
77
+ return format_default_skip_message(
78
+ f"cpu device. Found from : \n {stack_trace}"
79
+ )
80
+
81
+ return format_default_skip_message("cpu device")
82
+
83
+ if (
84
+ len(device_node_mapping) == 1
85
+ and next(iter(device_node_mapping.keys())).type == "cuda"
86
+ ):
87
+ return None
88
+
89
+ keys_repr = (repr(key) for key in device_node_mapping.keys())
90
+ return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}")
91
+
92
+
93
+ def check_lowering_disable_cudagraph(
94
+ device_node_mapping: Dict[torch.device, torch.fx.Node]
95
+ ):
96
+ return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
97
+
98
+
99
+ @dataclasses.dataclass
100
+ class BoxedDeviceIndex:
101
+ value: Optional[int]
102
+
103
+ def set(self, device_idx: Optional[int]):
104
+ assert device_idx is None or isinstance(device_idx, int)
105
+ self.value = device_idx
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/dependencies.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import dataclasses
3
+ import itertools
4
+ import logging
5
+ import re
6
+ import typing
7
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
8
+ from unittest.mock import patch
9
+
10
+ import sympy
11
+
12
+ import torch
13
+ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
14
+
15
+ from .codegen.common import index_prevent_reordering
16
+ from .utils import (
17
+ get_dtype_size,
18
+ reduction_num_outputs,
19
+ sympy_index_symbol,
20
+ sympy_str,
21
+ sympy_subs,
22
+ VarRanges,
23
+ )
24
+ from .virtualized import OpsHandler, ReductionType, V
25
+
26
+ log = logging.getLogger(__name__)
27
+ is_indirect = re.compile(r"indirect|tmp").search
28
+ Dep = Union["MemoryDep", "StarDep", "WeakDep"]
29
+
30
+
31
+ class MemoryDep(typing.NamedTuple):
32
+ name: str
33
+ index: sympy.Expr # type: ignore[assignment]
34
+ var_names: Tuple[sympy.Symbol, ...]
35
+ size: Tuple[sympy.Expr, ...]
36
+
37
+ def __repr__(self):
38
+ return f"MemoryDep({self.name!r}, {self.index}, {self.ranges})"
39
+
40
+ @property
41
+ def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
42
+ """{c0: 128, c1: 512, ...}"""
43
+ return dict(zip(self.var_names, self.size))
44
+
45
+ def get_numel(self) -> sympy.Expr:
46
+ if self.is_indirect():
47
+ numel = V.graph.get_numel(self.name)
48
+ else:
49
+ vars = set(self.index.free_symbols)
50
+ numel = sympy.Integer(1)
51
+ for var, size in zip(self.var_names, self.size):
52
+ if var in vars:
53
+ numel = numel * size
54
+ return numel
55
+
56
+ def rename(self, renames: Dict[str, str]) -> "MemoryDep":
57
+ if self.name in renames:
58
+ return MemoryDep(
59
+ renames[self.name], self.index, var_names=self.var_names, size=self.size
60
+ )
61
+ return self
62
+
63
+ def numbytes_hint(self):
64
+ return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
65
+ V.graph.get_dtype(self.name)
66
+ )
67
+
68
+ def has_unbacked_symbols(self):
69
+ return len(free_unbacked_symbols(self.get_numel())) > 0
70
+
71
+ def is_contiguous(self) -> bool:
72
+ return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
73
+
74
+ def is_scalar(self) -> bool:
75
+ if isinstance(self.index, sympy.Symbol):
76
+ return self.index not in self.var_names and not self.is_indirect()
77
+ return isinstance(self.index, (int, sympy.Integer))
78
+
79
+ def is_indirect(self) -> bool:
80
+ return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
81
+
82
+
83
+ class StarDep(typing.NamedTuple):
84
+ # depends on the entire buffer
85
+ name: str
86
+
87
+ @property
88
+ def index(self):
89
+ raise NotImplementedError("StarDep does not have an index")
90
+
91
+ def get_numel(self) -> sympy.Expr:
92
+ return V.graph.get_numel(self.name)
93
+
94
+ def rename(self, renames: Dict[str, str]) -> "StarDep":
95
+ if self.name in renames:
96
+ return StarDep(renames[self.name])
97
+ return self
98
+
99
+ def numbytes_hint(self):
100
+ return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size(
101
+ V.graph.get_dtype(self.name)
102
+ )
103
+
104
+ def has_unbacked_symbols(self):
105
+ return len(free_unbacked_symbols(self.get_numel())) > 0
106
+
107
+ def is_contiguous(self) -> bool:
108
+ return False
109
+
110
+ def is_scalar(self) -> bool:
111
+ return False
112
+
113
+ def is_indirect(self) -> bool:
114
+ return False
115
+
116
+
117
+ # Used for tracking mutation ordering
118
+ # if A reads a buffer and B mutates it
119
+ # B must be ordered after A
120
+ #
121
+ # It is weak because if it turns out A's read is never used, we can still
122
+ # eliminate it
123
+ class WeakDep(typing.NamedTuple):
124
+ name: str
125
+
126
+ @property
127
+ def index(self):
128
+ raise NotImplementedError("WeakDep does not have an index")
129
+
130
+ def get_numel(self) -> sympy.Expr:
131
+ return sympy.Integer(1)
132
+
133
+ def rename(self, renames: Dict[str, str]) -> "WeakDep":
134
+ if self.name in renames:
135
+ return WeakDep(renames[self.name])
136
+ return self
137
+
138
+ def numbytes_hint(self):
139
+ return 1 # Purely inserted for ordering, not an actual dep
140
+
141
+ def has_unbacked_symbols(self):
142
+ return False
143
+
144
+ def is_contiguous(self) -> bool:
145
+ return False
146
+
147
+
148
+ class IndexExprDep(typing.NamedTuple):
149
+ index: sympy.Expr # type: ignore[assignment]
150
+ var_names: Tuple[sympy.Symbol, ...]
151
+ size: Tuple[sympy.Expr, ...]
152
+
153
+
154
+ @dataclasses.dataclass
155
+ class ReadWrites:
156
+ reads: Set[Dep]
157
+ writes: Set[Dep]
158
+ index_exprs: Set[IndexExprDep]
159
+ range_vars: Optional[List[sympy.Expr]] = None
160
+ var_ranges: Optional[VarRanges] = None
161
+ op_counts: typing.Counter[str] = dataclasses.field(
162
+ default_factory=collections.Counter
163
+ )
164
+
165
+ def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
166
+ return ReadWrites(
167
+ {dep.rename(renames) for dep in self.reads},
168
+ {dep.rename(renames) for dep in self.writes},
169
+ self.index_exprs,
170
+ self.range_vars,
171
+ self.var_ranges,
172
+ op_counts=self.op_counts,
173
+ )
174
+
175
+ def with_read(self, dep: Dep) -> "ReadWrites":
176
+ assert isinstance(dep, (WeakDep, StarDep))
177
+ return ReadWrites(
178
+ set.union(self.reads, {dep}),
179
+ self.writes,
180
+ self.index_exprs,
181
+ self.range_vars,
182
+ self.var_ranges,
183
+ op_counts=self.op_counts,
184
+ )
185
+
186
+ def merge(self, other: "ReadWrites"):
187
+ reads = set.union(self.reads, other.reads)
188
+ writes = set.union(self.writes, other.writes)
189
+ index_exprs = set.union(self.index_exprs, other.index_exprs)
190
+ op_counts = collections.Counter(self.op_counts)
191
+ op_counts.update(other.op_counts)
192
+ return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts)
193
+
194
+ @staticmethod
195
+ def merge_list(read_writes: List["ReadWrites"]):
196
+ all_writes = set.union(*[rw.writes for rw in read_writes])
197
+ all_reads = set.union(*[rw.reads for rw in read_writes]) - all_writes
198
+ all_index_exprs = set.union(*[rw.index_exprs for rw in read_writes])
199
+
200
+ op_counts: typing.Counter[Any] = collections.Counter()
201
+ for rw in read_writes:
202
+ op_counts.update(rw.op_counts)
203
+
204
+ return ReadWrites(all_reads, all_writes, all_index_exprs, op_counts=op_counts)
205
+
206
+ def remove_reads(self, rem_reads):
207
+ return ReadWrites(
208
+ self.reads - rem_reads,
209
+ self.writes,
210
+ self.index_exprs,
211
+ self.range_vars,
212
+ self.var_ranges,
213
+ op_counts=self.op_counts,
214
+ )
215
+
216
+ def reads_and_writes(self):
217
+ return itertools.chain(self.reads, self.writes)
218
+
219
+
220
+ class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
221
+ def __init__(self, var_ranges: VarRanges, normalize: bool):
222
+ super().__init__()
223
+ self._reads: Set[Dep] = set()
224
+ self._writes: Set[MemoryDep] = set()
225
+ self._index_exprs: Set[IndexExprDep] = set()
226
+ self._var_ranges: VarRanges = var_ranges
227
+ self._normalize: bool = normalize
228
+
229
+ def canonicalize(
230
+ self, index: sympy.Expr
231
+ ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]:
232
+ if not self._normalize:
233
+ sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
234
+ var_names = tuple(
235
+ k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1
236
+ )
237
+ sizes = tuple(v for v in sizes if v != 1)
238
+ return index, var_names, sizes # type: ignore[return-value]
239
+
240
+ # Try to further simplify the indexes even if simplify_loops didn't
241
+ # convert it to the simplest form because of the interference from
242
+ # different indexing formulas.
243
+ free_symbols = index.free_symbols
244
+ var_ranges = {
245
+ k: V.graph.sizevars.simplify(v)
246
+ for k, v in self._var_ranges.items()
247
+ # TODO(jansel): explore this further normalization
248
+ # if k in free_symbols
249
+ }
250
+ index_vars = [*var_ranges.keys()]
251
+ sizes = tuple(var_ranges.values())
252
+ new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
253
+ index_vars,
254
+ sizes,
255
+ index_prevent_reordering([index], index_vars, sizes),
256
+ )
257
+
258
+ # assign new variables each dimension to deal with numbering mismatches
259
+ # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
260
+ new_vars, add_var = var_builder(canonicalization_prefix())
261
+ replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
262
+ index = sympy_subs(sympy.expand(index), replacement)
263
+
264
+ new_vars = [*new_vars.keys()]
265
+ new_sizes = [*new_sizes]
266
+ free_symbols = index.free_symbols
267
+ while new_vars and new_vars[-1] not in free_symbols:
268
+ # Reduction has last (reduced) dim in its sizes, but
269
+ # downstream users won't. Normalize this away.
270
+ new_vars.pop()
271
+ new_sizes.pop()
272
+ return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
273
+
274
+ def load(self, name: str, index: sympy.Expr) -> str:
275
+ self._reads.add(MemoryDep(name, *self.canonicalize(index)))
276
+ return f"load({name}, {sympy_str(index)})"
277
+
278
+ def load_seed(self, name: str, index: int):
279
+ assert isinstance(index, int)
280
+ return self.load(name, sympy.Integer(index))
281
+
282
+ def store(self, name: str, index: sympy.Expr, value: str, mode=None) -> str:
283
+ self._writes.add(MemoryDep(name, *self.canonicalize(index)))
284
+ return f"store({name}, {sympy_str(index)}, {value}, {mode})"
285
+
286
+ def store_reduction(self, name: str, index, value) -> str:
287
+ return self.store(name, index, f"store_reduction({value})")
288
+
289
+ def index_expr(self, index: sympy.Expr, dtype) -> str:
290
+ self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
291
+ return f"index_expr({sympy_str(index)}, {dtype})"
292
+
293
+ def bucketize(
294
+ self,
295
+ values,
296
+ offsets_name: str,
297
+ offsets_size: sympy.Expr,
298
+ indexing_dtype: torch.dtype,
299
+ right: bool,
300
+ ):
301
+ self._reads.add(StarDep(offsets_name))
302
+ return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})"
303
+
304
+
305
+ class _OpCounter:
306
+ """Shim to count how many times each op is used"""
307
+
308
+ def __init__(self, inner):
309
+ super().__init__()
310
+ self.parent_handler = inner
311
+ self._op_counts: typing.Counter[Any] = collections.Counter()
312
+
313
+ def __getattr__(self, name):
314
+ self._op_counts[name] += 1
315
+ return getattr(self.parent_handler, name)
316
+
317
+
318
+ class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined]
319
+ def __init__(self, var_ranges: VarRanges, normalize: bool):
320
+ parent_handler = _RecordLoadStoreInner(
321
+ var_ranges=var_ranges, normalize=normalize
322
+ )
323
+ parent_handler = _OpCounter(parent_handler)
324
+ super().__init__(parent_handler=parent_handler)
325
+
326
+
327
+ def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
328
+ cnt = itertools.count()
329
+ var_ranges: VarRanges = dict()
330
+
331
+ def add_var(length: sympy.Expr) -> sympy.Symbol:
332
+ v = sympy_index_symbol(f"{prefix}{next(cnt)}")
333
+ var_ranges[v] = length
334
+ return v
335
+
336
+ return var_ranges, add_var
337
+
338
+
339
+ def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
340
+ var_ranges, add_var = var_builder(prefix)
341
+ args: List[List[sympy.Symbol]] = []
342
+ for size in argsizes:
343
+ args.append(list(map(add_var, size)))
344
+ return args, var_ranges
345
+
346
+
347
+ def index_vars_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str = "d"):
348
+ from .ir import SqueezeView
349
+
350
+ var_ranges, add_var = var_builder(prefix)
351
+ args: List[List[sympy.Expr]] = []
352
+ new_sizes: List[List[sympy.Expr]] = []
353
+ for size in argsizes:
354
+ new_size, reindex = SqueezeView.squeezer(size)
355
+ new_sizes.append(new_size)
356
+ args.append(reindex(list(map(add_var, new_size))))
357
+ return args, var_ranges
358
+
359
+
360
+ def extract_read_writes(
361
+ fn: Callable[..., Any],
362
+ *argsizes: Tuple[sympy.Expr, ...],
363
+ normalize: bool = False,
364
+ prefix: str = "d",
365
+ ):
366
+ args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
367
+ rw = RecordLoadStore(var_ranges, normalize=normalize)
368
+ with V.set_ops_handler(rw):
369
+ fn(*args)
370
+
371
+ if normalize:
372
+ range_vars = [] # Number of vars could differ due to normalization
373
+ else:
374
+ range_vars = list(itertools.chain.from_iterable(args))
375
+
376
+ inner = rw.parent_handler.parent_handler
377
+ return ReadWrites(
378
+ set(inner._reads),
379
+ set(inner._writes),
380
+ inner._index_exprs,
381
+ range_vars,
382
+ var_ranges,
383
+ rw.parent_handler._op_counts,
384
+ )
385
+
386
+
387
+ def extract_input_node_reduction_ranges(
388
+ input_node: "torch._inductor.ir.TensorBox",
389
+ ) -> Tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
390
+ """
391
+ Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
392
+ It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
393
+ In this case, reduction_sizes of the Reduction nodes need to be the same.
394
+ Otherwise returns (None, None).
395
+ """
396
+
397
+ from .ir import ComputedBuffer, Loops
398
+
399
+ if isinstance(input_node.data, ComputedBuffer):
400
+ # Input node has already been realized. Return its size and reduction_size.
401
+ size = input_node.get_size()
402
+ reduction_size = input_node.get_reduction_size()
403
+ if len(reduction_size) > 0:
404
+ return (size, reduction_size)
405
+ else:
406
+ return (None, None)
407
+
408
+ if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined]
409
+ # Other IRNodes do not have reduction_ranges.
410
+ return (None, None)
411
+
412
+ # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes?
413
+ # The current method still uses reduction ranges from the dependent realized node, which is not ideal.
414
+ # Is there a way to check whether there are permutations inbetween?
415
+ reads = input_node.get_reads()
416
+ reduction_size = None
417
+ size = None
418
+ while reduction_size is None and len(reads) > 0:
419
+ seen = set()
420
+ new_reads = []
421
+ for read in reads:
422
+ if not isinstance(read, MemoryDep):
423
+ continue
424
+ if read.name in seen:
425
+ continue
426
+ seen.add(read.name)
427
+ buffer = V.graph.get_buffer(read.name)
428
+ if buffer is None:
429
+ continue
430
+ if (
431
+ isinstance(buffer, ComputedBuffer)
432
+ and len(buffer.get_reduction_size()) > 0
433
+ ):
434
+ if reduction_size is None:
435
+ reduction_size = buffer.get_reduction_size()
436
+ size = buffer.get_size()
437
+ elif (
438
+ reduction_size != buffer.get_reduction_size()
439
+ or size != buffer.get_size()
440
+ ):
441
+ return (None, None)
442
+ else:
443
+ new_reads.extend(buffer.get_reads())
444
+ if reads == new_reads:
445
+ return (size, reduction_size)
446
+ else:
447
+ reads = new_reads
448
+ return (size, reduction_size)
449
+
450
+
451
+ def canonicalization_prefix():
452
+ return "c"
453
+
454
+
455
+ # ops handler which computes all the free unbacked symbols for an IR
456
+ class FreeUnbackedSymbolsOpsHandler:
457
+ symbols: Set[sympy.Symbol]
458
+
459
+ def __init__(self):
460
+ self.symbols = set()
461
+
462
+ def __getattr__(self, name: str) -> Callable[..., Any]:
463
+ def inner(*args, **kwargs):
464
+ for a in itertools.chain(args, kwargs.values()):
465
+ if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
466
+ self.symbols |= free_unbacked_symbols(a)
467
+
468
+ return inner
469
+
470
+ def indirect_indexing(self, index_var, size, check=True) -> sympy.Symbol:
471
+ assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
472
+ self.symbols |= free_unbacked_symbols(size)
473
+ return sympy_index_symbol(f"({str(index_var)})")
474
+
475
+ def frexp(self, x):
476
+ return (None,) * 2
477
+
478
+ def reduction(
479
+ self,
480
+ dtype: torch.dtype,
481
+ src_dtype: torch.dtype,
482
+ reduction_type: ReductionType,
483
+ value: Union[None, Tuple[None, ...]],
484
+ ) -> Union[None, Tuple[None, ...]]:
485
+ num_values = reduction_num_outputs(reduction_type)
486
+ return (None,) * num_values if num_values > 1 else None
487
+
488
+
489
+ def _typecheck_FreeUnbackedSymbolsOpsHandler(
490
+ h: FreeUnbackedSymbolsOpsHandler,
491
+ ) -> OpsHandler[None]:
492
+ return h
493
+
494
+
495
+ def extract_free_unbacked_symbols(fn: Callable[..., Any], index, rindex=None):
496
+ from .ir import FlexibleLayout
497
+
498
+ args = [index, rindex] if rindex is not None else [index]
499
+ handler = FreeUnbackedSymbolsOpsHandler()
500
+ # NB: I cargo culted the allow_indexing patch here, I don't understand why
501
+ # people do this all over
502
+ with V.set_ops_handler(handler), patch.object(
503
+ FlexibleLayout, "allow_indexing", True
504
+ ):
505
+ fn(*args)
506
+ return handler.symbols
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/exc.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import tempfile
5
+ import textwrap
6
+ from functools import lru_cache
7
+
8
+ if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
9
+
10
+ @lru_cache(None)
11
+ def _record_missing_op(target):
12
+ with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd:
13
+ fd.write(str(target) + "\n")
14
+
15
+ else:
16
+
17
+ def _record_missing_op(target): # type: ignore[misc]
18
+ pass
19
+
20
+
21
+ class OperatorIssue(RuntimeError):
22
+ @staticmethod
23
+ def operator_str(target, args, kwargs):
24
+ lines = [f"target: {target}"] + [
25
+ f"args[{i}]: {arg}" for i, arg in enumerate(args)
26
+ ]
27
+ if kwargs:
28
+ lines.append(f"kwargs: {kwargs}")
29
+ return textwrap.indent("\n".join(lines), " ")
30
+
31
+
32
+ class MissingOperatorWithoutDecomp(OperatorIssue):
33
+ def __init__(self, target, args, kwargs):
34
+ _record_missing_op(target)
35
+ super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
36
+
37
+
38
+ class MissingOperatorWithDecomp(OperatorIssue):
39
+ def __init__(self, target, args, kwargs):
40
+ _record_missing_op(target)
41
+ super().__init__(
42
+ f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
43
+ + textwrap.dedent(
44
+ f"""
45
+
46
+ There is a decomposition available for {target} in
47
+ torch._decomp.get_decompositions(). Please add this operator to the
48
+ `decompositions` list in torch._inductor.decompositions
49
+ """
50
+ )
51
+ )
52
+
53
+
54
+ class LoweringException(OperatorIssue):
55
+ def __init__(self, exc: Exception, target, args, kwargs):
56
+ super().__init__(
57
+ f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
58
+ )
59
+
60
+
61
+ class InvalidCxxCompiler(RuntimeError):
62
+ def __init__(self):
63
+ from . import config
64
+
65
+ super().__init__(
66
+ f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}"
67
+ )
68
+
69
+
70
+ class CppWrapperCodeGenError(RuntimeError):
71
+ def __init__(self, msg: str):
72
+ super().__init__(f"C++ wrapper codegen error: {msg}")
73
+
74
+
75
+ class CppCompileError(RuntimeError):
76
+ def __init__(self, cmd: list[str], output: str):
77
+ if isinstance(output, bytes):
78
+ output = output.decode("utf-8")
79
+
80
+ super().__init__(
81
+ textwrap.dedent(
82
+ """
83
+ C++ compile error
84
+
85
+ Command:
86
+ {cmd}
87
+
88
+ Output:
89
+ {output}
90
+ """
91
+ )
92
+ .strip()
93
+ .format(cmd=" ".join(cmd), output=output)
94
+ )
95
+
96
+
97
+ class CUDACompileError(CppCompileError):
98
+ pass
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/fx_utils.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ from collections import defaultdict
3
+ from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
4
+
5
+ import torch
6
+ import torch.fx
7
+ from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
8
+ from torch.utils import _pytree as pytree
9
+ from torch.utils._pytree import tree_map
10
+ from .virtualized import V
11
+
12
+
13
+ # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
14
+ # Works for length 2 patterns with 1 module and 1 function/method.
15
+ def matches_module_function_pattern(
16
+ pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
17
+ node: torch.fx.node.Node,
18
+ modules: Dict[str, torch.nn.modules.Module],
19
+ ) -> bool:
20
+ if len(node.args) == 0:
21
+ return False
22
+ if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
23
+ node, torch.fx.Node
24
+ ):
25
+ return False
26
+ # the first node is call_module
27
+ if node.args[0].op != "call_module":
28
+ return False
29
+ if not isinstance(node.args[0].target, str):
30
+ return False
31
+ if node.args[0].target not in modules:
32
+ return False
33
+ if type(modules[node.args[0].target]) is not pattern[0]:
34
+ return False
35
+ # the second node is call_function or call_method
36
+ if node.op != "call_function" and node.op != "call_method":
37
+ return False
38
+ if node.target != pattern[1]:
39
+ return False
40
+ # make sure node.args[0] output is only used by current node.
41
+ if len(node.args[0].users) > 1:
42
+ return False
43
+ return True
44
+
45
+
46
+ class FakeTensorUpdater:
47
+ """
48
+ The main idea here is that it's difficult to maintain accurate fake
49
+ tensors (our primary form of metadata) for each node in our graph as we
50
+ transform it.
51
+
52
+ The most reliable way to obtain this information is by rerunning
53
+ faketensor propagation. However, in general, faketensor propagation is
54
+ fairly expensive. So, instead we'd like to only rerun faketensor
55
+ propagation on nodes that have changed.
56
+
57
+ In order to detect which nodes have changed, we first hash its node,
58
+ target, and argument lists (which are immutable in FX).
59
+
60
+ Then, whenever we call incremental_update, we check which FX nodes have a
61
+ new hash, and recompute the faketensor metadata for that node. Then, we
62
+ continue to recursively compute the faketensors for all users until the
63
+ fake tensors stop changing.
64
+ """
65
+
66
+ def __init__(self, graph: torch.fx.Graph):
67
+ self.processed_hashes = set()
68
+ self.graph = graph
69
+
70
+ for node in self.graph.nodes:
71
+ self.processed_hashes.add(self.hash_node(node))
72
+
73
+ def hash_node(self, node: torch.fx.Node):
74
+ # todo(chilli): Not a great hash function
75
+ return (node, node.target, id(node.args), id(node.kwargs))
76
+
77
+ def incremental_update(self):
78
+ processed = set()
79
+ existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
80
+ for node in self.graph.nodes:
81
+ existing_storages[get_node_storage(node)] += 1
82
+
83
+ def is_intlist_same(new, old):
84
+ return statically_known_true(sym_eq(new, old))
85
+
86
+ def is_fake_tensor_same(new, old):
87
+ if type(new) != type(old):
88
+ return False
89
+ if isinstance(new, (list, tuple)):
90
+ if len(new) != len(old):
91
+ return False
92
+ return all(
93
+ is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
94
+ )
95
+ assert isinstance(new, torch.Tensor)
96
+ if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
97
+ return False
98
+ if new.layout == torch.strided and (
99
+ not is_intlist_same(new.stride(), old.stride())
100
+ or not statically_known_true(
101
+ new.storage_offset() == old.storage_offset()
102
+ )
103
+ ):
104
+ return False
105
+
106
+ if get_storage(new) == get_storage(old):
107
+ return True
108
+
109
+ # This is the case where it returns a completely fresh storage that's used nowhere else.
110
+ if (
111
+ existing_storages[get_storage(old)] == 1
112
+ and get_storage(new) not in existing_storages
113
+ ):
114
+ return True
115
+ return False
116
+
117
+ for node in self.graph.nodes:
118
+ if self.hash_node(node) in self.processed_hashes:
119
+ continue
120
+
121
+ def is_aten_node(node):
122
+ return node.op == "call_function" and isinstance(
123
+ node.target, torch._ops.OpOverload
124
+ )
125
+
126
+ if not is_aten_node(node):
127
+ continue
128
+
129
+ processing = [node]
130
+ while len(processing) > 0:
131
+ updating_node = processing.pop()
132
+ if updating_node in processed:
133
+ continue
134
+ if is_aten_node(updating_node):
135
+ continue
136
+
137
+ is_valid, args, kwargs = get_fake_args_kwargs(updating_node)
138
+ if not is_valid:
139
+ continue
140
+ with V.fake_mode:
141
+ new_fake_tensor = updating_node.target(*args, **kwargs)
142
+ if "val" in updating_node.meta and is_fake_tensor_same(
143
+ new_fake_tensor, updating_node.meta["val"]
144
+ ):
145
+ continue
146
+ updating_node.meta["val"] = new_fake_tensor
147
+
148
+ # todo(chilli): This code path is not exercised by our existing
149
+ # tests - add a test
150
+ existing_storages[get_node_storage(new_fake_tensor)] += 1
151
+ processed.add(updating_node)
152
+ processing.extend(updating_node.users)
153
+
154
+ self.processed_hashes.add(self.hash_node(updating_node))
155
+
156
+
157
+ def get_storage(t: torch.Tensor) -> int:
158
+ return t.untyped_storage()._cdata
159
+
160
+
161
+ def get_node_storage(node: torch.fx.Node) -> Optional[int]:
162
+ if "val" not in node.meta:
163
+ return None
164
+ if not isinstance(node.meta["val"], torch.Tensor):
165
+ return None
166
+ if not torch._C._has_storage(node.meta["val"]):
167
+ return None
168
+ return get_storage(node.meta["val"])
169
+
170
+
171
+ def get_fake(x):
172
+ if isinstance(x, torch.fx.Node):
173
+ if "val" not in x.meta:
174
+ return x
175
+ return x.meta["val"]
176
+ return x
177
+
178
+
179
+ def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
180
+ """
181
+ First value returns a boolean if any of the input nodes don't have a faketensor.
182
+ """
183
+ args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
184
+ if any(
185
+ isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
186
+ ):
187
+ return False, args, kwargs
188
+ return True, args, kwargs
189
+
190
+
191
+ def is_node_realized(node: torch.fx.Node) -> bool:
192
+ """Returns true if a node is always realized when lowered to inductor IR.
193
+
194
+ NOTE: This may return some false negatives. e.g. it doesn't
195
+ handle buffers realized heuristically during lowering, or
196
+ buffers realized indirectly through view ops.
197
+ """
198
+ from torch._inductor.lowering import fallbacks, needs_realized_inputs
199
+
200
+ def is_buffer(node: torch.fx.Node) -> bool:
201
+ if node.op == "call_function" and node.target is operator.getitem:
202
+ # For nodes with multiple outputs, we get the fx graph:
203
+ # foo = torch.ops.aten.foo(...)
204
+ # getitem = foo[0]
205
+ # getitem_1 = foo[1]
206
+ # where we need to check if foo is a fallback kernel
207
+ return is_buffer(node.args[0]) # type: ignore[arg-type]
208
+ return node.op in ("placeholder", "output") or node.target in fallbacks
209
+
210
+ if is_buffer(node):
211
+ return True
212
+
213
+ def realizes_inputs(node: torch.fx.Node) -> bool:
214
+ return node.op == "output" or node.target in needs_realized_inputs
215
+
216
+ if any(realizes_inputs(user) for user in node.users):
217
+ return True
218
+
219
+ # Otherwise, assume node isn't realized
220
+ return False
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/index_propagation.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file implements the IndexPropagation ops handler, which wraps an
2
+ underlying handler to add a limited form of constant propagation, as well as
3
+ propagation of sympy expressions downstream of ops.index_expr calls.
4
+
5
+ For example, say we have the IR:
6
+
7
+ tmp0 = ops.index_expr(x, torch.int32)
8
+ tmp1 = ops.constant(2, torch.int32)
9
+ tmp2 = ops.mul(tmp0, tmp1)
10
+ tmp3 = ops.indirect_indexing(tmp2, x_size)
11
+ tmp4 = ops.load("buf0", tmp3)
12
+
13
+ The underlying handler would just see:
14
+
15
+ ops.load("buf0", x * 2)
16
+
17
+ This is limited by the set of operators handled in the sympy expression
18
+ printers. So simple operations like minimum and maximum cannot be translated to
19
+ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
20
+
21
+ """
22
+ import itertools
23
+ from dataclasses import dataclass
24
+ from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union
25
+
26
+ import sympy
27
+
28
+ from typing_extensions import TypeAlias
29
+
30
+ import torch
31
+ from torch._prims_common import is_boolean_dtype, is_integer_dtype
32
+ from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
33
+
34
+
35
+ @dataclass
36
+ class TypedExpr:
37
+ """A SymPy expression with associated type"""
38
+
39
+ expr: sympy.Expr
40
+ dtype: torch.dtype
41
+
42
+
43
+ class SymPyOps:
44
+ """An ops handler where all IR values are SymPy expressions
45
+
46
+ When a value cannot be represented as a SymPy expression, the method is
47
+ either not defined, or returns NotImplemented
48
+
49
+ """
50
+
51
+ @staticmethod
52
+ def identity(value: Any) -> Any:
53
+ return value
54
+
55
+ @staticmethod
56
+ def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr:
57
+ if is_boolean_dtype(dtype):
58
+ expr = sympy.Integer(bool(value))
59
+ elif is_integer_dtype(dtype):
60
+ expr = sympy.Integer(int(value))
61
+ else:
62
+ expr = sympy.Float(float(value))
63
+ return TypedExpr(expr, dtype)
64
+
65
+ @staticmethod
66
+ def index_expr(value: sympy.Expr, dtype: torch.dtype) -> Union[int, TypedExpr]:
67
+ if isinstance(value, int):
68
+ value = sympy.Integer(value)
69
+ return TypedExpr(value, dtype)
70
+
71
+ @staticmethod
72
+ def to_dtype(
73
+ value: Any, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None
74
+ ) -> Union[int, TypedExpr]:
75
+ if isinstance(value.expr, (sympy.Integer, sympy.Float)):
76
+ return SymPyOps.constant(value.expr, dtype)
77
+ elif is_integer_dtype(dtype) and is_integer_dtype(value.dtype):
78
+ return SymPyOps.index_expr(value.expr, dtype)
79
+ else:
80
+ # TODO: Inductor doesn't handle floating point in sympy expressions well at the moment
81
+ return NotImplemented
82
+
83
+ @staticmethod
84
+ def square(x: TypedExpr) -> TypedExpr:
85
+ return TypedExpr(x.expr * x.expr, x.dtype)
86
+
87
+ @staticmethod
88
+ def add(x: TypedExpr, y: TypedExpr) -> TypedExpr:
89
+ result_type = torch.promote_types(x.dtype, y.dtype)
90
+ return TypedExpr(x.expr + y.expr, result_type)
91
+
92
+ @staticmethod
93
+ def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr:
94
+ result_type = torch.promote_types(x.dtype, y.dtype)
95
+ return TypedExpr(x.expr - y.expr, result_type)
96
+
97
+ @staticmethod
98
+ def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr:
99
+ result_type = torch.promote_types(x.dtype, y.dtype)
100
+ return TypedExpr(x.expr * y.expr, result_type)
101
+
102
+ @staticmethod
103
+ def neg(x: TypedExpr) -> TypedExpr:
104
+ return TypedExpr(-x.expr, x.dtype)
105
+
106
+ @staticmethod
107
+ def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr:
108
+ result_type = torch.promote_types(x.dtype, y.dtype)
109
+ if not is_integer_dtype(result_type):
110
+ return NotImplemented
111
+
112
+ return TypedExpr(FloorDiv(x.expr, y.expr), result_type)
113
+
114
+ @staticmethod
115
+ def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
116
+ result_type = torch.promote_types(x.dtype, y.dtype)
117
+ if not is_integer_dtype(result_type):
118
+ return NotImplemented
119
+
120
+ result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
121
+ return TypedExpr(result_expr, result_type)
122
+
123
+ @staticmethod
124
+ def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
125
+ result_type = torch.promote_types(x.dtype, y.dtype)
126
+ if not is_integer_dtype(result_type):
127
+ return NotImplemented
128
+ # In these cases, remainder in Python == remainder in C++, so this transformation
129
+ # is sound
130
+ if (
131
+ x.expr.is_nonnegative is not None
132
+ and x.expr.is_nonnegative == y.expr.is_positive
133
+ ):
134
+ result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
135
+ return TypedExpr(result_expr, result_type)
136
+ return NotImplemented
137
+
138
+ @staticmethod
139
+ def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
140
+ result_type = torch.promote_types(x.dtype, y.dtype)
141
+ return TypedExpr(sympy.Min(x.expr, y.expr), result_type)
142
+
143
+ @staticmethod
144
+ def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
145
+ result_type = torch.promote_types(x.dtype, y.dtype)
146
+ return TypedExpr(sympy.Max(x.expr, y.expr), result_type)
147
+
148
+
149
+ @dataclass
150
+ class IndexPropVar:
151
+ value: Any # Either an IR value, or TypedExpr if is_symbolic is true
152
+ is_symbolic: bool = False
153
+
154
+ @staticmethod
155
+ def new_symbolic(expr: TypedExpr) -> "IndexPropVar":
156
+ return IndexPropVar(expr, is_symbolic=True)
157
+
158
+ def __post_init__(self):
159
+ assert not self.is_symbolic or isinstance(
160
+ self.value, TypedExpr
161
+ ), "Symbolic IndexPropVar must contain a TypedExpr"
162
+
163
+
164
+ IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]]
165
+
166
+
167
+ class IndexPropagation:
168
+ """Ops wrapper that tries to propagate constant and index_expr values through the computation.
169
+
170
+ This aims to maximize the compile time simplification possible, and convert
171
+ indirect indexing from arange into normal static indexing.
172
+
173
+ """
174
+
175
+ def __init__(self, inner: Any):
176
+ self._inner = inner
177
+
178
+ def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any:
179
+ # Construct a new constant/index_expr from the SymPy expression
180
+ if isinstance(expr, sympy.Integer):
181
+ return self._inner.constant(int(expr), dtype)
182
+ elif expr.is_number:
183
+ return self._inner.constant(float(expr), dtype)
184
+ return self._inner.index_expr(expr, dtype)
185
+
186
+ def unwrap(self, a: Union[Any, IndexPropVar]) -> Any:
187
+ if isinstance(a, (list, tuple)):
188
+ return tuple(self.unwrap(v) for v in a)
189
+
190
+ if not isinstance(a, IndexPropVar):
191
+ return a
192
+
193
+ # Prefer the sympy representation if possible
194
+ if a.is_symbolic:
195
+ return self.materialize_expr(a.value.expr, a.value.dtype)
196
+
197
+ return a.value
198
+
199
+ def wrap(self, a) -> IndexPropResult:
200
+ if isinstance(a, (list, tuple)):
201
+ return tuple(self.wrap(v) for v in a)
202
+ return IndexPropVar(a)
203
+
204
+ @overload
205
+ def fallback(
206
+ self,
207
+ name: Literal["indirect_indexing"],
208
+ args: Tuple[Any, ...],
209
+ kwargs: Dict[str, Any],
210
+ ) -> IndexPropVar:
211
+ ...
212
+
213
+ @overload
214
+ def fallback(
215
+ self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
216
+ ) -> IndexPropResult:
217
+ ...
218
+
219
+ def fallback(
220
+ self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
221
+ ) -> IndexPropResult:
222
+ # Fallback to the wrapped handler
223
+ new_args = [self.unwrap(a) for a in args]
224
+ new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()}
225
+ return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
226
+
227
+ def propagate_sympy(
228
+ self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
229
+ ) -> IndexPropResult:
230
+ # Build a new SymPy expression from this ops call
231
+ def unwrap(a: Union[Any, IndexPropVar]) -> Any:
232
+ if not isinstance(a, IndexPropVar):
233
+ return a
234
+ return a.value
235
+
236
+ new_args = [unwrap(a) for a in args]
237
+ new_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
238
+ new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs)
239
+ is_valid_expr = new_expr is not NotImplemented and (
240
+ # Inductor doesn't expect floating point in sympy expressions, but
241
+ # allow floating point constants to be propagated
242
+ isinstance(new_expr.expr, sympy.Number)
243
+ or new_expr.expr.is_integer
244
+ )
245
+ if not is_valid_expr:
246
+ return self.fallback(name, args, kwargs)
247
+ return IndexPropVar.new_symbolic(new_expr)
248
+
249
+ def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
250
+ def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
251
+ if not hasattr(SymPyOps, name):
252
+ return self.fallback(name, args, kwargs)
253
+
254
+ var_arguments = [
255
+ a
256
+ for a in itertools.chain(args, kwargs.values())
257
+ if isinstance(a, IndexPropVar)
258
+ ]
259
+ if not all(v.is_symbolic for v in var_arguments):
260
+ return self.fallback(name, args, kwargs)
261
+
262
+ return self.propagate_sympy(name, args, kwargs)
263
+
264
+ return inner
265
+
266
+ def indirect_indexing(
267
+ self, index: Union[Any, IndexPropVar], size: Any, check: bool = True
268
+ ) -> Any:
269
+ # nb. We do index + Where(...) rather than Where(idx >= 0, idx, idx + sz) because we don't have CSE
270
+ # for SymPy expressions, so we don't want to repeat idx too much
271
+
272
+ # indirect_indexing returns a sympy value, so no need to wrap in IndexPropVar here
273
+ if isinstance(index, IndexPropVar) and index.is_symbolic:
274
+ # If we are turning a indirect indexing into direct, we need to wrap it.
275
+ index = index.value.expr
276
+ return index + Where(index >= 0, 0, size)
277
+ return self.fallback("indirect_indexing", (index, size, check), {}).value
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/metrics.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import inspect
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass
8
+ from functools import lru_cache
9
+
10
+ from typing import Dict, List, Set, Tuple, TYPE_CHECKING, Union
11
+
12
+ from torch._inductor import config
13
+ from torch._inductor.utils import get_benchmark_name
14
+
15
+ # Prevent circular import
16
+ if TYPE_CHECKING:
17
+ from torch._inductor.scheduler import (
18
+ BaseSchedulerNode,
19
+ ExternKernelSchedulerNode,
20
+ NopKernelSchedulerNode,
21
+ SchedulerNode,
22
+ )
23
+
24
+ # counter for tracking how many kernels have been generated
25
+ generated_kernel_count = 0
26
+ generated_cpp_vec_kernel_count = 0
27
+ num_bytes_accessed = 0
28
+ nodes_num_elem: List[
29
+ Tuple[
30
+ Union[NopKernelSchedulerNode, SchedulerNode, ExternKernelSchedulerNode],
31
+ int,
32
+ ]
33
+ ] = []
34
+ node_runtimes: List[Tuple[BaseSchedulerNode, float]] = []
35
+
36
+ # counters for tracking fusions
37
+ ir_nodes_pre_fusion = 0
38
+
39
+ # counters for tracking to_dtype inserted
40
+ cpp_to_dtype_count = 0
41
+
42
+ # counters for tracking cpp_wrapper disabled
43
+ disable_cpp_wrapper = 0
44
+
45
+
46
+ # reset all counters
47
+ def reset():
48
+ global generated_kernel_count
49
+ global generated_cpp_vec_kernel_count
50
+ global num_bytes_accessed, nodes_num_elem
51
+ global ir_nodes_pre_fusion
52
+ global cpp_to_dtype_count
53
+ global disable_cpp_wrapper
54
+
55
+ generated_kernel_count = 0
56
+ generated_cpp_vec_kernel_count = 0
57
+ num_bytes_accessed = 0
58
+ nodes_num_elem.clear()
59
+ node_runtimes.clear()
60
+ ir_nodes_pre_fusion = 0
61
+ cpp_to_dtype_count = 0
62
+ disable_cpp_wrapper = 0
63
+
64
+
65
+ @dataclass
66
+ class CachedMetricsDeltas:
67
+ """
68
+ The subset of metrics we want update across cache hits, e.g., the
69
+ FxGraphCache.
70
+ """
71
+
72
+ generated_kernel_count: int
73
+ generated_cpp_vec_kernel_count: int
74
+ ir_nodes_pre_fusion: int
75
+ cpp_to_dtype_count: int
76
+
77
+
78
+ class CachedMetricsHelper:
79
+ """
80
+ A helper class to help calculate and apply counter deltas for those
81
+ metrics we want to save with cache entries (e.g., FxGraphCache) and
82
+ apply on a cache hit.
83
+ """
84
+
85
+ def __init__(self):
86
+ global generated_kernel_count
87
+ global generated_cpp_vec_kernel_count
88
+ global ir_nodes_pre_fusion
89
+ global cpp_to_dtype_count
90
+
91
+ self.generated_kernel_count = generated_kernel_count
92
+ self.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count
93
+ self.ir_nodes_pre_fusion = ir_nodes_pre_fusion
94
+ self.cpp_to_dtype_count = cpp_to_dtype_count
95
+
96
+ def get_deltas(self) -> CachedMetricsDeltas:
97
+ global generated_kernel_count
98
+ global generated_cpp_vec_kernel_count
99
+ global ir_nodes_pre_fusion
100
+ global cpp_to_dtype_count
101
+
102
+ return CachedMetricsDeltas(
103
+ generated_kernel_count - self.generated_kernel_count,
104
+ generated_cpp_vec_kernel_count - self.generated_cpp_vec_kernel_count,
105
+ ir_nodes_pre_fusion - self.ir_nodes_pre_fusion,
106
+ cpp_to_dtype_count - self.cpp_to_dtype_count,
107
+ )
108
+
109
+ @staticmethod
110
+ def apply_deltas(delta: CachedMetricsDeltas):
111
+ global generated_kernel_count
112
+ global generated_cpp_vec_kernel_count
113
+ global ir_nodes_pre_fusion
114
+ global cpp_to_dtype_count
115
+
116
+ generated_kernel_count += delta.generated_kernel_count
117
+ generated_cpp_vec_kernel_count += delta.generated_cpp_vec_kernel_count
118
+ ir_nodes_pre_fusion += delta.ir_nodes_pre_fusion
119
+ cpp_to_dtype_count += delta.cpp_to_dtype_count
120
+
121
+
122
+ REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {}
123
+
124
+
125
+ @dataclass
126
+ class MetricTable:
127
+ table_name: str
128
+ column_names: List[str]
129
+
130
+ num_rows_added: int = 0
131
+
132
+ def add_row(self, row_fn):
133
+ if self.table_name not in enabled_metric_tables():
134
+ return
135
+
136
+ row_dict = row_fn()
137
+ assert len(self.column_names) == len(
138
+ row_dict
139
+ ), f"{len(self.column_names)} v.s. {len(row_dict)}"
140
+ assert set(self.column_names) == set(
141
+ row_dict.keys()
142
+ ), f"{set(self.column_names)} v.s. {set(row_dict.keys())}"
143
+
144
+ row = [
145
+ get_benchmark_name(),
146
+ ]
147
+ row += [row_dict[column_name] for column_name in self.column_names]
148
+ self._write_row(row)
149
+
150
+ def output_filename(self):
151
+ return f"metric_table_{self.table_name}.csv"
152
+
153
+ def write_header(self):
154
+ filename = self.output_filename()
155
+ with open(filename, "w") as fd:
156
+ writer = csv.writer(fd, lineterminator="\n")
157
+ writer.writerow(["model_name"] + self.column_names)
158
+
159
+ def _write_row(self, row):
160
+ filename = self.output_filename()
161
+ if self.num_rows_added == 0 and not os.path.exists(filename):
162
+ self.write_header()
163
+
164
+ self.num_rows_added += 1
165
+
166
+ for idx, orig_val in enumerate(row):
167
+ if isinstance(orig_val, float):
168
+ new_val = f"{orig_val:.6f}"
169
+ elif orig_val is None:
170
+ new_val = ""
171
+ else:
172
+ new_val = orig_val
173
+ row[idx] = new_val
174
+
175
+ with open(filename, "a") as fd:
176
+ writer = csv.writer(fd, lineterminator="\n")
177
+ writer.writerow(row)
178
+
179
+ @staticmethod
180
+ def register_table(name, column_names):
181
+ table = MetricTable(name, column_names)
182
+ REGISTERED_METRIC_TABLES[name] = table
183
+
184
+
185
+ MetricTable.register_table(
186
+ "slow_fusion",
187
+ [
188
+ "kernel1_path",
189
+ "kernel1_latency",
190
+ "kernel2_path",
191
+ "kernel2_latency",
192
+ "fused_kernel_path",
193
+ "fused_kernel_latency",
194
+ "slow_down_ratio",
195
+ ],
196
+ )
197
+
198
+ # track the fusion statistics for each graph
199
+ MetricTable.register_table(
200
+ "graph_stats",
201
+ [
202
+ "graph_id",
203
+ "num_nodes_before_fusion",
204
+ "num_nodes_after_fusion",
205
+ ],
206
+ )
207
+
208
+ # track the perf difference between persistent reduction and non-persistent
209
+ # reductions
210
+ MetricTable.register_table(
211
+ "persistent_red_perf",
212
+ [
213
+ "kernel1_name",
214
+ "kernel2_name",
215
+ "kernel1_latency",
216
+ "kernel2_latency",
217
+ "size_hints",
218
+ "reduction_hint",
219
+ "speedup",
220
+ ],
221
+ )
222
+
223
+ # Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint
224
+ MetricTable.register_table(
225
+ "kernel_metadata",
226
+ [
227
+ "kernel_name",
228
+ "kernel_path",
229
+ "kernel_category", # pointwise/reduction/foreach etc.
230
+ "size_hints",
231
+ "reduction_hint",
232
+ "line_of_code",
233
+ "num_load",
234
+ "num_store",
235
+ "num_for_loop",
236
+ "num_atomic_add",
237
+ "num_args",
238
+ # xyz numel can be different to size_hints since size_hints are rounded
239
+ # up to the nearest power of 2.
240
+ # Inductor kernel will burn in the xyz numel in kernel code for static
241
+ # shape kernels.
242
+ # Logging them will be helpful to find unaligned shape for reduction
243
+ "xnumel",
244
+ "ynumel",
245
+ "rnumel",
246
+ "kernel_args_num_gb",
247
+ ],
248
+ )
249
+
250
+
251
+ def _parse_kernel_fn_code(kernel_module_code):
252
+ """
253
+ The kernel_module_code is the python module that contains kernel function code.
254
+ kernel function is the proper triton kernel function annotated with
255
+ @triton.jit
256
+ """
257
+ from .codecache import PyCodeCache
258
+ from .wrapper_benchmark import get_triton_kernel
259
+
260
+ mod = PyCodeCache.load(kernel_module_code)
261
+ kernel = get_triton_kernel(mod)
262
+ # kernel is a CachingAutotune; kernel.fn is the JITFunction;
263
+ # kernel.fn.fn is the function being decorate by triton.jit
264
+ return inspect.getsource(kernel.fn.fn)
265
+
266
+
267
+ def _parse_kernel_line_of_code(proper_kernel_fn_code):
268
+ """
269
+ Return the line of code for the kernel excluding the decorators.
270
+ """
271
+ return len(proper_kernel_fn_code.splitlines())
272
+
273
+
274
+ def _parse_size_hints(kernel_module_code, kernel_category):
275
+ if kernel_category == "foreach":
276
+ # foreach kernel does not have size_hints
277
+ return None
278
+ m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code)
279
+ assert m, "size_hints missing!"
280
+ return m.group(1)
281
+
282
+
283
+ def _parse_reduction_hint(kernel_category, kernel_module_code):
284
+ if kernel_category not in ("reduction", "persistent_reduction"):
285
+ return None
286
+ m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code)
287
+ assert m, "reduction_hint not found in kernel source code!"
288
+ return m.group(1)
289
+
290
+
291
+ def _count_pattern(proper_kernel_fn_code, pattern):
292
+ return proper_kernel_fn_code.count(pattern)
293
+
294
+
295
+ def _count_args(proper_kernel_fn_code):
296
+ def_line = proper_kernel_fn_code.splitlines()[0]
297
+ assert def_line.startswith("def ")
298
+ start_idx = def_line.index("(")
299
+ end_idx = def_line.index("):")
300
+ decl_csv = def_line[start_idx + 1 : end_idx]
301
+ comps = decl_csv.split(",")
302
+ return len(comps)
303
+
304
+
305
+ def _parse_proper_kernel_fn_code(kernel_fn_code):
306
+ """
307
+ Skip decorators.
308
+ """
309
+ start_pos = kernel_fn_code.index("def ")
310
+ return kernel_fn_code[start_pos:]
311
+
312
+
313
+ def _parse_numel(proper_kernel_fn_code, numel_arg_name):
314
+ m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code)
315
+ if m:
316
+ return int(m.group(1))
317
+ else:
318
+ return None
319
+
320
+
321
+ def _parse_kernel_args_num_gb(kernel_fn_code, kernel_category):
322
+ """
323
+ inductor meta looks like:
324
+ inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0},
325
+ """
326
+ m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code)
327
+ if m:
328
+ return float(m.group(1))
329
+ else:
330
+ """
331
+ There are a few cases that kernel_num_gdb field can be missing:
332
+ 1. the field will be missing if config.benchmark_kernel and
333
+ config.profile_bandwidth are false
334
+ 2. even if config.benchmark_kernel or config.profile_bandwidth is true.
335
+ foreach kernel does not have kernel_num_gb field in the metadata
336
+ """
337
+ return None
338
+
339
+
340
+ def log_kernel_metadata(kernel_name, kernel_path, kernel_module_code):
341
+ """
342
+ An utility to log kernel metadata. We may parse metadata from kernel source code here.
343
+
344
+ It's fine to parse the generated kernel code here since the logging is
345
+ disabled by default. It would hurt compilation time.
346
+ """
347
+ from .wrapper_benchmark import get_kernel_category_by_source_code
348
+
349
+ kernel_category = get_kernel_category_by_source_code(kernel_module_code)
350
+ reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code)
351
+ size_hints = _parse_size_hints(kernel_module_code, kernel_category)
352
+ kernel_fn_code = _parse_kernel_fn_code(kernel_module_code)
353
+
354
+ proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code)
355
+
356
+ # the line of code excluding the decortors
357
+ kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code)
358
+
359
+ get_metric_table("kernel_metadata").add_row(
360
+ lambda: {
361
+ "kernel_name": kernel_name,
362
+ "kernel_path": kernel_path,
363
+ "kernel_category": kernel_category,
364
+ "size_hints": size_hints,
365
+ "reduction_hint": reduction_hint,
366
+ "line_of_code": kernel_line_of_code,
367
+ "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"),
368
+ "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"),
369
+ "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "),
370
+ "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"),
371
+ "num_args": _count_args(proper_kernel_fn_code),
372
+ "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"),
373
+ "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"),
374
+ "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"),
375
+ "kernel_args_num_gb": _parse_kernel_args_num_gb(
376
+ kernel_fn_code, kernel_category
377
+ ),
378
+ }
379
+ )
380
+
381
+
382
+ def purge_old_log_files():
383
+ """
384
+ Purge the old log file at the beginning when the benchmark script runs.
385
+ Should do it in the parent process rather than the child processes running
386
+ each individual model.
387
+ """
388
+ for name, table in REGISTERED_METRIC_TABLES.items():
389
+ if name in enabled_metric_tables():
390
+ filename = table.output_filename()
391
+ if os.path.exists(filename):
392
+ os.unlink(filename)
393
+
394
+ table.write_header()
395
+
396
+
397
+ @lru_cache
398
+ def enabled_metric_tables() -> Set[str]:
399
+ config_str = config.enabled_metric_tables
400
+
401
+ enabled = set()
402
+ for name in config_str.split(","):
403
+ name = name.strip()
404
+ if not name:
405
+ continue
406
+ assert (
407
+ name in REGISTERED_METRIC_TABLES
408
+ ), f"Metric table name {name} is not registered"
409
+ enabled.add(name)
410
+ return enabled
411
+
412
+
413
+ def is_metric_table_enabled(name):
414
+ return name in enabled_metric_tables()
415
+
416
+
417
+ def get_metric_table(name):
418
+ assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
419
+ return REGISTERED_METRIC_TABLES[name]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/triton_helpers.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+ # In the latest triton, math functions were shuffled around into different modules:
5
+ # https://github.com/openai/triton/pull/3172
6
+ if hasattr(tl.extra.cuda, "libdevice"):
7
+ libdevice = tl.extra.cuda.libdevice
8
+ math = tl.math
9
+ else:
10
+ libdevice = tl.math
11
+ math = tl
12
+
13
+
14
+ @triton.jit
15
+ def promote_to_tensor(x):
16
+ # Addition promotes to tensor for us
17
+ return x + tl.zeros((1,), tl.int1)
18
+
19
+
20
+ @triton.jit
21
+ def is_floating(x):
22
+ return promote_to_tensor(x).dtype.is_floating()
23
+
24
+
25
+ @triton.jit
26
+ def _prod_accumulate(a, b):
27
+ return a * b
28
+
29
+
30
+ @triton.jit
31
+ def prod(input, axis):
32
+ return tl.reduce(input, axis, _prod_accumulate)
33
+
34
+
35
+ @triton.jit
36
+ def minimum(a, b):
37
+ mask = a < b
38
+ if is_floating(a):
39
+ mask |= a != a
40
+ return tl.where(mask, a, b)
41
+
42
+
43
+ @triton.jit
44
+ def maximum(a, b):
45
+ mask = a > b
46
+ if is_floating(a):
47
+ mask |= a != a
48
+ return tl.where(mask, a, b)
49
+
50
+
51
+ @triton.jit
52
+ def min2(a, dim):
53
+ return tl.reduce(a, dim, minimum)
54
+
55
+
56
+ @triton.jit
57
+ def max2(a, dim):
58
+ return tl.reduce(a, dim, maximum)
59
+
60
+
61
+ @triton.jit
62
+ def minimum_with_index(a_value, a_index, b_value, b_index):
63
+ mask = a_value < b_value
64
+ equal = a_value == b_value
65
+ if is_floating(a_value):
66
+ a_isnan = a_value != a_value
67
+ b_isnan = b_value != b_value
68
+ mask |= a_isnan and not b_isnan
69
+ # Consider NaNs as equal
70
+ equal |= a_isnan and b_isnan
71
+
72
+ # Prefer lowest index if values are equal
73
+ mask |= equal & (a_index < b_index)
74
+ return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
75
+
76
+
77
+ @triton.jit
78
+ def maximum_with_index(a_value, a_index, b_value, b_index):
79
+ mask = a_value > b_value
80
+ equal = a_value == b_value
81
+ if is_floating(a_value):
82
+ a_isnan = a_value != a_value
83
+ b_isnan = b_value != b_value
84
+ mask |= a_isnan and not b_isnan
85
+ # Consider NaNs as equal
86
+ equal |= a_isnan and b_isnan
87
+
88
+ # Prefer lowest index if values are equal
89
+ mask |= equal & (a_index < b_index)
90
+ return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
91
+
92
+
93
+ @triton.jit
94
+ def min_with_index(value, index, dim):
95
+ return tl.reduce((value, index), dim, minimum_with_index)
96
+
97
+
98
+ @triton.jit
99
+ def max_with_index(value, index, dim):
100
+ return tl.reduce((value, index), dim, maximum_with_index)
101
+
102
+
103
+ @triton.jit
104
+ def welford_reduce(value, mean, m2, weight, first_iteration):
105
+ if first_iteration:
106
+ new_weight = tl.full(weight.shape, 1, weight.dtype)
107
+ new_mean = value
108
+ new_m2 = tl.zeros_like(m2)
109
+ else:
110
+ delta = value - mean
111
+ new_weight = weight + 1
112
+ new_mean = mean + delta / new_weight
113
+ new_m2 = m2 + delta * (value - new_mean)
114
+ return new_mean, new_m2, new_weight
115
+
116
+
117
+ @triton.jit
118
+ def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
119
+ delta = mean_2 - mean_1
120
+ new_weight = weight_1 + weight_2
121
+ w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
122
+ return (
123
+ mean_1 + delta * w2_over_w,
124
+ m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
125
+ new_weight,
126
+ )
127
+
128
+
129
+ @triton.jit
130
+ def welford(mean, m2, weight, dim):
131
+ return tl.reduce((mean, m2, weight), dim, welford_combine)
132
+
133
+
134
+ @triton.jit
135
+ def device_assert_then(cond, msg, r):
136
+ tl.device_assert(cond, msg)
137
+ return r
138
+
139
+
140
+ @triton.jit
141
+ def randint64(seed, offset, low, high):
142
+ r0, r1, r2, r3 = tl.randint4x(seed, offset)
143
+ r0 = r0.to(tl.uint64)
144
+ r1 = r1.to(tl.uint64)
145
+ result = r0 | (r1 << 32)
146
+ size = high - low
147
+ result = result % size.to(tl.uint64)
148
+ result = result.to(tl.int64) + low
149
+ return result
150
+
151
+
152
+ @triton.jit
153
+ def _any_combine(a, b):
154
+ return a | b
155
+
156
+
157
+ @triton.jit
158
+ def any(a, dim):
159
+ return tl.reduce(a, dim, _any_combine)
160
+
161
+
162
+ @triton.jit
163
+ def bucketize_binary_search(
164
+ values, # 1D tensor
165
+ offsets_ptr,
166
+ indexing_dtype,
167
+ right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]
168
+ OFFSETS_SIZE: int,
169
+ BLOCK_SHAPE, # tuple/list of block shape
170
+ ):
171
+ """
172
+ See [Note: Inductor bucketize op]
173
+ """
174
+
175
+ low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
176
+ high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)
177
+
178
+ full_range = OFFSETS_SIZE + 1
179
+ while full_range > 1:
180
+ mid = (high + low) // 2
181
+ mask = mid < OFFSETS_SIZE
182
+ bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)
183
+ if right:
184
+ is_above = values >= bucket_upper_bound
185
+ else:
186
+ is_above = values > bucket_upper_bound
187
+
188
+ low = tl.where(is_above & mask, mid + 1, low)
189
+ high = tl.where(is_above, high, mid)
190
+
191
+ full_range = (full_range + 1) // 2
192
+
193
+ return low
194
+
195
+
196
+ @triton.jit
197
+ def pack_value_flag(
198
+ value,
199
+ flag,
200
+ DTYPE_VALUE_AS_UINT: tl.constexpr,
201
+ DTYPE_PACK: tl.constexpr,
202
+ ):
203
+ # Workaround for triton bug, tensor.to doesn't unwrap constexpr values
204
+ DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
205
+ bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
206
+ uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)
207
+ return flag.to(DTYPE_PACK) | (uv << bitwidth)
208
+
209
+
210
+ @triton.jit
211
+ def unpack_value(
212
+ pack,
213
+ DTYPE_VALUE,
214
+ DTYPE_VALUE_AS_UINT,
215
+ ):
216
+ # Workaround for triton bug, tensor.to doesn't unwrap constexpr values
217
+ DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)
218
+ DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
219
+ bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
220
+ value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)
221
+ return value_uint.to(DTYPE_VALUE, bitcast=True)
222
+
223
+
224
+ @triton.jit
225
+ def unpack_flag(pack, DTYPE_FLAG):
226
+ return pack.to(DTYPE_FLAG)
227
+
228
+
229
+ @triton.jit
230
+ def exclusive_scan_decoupled_lookback(
231
+ scratch_base,
232
+ block_value,
233
+ index,
234
+ combine_fn,
235
+ init,
236
+ DTYPE_VALUE_AS_UINT: tl.constexpr,
237
+ DTYPE_PACK: tl.constexpr,
238
+ ):
239
+ """Compute exclusive scan of a scalar value between blocks
240
+
241
+ Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
242
+
243
+ scratch_base: Pointer to scratch space in global memory
244
+ block_value: Scalar value for this block
245
+ index: Scalar index of this block relative to the current scan
246
+ combine_fn: Function ``(value, value) -> value`` which is scanned over
247
+ init: Scalar value equal to the identiy of combine_fn
248
+ DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``
249
+ DTYPE_PACK: Unsigned type twice the width of block_value
250
+
251
+ NOTE: This function is limited to values which are 32-bits or less.
252
+ """
253
+ DTYPE_VALUE = block_value.dtype
254
+ pack = pack_value_flag(
255
+ block_value,
256
+ tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),
257
+ DTYPE_VALUE_AS_UINT,
258
+ DTYPE_PACK,
259
+ )
260
+ tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
261
+
262
+ exclusive_prefix = init
263
+ test_target = index - 1
264
+ while test_target >= 0:
265
+ # tl.atomic_load
266
+ flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)
267
+ while flag == 0:
268
+ pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed")
269
+ flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)
270
+
271
+ value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)
272
+ exclusive_prefix = combine_fn(value, exclusive_prefix)
273
+
274
+ if flag == 2:
275
+ test_target = -1
276
+ else:
277
+ test_target = test_target - 1
278
+
279
+ # Make inclusive block sum visible to other blocks
280
+ inclusive_prefix = combine_fn(exclusive_prefix, block_value)
281
+ pack = pack_value_flag(
282
+ inclusive_prefix,
283
+ tl.full([], 2, DTYPE_VALUE_AS_UINT),
284
+ DTYPE_VALUE_AS_UINT,
285
+ DTYPE_PACK,
286
+ )
287
+ tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
288
+ return exclusive_prefix
289
+
290
+
291
+ @triton.jit
292
+ def exclusive_scan_decoupled_lookback_64(
293
+ scratch_base, block_value, index, combine_fn, init
294
+ ):
295
+ """Compute exclusive scan of a scalar value between blocks
296
+
297
+ Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
298
+
299
+ scratch_base: Pointer to scratch space in global memory
300
+ block_value: Scalar value for this block, must be 64-bits wide
301
+ index: Scalar index of this block relative to the current scan
302
+ combine_fn: Function ``(value, value) -> value`` which is scanned over
303
+ init: Scalar value equal to the identiy of combine_fn
304
+ """
305
+ block_value_u64 = block_value.to(tl.uint64, bitcast=True)
306
+ tl.store(scratch_base + 3 * index + 1, block_value_u64)
307
+ tl.debug_barrier()
308
+ flag_one = tl.full([], 1, tl.uint64)
309
+ tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release")
310
+
311
+ exclusive_prefix = init
312
+ test_target = index - 1
313
+ while test_target >= 0:
314
+ flag = tl.full([], 0, tl.uint64)
315
+ while flag == 0:
316
+ flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire")
317
+
318
+ value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))
319
+ value = value_u64.to(block_value.dtype, bitcast=True)
320
+ exclusive_prefix = combine_fn(value, exclusive_prefix)
321
+
322
+ if flag == 2:
323
+ test_target = -1
324
+ else:
325
+ test_target = test_target - 1
326
+
327
+ # Make inclusive block sum visible to other blocks
328
+ inclusive_prefix = combine_fn(exclusive_prefix, block_value)
329
+ inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)
330
+ tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)
331
+ tl.debug_barrier()
332
+ flag_two = tl.full([], 2, tl.uint64)
333
+ tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release")
334
+
335
+ return exclusive_prefix
336
+
337
+
338
+ @triton.jit
339
+ def frexp(x):
340
+ # TODO(isuruf): use inline_asm_elementwise here
341
+ y = libdevice.ilogb(x) + 1
342
+ exponent = tl.where(x == 0, 0, y)
343
+ mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))
344
+ return mantissa, exponent
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocator.h ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/mps/MPSAllocatorInterface.h>
6
+ #include <ATen/mps/MPSEvent.h>
7
+ #include <ATen/mps/MPSStream.h>
8
+
9
+ #include <cstdio>
10
+ #include <mutex>
11
+ #include <set>
12
+ #include <unordered_set>
13
+ #include <mach/vm_page_size.h>
14
+ #include <c10/util/flat_hash_map.h>
15
+
16
+ // this implementation is based on CUDACachingAllocator.
17
+ // It utilizes Metal Heaps to improve the performance with buffer allocation.
18
+ // Do not include this header. Use MPSAllocatorInterface.h instead.
19
+ // TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
20
+ namespace at::mps::HeapAllocator {
21
+
22
+ static const size_t kMaxSmallAlloc = MB(1); // largest "small" allocation is 1 MiB
23
+ static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 MiB may use kLargeHeap
24
+ static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
25
+ static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
26
+ static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
27
+ static const size_t kXLargeHeapD = MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
28
+ static const size_t kXLargeHeapU = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
29
+ static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
30
+
31
+ // buffer pools could be customized with a combination of usage flags
32
+ enum UsageFlags : uint32_t {
33
+ PRIVATE = 0,
34
+ SMALL = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
35
+ SHARED = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
36
+ MANAGED = (1 << 2), // managed storage mode
37
+ HAZARD = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
38
+ SCALAR = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
39
+ };
40
+ // debug verbosity flags
41
+ enum DebugVerbosity : uint32_t {
42
+ SILENT = 0,
43
+ PROFILING = (1 << 0), // print generic profiling data for total system memory usage
44
+ ALLOCATIONS = (1 << 1), // print buffer allocations
45
+ RECYCLES = (1 << 2), // print buffer recycling
46
+ RELEASES = (1 << 3), // print buffer releases
47
+ LARGE_ONLY = (1 << 4), // only log large buffer pool transactions
48
+ };
49
+
50
+ struct HeapBlock;
51
+
52
+ struct BufferBlock {
53
+ id<MTLBuffer> buffer;
54
+ void* cpu_ptr = nullptr; // stores the pointer to CPU mapping of a Shared MTLBuffer
55
+ size_t size; // size after alignment
56
+ size_t requested_size; // requested size (before alignment)
57
+ // buffer shape is used for retrieving base of views in cached graphs
58
+ std::vector<int64_t> shape;
59
+ bool in_use = false;
60
+ HeapBlock* heap;
61
+ id_t buf_id;
62
+ // counter to candidate least recently used buffers for garbage collection
63
+ uint32_t gc_count = 0;
64
+ uint32_t use_count = 0;
65
+ // counter to assign unique ids to buffer blocks
66
+ static uint64_t buffer_counter;
67
+ // Metal events used to sync GPU/CPU operations on the shared-storage buffers
68
+ MPSEventPtr event;
69
+
70
+ BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
71
+ HeapBlock* Heap = nullptr) :
72
+ buffer(Buffer), size(Size), requested_size(RequestedSize),
73
+ heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
74
+
75
+ static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
76
+ return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
77
+ }
78
+ static size_t alignUp(size_t Size, size_t Alignment) {
79
+ assert(((Alignment - 1) & Alignment) == 0);
80
+ return ((Size + Alignment - 1) & ~(Alignment - 1));
81
+ }
82
+ uint32_t retainCount() const { return [buffer retainCount]; }
83
+ };
84
+ typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
85
+
86
+ struct BufferPool;
87
+ struct AllocParams {
88
+ AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
89
+ search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
90
+ size_t size() const { return search_key.size; }
91
+
92
+ BufferBlock search_key;
93
+ BufferPool* pool;
94
+ BufferBlock* buffer_block = nullptr;
95
+ size_t requested_size;
96
+ // true if we exceed the low watermark limit. In this case
97
+ // we apply strategies to relieve the pressure before allocation.
98
+ bool has_memory_pressure = false;
99
+ // true if we're allocating on a unified memory device
100
+ bool has_unified_memory = true;
101
+ };
102
+
103
+ struct HeapBlock {
104
+ id<MTLHeap> heap;
105
+ struct { size_t total, available; } size;
106
+ BufferPool* pool;
107
+ unsigned int n_buffers = 0;
108
+ id_t heap_id;
109
+ // indicates if we split this heap to sub-allocate 'several' buffers (otherwise single buffer)
110
+ bool is_split;
111
+ // counter to assign unique ids to heap blocks
112
+ static uint64_t heap_counter;
113
+
114
+ HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
115
+ heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
116
+ heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
117
+
118
+ static MTLResourceOptions getOptions(uint32_t usage) {
119
+ // TODO: check the caching performance of write-combined mode
120
+ MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache;
121
+
122
+ if (usage & UsageFlags::MANAGED)
123
+ options |= MTLResourceStorageModeManaged;
124
+ else if (usage & UsageFlags::SHARED)
125
+ options |= MTLResourceStorageModeShared;
126
+ else
127
+ options |= MTLResourceStorageModePrivate;
128
+
129
+ options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
130
+
131
+ return options;
132
+ }
133
+
134
+ static HeapBlock* createHeapBlock(AllocParams& params, id<MTLDevice> device, uint32_t usage) {
135
+ HeapBlock *heapBlock = nullptr;
136
+ bool is_split = true;
137
+ const size_t size = params.size();
138
+ MTLHeapDescriptor *d = [MTLHeapDescriptor new];
139
+ if (d) {
140
+ const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
141
+ if (size <= kMaxSmallAlloc) {
142
+ d.size = kSmallHeap;
143
+ } else if (size < kMinLargeAlloc) {
144
+ d.size = kLargeHeap;
145
+ } else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) {
146
+ d.size = kXLargeHeap;
147
+ } else {
148
+ d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
149
+ is_split = false;
150
+ }
151
+ d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate;
152
+ d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
153
+ // this automatically handles Metal buffer access synchronizations at the
154
+ // cost of slightly lower performance.
155
+ d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
156
+ d.resourceOptions = getOptions(usage);
157
+ d.type = MTLHeapTypeAutomatic;
158
+ id<MTLHeap> heap = [device newHeapWithDescriptor: d];
159
+ if (heap) {
160
+ [heap setPurgeableState:MTLPurgeableStateNonVolatile];
161
+ const size_t heap_size = heapAvailableSize(heap);
162
+ heapBlock = new HeapBlock(heap_size, heap, params.pool);
163
+ if (heapBlock) {
164
+ heapBlock->is_split = is_split;
165
+ }
166
+ }
167
+ [d release];
168
+ }
169
+ return heapBlock;
170
+ }
171
+ static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
172
+ return (a->size.available != b->size.available) ? a->size.available < b->size.available :
173
+ (uintptr_t)a->heap < (uintptr_t)b->heap;
174
+ }
175
+ static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
176
+ return [heap maxAvailableSizeWithAlignment:Alignment];
177
+ }
178
+ NSUInteger Size() {
179
+ return [heap size];
180
+ }
181
+ id<MTLBuffer> newMTLBuffer(size_t length, uint32_t usage) {
182
+ id<MTLBuffer> buf = [heap newBufferWithLength:length options:getOptions(usage)];
183
+ if (buf) {
184
+ updateAvailableSize();
185
+ n_buffers++;
186
+ }
187
+ return buf;
188
+ }
189
+ // returns the retainCount before releasing the buffer
190
+ uint32_t releaseMTLBuffer(id<MTLBuffer>& buffer) {
191
+ const uint32_t retainCount = [buffer retainCount];
192
+ [buffer release];
193
+ buffer = nil;
194
+ updateAvailableSize();
195
+ n_buffers--;
196
+ return retainCount;
197
+ }
198
+ // returns the retainCount before releasing the heap
199
+ uint32_t releaseMTLHeap() {
200
+ const uint32_t retainCount = [heap retainCount];
201
+ TORCH_INTERNAL_ASSERT(!n_buffers); // assert if heap isn't empty
202
+ [heap setPurgeableState:MTLPurgeableStateEmpty];
203
+ [heap release];
204
+ heap = nil;
205
+ size.available = 0;
206
+ return retainCount;
207
+ }
208
+ uint32_t retainCount() const { return [heap retainCount]; }
209
+ void updateAvailableSize() { size.available = heapAvailableSize(heap); }
210
+ };
211
+ typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
212
+
213
+ struct BufferPool {
214
+ enum class Kind {
215
+ PRIVATE_SMALL,
216
+ PRIVATE_LARGE,
217
+ SHARED_SMALL,
218
+ SHARED_LARGE,
219
+ SCALAR,
220
+ };
221
+
222
+ BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
223
+ device(Device), usage(Usage),
224
+ heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
225
+
226
+ const id<MTLDevice> device;
227
+ // usage flags to customize the pool for various purposes (see UsageFlags enum)
228
+ const uint32_t usage;
229
+ // total number of buffers in the pool
230
+ uint32_t n_buffers = 0;
231
+ // total allocations size on this pool
232
+ size_t allocated_size = 0;
233
+ // total memory available in the pool
234
+ size_t available_size = 0;
235
+ // list of heaps ordered by their "available" (not total) memory size
236
+ std::set<HeapBlock*, HeapComparison> heaps;
237
+ // list of only "available" buffers in the pool (i.e., buffers not in-use)
238
+ std::set<BufferBlock*, BufferComparison> available_buffers;
239
+ // list of buffers that are in a state of "limbo" where they've already been freed
240
+ // from PyTorch-side, but were not returned to pool due to still being
241
+ // in-use by command buffers with retainCount > 1. In this state, the buffer is
242
+ // neither ready to be recycled, nor could be returned to pool as available.
243
+ // These buffers will be returned to pool once the command buffer's
244
+ // completionHandler callbacks are called.
245
+ std::unordered_set<BufferBlock*> buffers_pending_free;
246
+ // list of heaps pending size update
247
+ std::unordered_set<HeapBlock*> heaps_pending_update;
248
+ };
249
+
250
+ class MPSHeapAllocatorImpl {
251
+ public:
252
+ explicit MPSHeapAllocatorImpl() :
253
+ m_device(at::mps::MPSDevice::getInstance()->device()),
254
+ m_max_buffer_size([m_device maxBufferLength]),
255
+ m_stream(getDefaultMPSStream()),
256
+ m_event_pool(getMPSEventPool()) {
257
+ init_allocator();
258
+ }
259
+ ~MPSHeapAllocatorImpl() {
260
+ emptyCache();
261
+ }
262
+ // interface exposed to at::Allocator
263
+ id<MTLBuffer> malloc(size_t size, uint32_t usage);
264
+ // frees a buffer and returns it into buffer pool
265
+ void free(void* ptr);
266
+ // releases all the cached buffers and their associated heaps
267
+ void emptyCache();
268
+ // free inactive buffers that are pending to be freed
269
+ void freeInactiveBuffers();
270
+ // returns true if buffer was allocated from the shared pool
271
+ bool isSharedBuffer(const void* ptr);
272
+ // get the requested unaligned size of an MTLBuffer
273
+ ssize_t getUnalignedBufferSize(const void* ptr);
274
+ // set the shape of a base tensor from a view tensor
275
+ void setBufferShape(const void* ptr, const IntArrayRef& shape);
276
+ // retrieve the shape of a base tensor from a view tensor
277
+ IntArrayRef getBufferShape(const void* ptr);
278
+ // get the unique ID of the buffer
279
+ id_t getBufferId(const void* ptr);
280
+ // allocate a buffer from a specialized pool to import CPU scalars into GPU
281
+ id<MTLBuffer> allocScalarBufferWithValue(void* value, size_t size);
282
+ // returns a CPU-mapping of the input buffer and its retainCount,
283
+ // if only it has Shared storage-mode and allocated on MPSAllocator
284
+ std::pair<const void*, uint32_t> getSharedBufferPtr(const void* buffer);
285
+ // records events for a list of MTLBuffers (list is used to lock the mutex once)
286
+ // returns true if records any event (given if passed buffers exist and are shared-storage)
287
+ bool recordEvents(c10::ArrayRef<const void*> buffers);
288
+ // waits for the event to signal the completion of GPU execution
289
+ // on the passed shared buffers (list is used to lock the mutex once)
290
+ // returns true if actually waited on any event
291
+ bool waitForEvents(c10::ArrayRef<const void*> buffers);
292
+ // this indicates how far (in Megabytes) the current total allocations are from the
293
+ // low watermark limit which is used to detect if we're under memory pressure
294
+ // This returns zero if we've reached the low watermark limit
295
+ ssize_t getLowWatermarkValue();
296
+ // (see m_low_watermark_ratio for description)
297
+ void setLowWatermarkRatio(double ratio);
298
+ // (see m_high_watermark_ratio for description)
299
+ void setHighWatermarkRatio(double ratio);
300
+ // (see m_low_watermark_limit for description)
301
+ size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
302
+ // (see m_max_total_allowed_size for description)
303
+ size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
304
+ // (see m_total_allocated_memory for description)
305
+ size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
306
+ // (see m_current_allocated_memory for description)
307
+ size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
308
+ // total GPU memory allocated in the process by Metal driver; including
309
+ // implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
310
+ size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
311
+ // (see enum DebugVerbosity for description)
312
+ uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
313
+ // returns the device that we allocate from
314
+ inline id<MTLDevice> Device() const { return m_device; }
315
+
316
+ // TODO: make a common function to do size unit conversions in PyTorch.
317
+ inline std::string format_size(uint64_t size) const;
318
+
319
+ private:
320
+ // (see m_high_watermark_ratio for description)
321
+ constexpr static double default_high_watermark_ratio = 1.7;
322
+ // we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
323
+ constexpr static double default_high_watermark_upper_bound = 2.0;
324
+ // (see m_low_watermark_ratio for description)
325
+ // on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
326
+ constexpr static double default_low_watermark_ratio_unified = 1.4;
327
+ constexpr static double default_low_watermark_ratio_discrete = 1.0;
328
+
329
+ const id<MTLDevice> m_device;
330
+ std::recursive_mutex m_mutex;
331
+ // allocated buffers by device pointer
332
+ ska::flat_hash_map<const void*, BufferBlock*> m_allocated_buffers;
333
+ // using a container for pools to simplify iterating them
334
+ ska::flat_hash_map<BufferPool::Kind, std::unique_ptr<BufferPool>> m_pools;
335
+ // total memory allocated by HeapAllocator (including blocks in pools)
336
+ size_t m_total_allocated_memory = 0;
337
+ // currently active memory allocations in use (i.e., blocks not in pools)
338
+ size_t m_current_allocated_memory = 0;
339
+ // max buffer size allowed by Metal
340
+ size_t m_max_buffer_size = 0;
341
+ // maximum total size allowed to be allocated
342
+ size_t m_max_total_allowed_size = 0;
343
+ // high watermark ratio is a hard limit for the total allowed allocations
344
+ // 0. : disables high watermark limit (may cause system failure if system-wide OOM occurs)
345
+ // 1. : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
346
+ // >1.: allows limits beyond the device.recommendedMaxWorkingSetSize
347
+ // e.g., value 0.95 means we allocate up to 95% of recommended maximum
348
+ // allocation size; beyond that, the allocations would fail with OOM error.
349
+ double m_high_watermark_ratio;
350
+ // low watermark ratio is a soft limit to attempt limiting memory allocations up to the lower watermark
351
+ // level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
352
+ // Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
353
+ // e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
354
+ // allocation size.
355
+ double m_low_watermark_ratio;
356
+ // low watermark size limit (in Bytes) at the time we initialize the allocator
357
+ size_t m_low_watermark_limit;
358
+ // use "PYTORCH_DEBUG_MPS_ALLOCATOR" env-var to set debug verbosity
359
+ uint32_t m_debug_verbosity;
360
+ // default MPS stream
361
+ MPSStream* m_stream;
362
+ // we hold a reference to MPSEventPool so it could get destroyed after MPSAllocator
363
+ std::shared_ptr<MPSEventPool> m_event_pool;
364
+
365
+ void init_allocator();
366
+ void init_buffer_pools();
367
+ HeapBlock* get_free_heap(AllocParams& params);
368
+ bool get_free_buffer(AllocParams& params);
369
+ BufferBlock* get_allocated_buffer_block(const void* ptr);
370
+ BufferBlock* alloc_buffer_block(size_t size, uint32_t usage);
371
+ bool alloc_buffer(AllocParams& params);
372
+ void free_buffer(BufferBlock* buffer_block);
373
+ // returns true if the container heap is also released
374
+ bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true);
375
+ void release_buffers(BufferPool& pool);
376
+ bool release_available_cached_buffers(AllocParams& params);
377
+ bool release_cached_buffers();
378
+ // free unused cached blocks to reclaim GPU memory if memory pressure is high
379
+ void garbage_collect_cached_buffers(AllocParams& params);
380
+ // returns the suitable buffer pool type for the usage or
381
+ // requested/allocated sizes
382
+ BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage);
383
+ // returns the aligned allocation size that is optimized
384
+ // for the buffers to get reused frequently
385
+ size_t get_allocation_size(size_t size, uint32_t usage) const;
386
+ // maximum size of device memory available for allocation in current process
387
+ // Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
388
+ size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
389
+ // there are implicit allocations from MPS backend, so we need to query the 'device' for
390
+ // total allocated size instead of manually tracking in MPSAllocator
391
+ size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
392
+
393
+ bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
394
+ for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
395
+ MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
396
+ }
397
+ return true;
398
+ }
399
+ };
400
+
401
+ } // namespace at::mps::HeapAllocator
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSAllocatorInterface.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <c10/core/Allocator.h>
6
+ #include <c10/util/Registry.h>
7
+ #include <ATen/core/ATen_fwd.h>
8
+
9
+ #define MB(x) (x * 1048576UL)
10
+
11
+ namespace at::mps {
12
+
13
+ // this is a public interface to access MPSAllocator.
14
+ // Do not declare methods that would depend on MPS or Metal frameworks.
15
+ class IMPSAllocator : public c10::Allocator {
16
+ public:
17
+ // see the comments in MPSAllocator.h for the description of these methods.
18
+ virtual void emptyCache() const = 0;
19
+ virtual void freeInactiveBuffers() const = 0;
20
+ virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
21
+ virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
22
+ virtual id_t getBufferId(const void* ptr) const = 0;
23
+ virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
24
+ virtual bool isSharedBuffer(const void* ptr) const = 0;
25
+ virtual bool isSharedStorageSupported() const = 0;
26
+ virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
27
+ virtual std::string formatSize(size_t size) const = 0;
28
+ virtual void setLowWatermarkRatio(double ratio) const = 0;
29
+ virtual void setHighWatermarkRatio(double ratio) const = 0;
30
+ virtual ssize_t getLowWatermarkValue() const = 0;
31
+ virtual size_t getLowWatermarkLimit() const = 0;
32
+ virtual size_t getHighWatermarkLimit() const = 0;
33
+ virtual size_t getTotalAllocatedMemory() const = 0;
34
+ virtual size_t getCurrentAllocatedMemory() const = 0;
35
+ virtual size_t getDriverAllocatedMemory() const = 0;
36
+ virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
37
+ virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
38
+ virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
39
+ };
40
+
41
+ class IMpsAllocatorCallback {
42
+ public:
43
+ enum class EventType {
44
+ ALLOCATED, // buffer got allocated to be used immediately
45
+ RECYCLED, // buffer pulled from free list to be reused
46
+ FREED, // buffer put to free list for future recycling
47
+ RELEASED, // buffer memory released
48
+ ALLOCATION_FAILED // buffer allocation failed
49
+ };
50
+ virtual ~IMpsAllocatorCallback() = default;
51
+ virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
52
+ };
53
+
54
+ // MPS allocator will execute every registered callback when a block of memory is freed.
55
+ C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
56
+ #define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
57
+ C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);
58
+
59
+ IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false);
60
+
61
+ } // namespace at::mps
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSEvent.h ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/mps/MPSStream.h>
6
+ #include <ctime>
7
+ #include <stack>
8
+
9
+ namespace at::mps {
10
+
11
+ // NOTE: don't create instances of this class directly.
12
+ // Use MPSEventPool to acquire instances of MPSEvent.
13
+ class MPSEvent {
14
+ public:
15
+ explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
16
+ ~MPSEvent();
17
+
18
+ // records an event on the stream
19
+ void record(bool needsLock, bool syncEvent = false);
20
+ // makes all future work submitted to the stream wait for this event.
21
+ bool wait(bool needsLock, bool syncEvent = false);
22
+ // schedules a notifyListener callback for the event.
23
+ bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
24
+ // checks if events are already signaled.
25
+ bool query() const;
26
+ // blocks the CPU thread until all the GPU work that were scheduled
27
+ // prior to recording this event are completed.
28
+ bool synchronize();
29
+ // resets this event with new parameters in case it gets reused from the event pool
30
+ void reset(MPSStream* stream, bool enable_timing);
31
+ // returns the unique ID of the event instance
32
+ id_t getID() const { return m_id; }
33
+ // returns the completion timestamp of the event
34
+ uint64_t getCompletionTime() const { return m_completion_time; }
35
+ // if already recorded, waits for cpu_sync_cv to be signaled
36
+ void waitForCpuSync();
37
+
38
+ private:
39
+ id_t m_id;
40
+ // enables measuring the completion time of the notifyListener of this event
41
+ bool m_enable_timing;
42
+ uint64_t m_signalCounter = 0;
43
+ MPSStream* m_stream = nullptr;
44
+ MTLSharedEvent_t m_event = nullptr;
45
+ MTLSharedEventListener* m_listener = nullptr;
46
+ // used to sync the events created on this Stream with CPU
47
+ std::mutex m_cpu_sync_mutex{};
48
+ std::condition_variable m_cpu_sync_cv{};
49
+ // CondVar predicate to sync the events created on this Stream with CPU
50
+ bool m_cpu_sync_completed = false;
51
+ // used to compute elapsed time
52
+ uint64_t m_completion_time = 0;
53
+
54
+ void recordLocked(bool syncEvent);
55
+ bool waitLocked(bool syncEvent);
56
+ bool notifyLocked(MTLSharedEventNotificationBlock block);
57
+ void notifyCpuSync();
58
+ static uint64_t getTime() {
59
+ return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
60
+ }
61
+ };
62
+
63
+ typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
64
+
65
+ class MPSEventPool {
66
+ public:
67
+ explicit MPSEventPool(MPSStream* default_stream);
68
+ ~MPSEventPool();
69
+
70
+ MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
71
+ void emptyCache();
72
+
73
+ // these are mainly used for MPSHooks and torch.mps.Event() bindings
74
+ id_t acquireEvent(bool enable_timing);
75
+ void releaseEvent(id_t event_id);
76
+ void recordEvent(id_t event_id, bool syncEvent);
77
+ void waitForEvent(id_t event_id, bool syncEvent);
78
+ void synchronizeEvent(id_t event_id);
79
+ bool queryEvent(id_t event_id);
80
+ // returns elapsed time between two recorded events in milliseconds
81
+ double elapsedTime(id_t start_event_id, id_t end_event_id);
82
+
83
+ private:
84
+ MPSStream* m_default_stream = nullptr;
85
+ std::recursive_mutex m_mutex;
86
+ std::stack<std::unique_ptr<MPSEvent>> m_pool{};
87
+ // dictionary to associate event IDs with event objects
88
+ // used to retain in-use events out of the pool
89
+ // for torch.mps.Event() bindings.
90
+ std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
91
+ uint64_t m_event_counter = 0;
92
+ std::function<void(MPSEvent*)> m_default_deleter;
93
+
94
+ MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
95
+ };
96
+
97
+ // shared_ptr is used to get MPSEventPool destroyed after dependent instances
98
+ std::shared_ptr<MPSEventPool> getMPSEventPool();
99
+
100
+ } // namespace at::mps
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/mps/MPSProfiler.h ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/Tensor.h>
6
+ #include <ATen/mps/MPSStream.h>
7
+ #include <ATen/mps/MPSAllocatorInterface.h>
8
+
9
+ #include <os/signpost.h>
10
+ #include <os/log.h>
11
+
12
+ #include <sstream>
13
+ #include <string>
14
+ #include <atomic>
15
+ #include <unordered_map>
16
+ #include <utility>
17
+ #include <ctime>
18
+
19
+ namespace at::mps {
20
+
21
+ namespace Profiler {
22
+
23
+ struct BaseInfo {
24
+ // profiling info types
25
+ enum class Type {
26
+ GRAPH,
27
+ KERNEL,
28
+ COPY,
29
+ CPU_FALLBACK,
30
+ };
31
+
32
+ BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
33
+ type(infoType), profileId(Id), handle(Handle) { }
34
+ virtual ~BaseInfo() = default;
35
+
36
+ // type of profiling info
37
+ Type type;
38
+ // unique profile ID for execution instances of operations or copies
39
+ uint64_t profileId;
40
+ // ID generated by os_signpost
41
+ // since it's possible to use event and interval-based signposts at the
42
+ // same time, we need separate IDs for each.
43
+ os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
44
+ // accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
45
+ std::atomic<double> totalGpuTime{0.0};
46
+ // accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
47
+ std::atomic<double> totalSchedulingTime{0.0};
48
+ // indicates if the operation or copy execution has completed
49
+ std::atomic_bool completed{false};
50
+ // handle used to identify the profile info's instance (usually the pointer)
51
+ const uintptr_t handle;
52
+
53
+ virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
54
+ // builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
55
+ static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
56
+ if (tensor.defined()) {
57
+ std::stringstream tensorStr;
58
+ auto deviceType = tensor.device().type();
59
+ tensorStr << c10::DeviceTypeName(deviceType);
60
+ // see comments for INCLUDE_BUFFER_ID
61
+ if (includeBufferId && deviceType == at::kMPS) {
62
+ id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
63
+ tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
64
+ << ":" << buffer.retainCount << ")";
65
+ }
66
+ tensorStr << ":"
67
+ << tensor.scalar_type() << tensor.sizes();
68
+ return tensorStr.str();
69
+ } else {
70
+ return "undefined";
71
+ }
72
+ }
73
+ static uint64_t getTime() {
74
+ return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
75
+ }
76
+ };
77
+
78
+ struct OperationInfo : BaseInfo {
79
+ OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
80
+ BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
81
+
82
+ uint64_t runCount = 0;
83
+ std::string strKey;
84
+
85
+ const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
86
+
87
+ // builds a string for a kernel
88
+ static std::string buildKernelString(const std::string& kernelName,
89
+ const TensorList& tensors,
90
+ bool includeBufferId = false) {
91
+ std::stringstream kernelStr;
92
+ kernelStr << kernelName;
93
+ for (const Tensor& tensor: tensors) {
94
+ kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
95
+ }
96
+ return kernelStr.str();
97
+ }
98
+ };
99
+
100
+ struct CpuFbInfo : BaseInfo {
101
+ CpuFbInfo(uint64_t Id, const std::string& OpName) :
102
+ BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
103
+
104
+ uint64_t runCount = 0;
105
+ // the current and total overhead of copies in bytes required to convert the Op's
106
+ // input tensors from MPS to CPU and then output from CPU back to MPS
107
+ size_t currentCopyOverhead = 0;
108
+ size_t totalCopyOverhead = 0;
109
+ std::string opName;
110
+ std::string strKey;
111
+ uint64_t startTime = 0;
112
+
113
+ const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
114
+
115
+ void updateCopyOverhead(const TensorList& tensors) {
116
+ currentCopyOverhead = 0;
117
+ for (const Tensor& tensor: tensors) {
118
+ if (tensor.defined()) {
119
+ currentCopyOverhead += tensor.nbytes();
120
+ }
121
+ }
122
+ totalCopyOverhead += currentCopyOverhead;
123
+ }
124
+ };
125
+
126
+ struct CopyInfo : BaseInfo {
127
+ enum class Kind {
128
+ MPS_TO_MPS,
129
+ MPS_TO_CPU,
130
+ CPU_TO_MPS,
131
+ };
132
+
133
+ CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
134
+ BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
135
+ length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
136
+
137
+ Kind kind;
138
+ size_t length;
139
+ bool isNonBlocking;
140
+ bool usesBlitter;
141
+ std::string srcStrKey;
142
+ std::string dstStrKey;
143
+ // for copies that don't use blitters, we measure CPU time
144
+ uint64_t startTime = 0;
145
+
146
+ const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
147
+
148
+ static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
149
+
150
+ static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
151
+ if (tensor.has_value()) {
152
+ return tensor->device().type() == at::kMPS;
153
+ }
154
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer);
155
+ // getUnalignedBufferSize() returns -1 if input buffer is not on MPS device
156
+ return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
157
+ }
158
+
159
+ static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
160
+ const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
161
+ const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
162
+ const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
163
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
164
+ if (isSrcOnMPS && !isDstOnMPS) {
165
+ return Kind::MPS_TO_CPU;
166
+ } else if (!isSrcOnMPS && isDstOnMPS) {
167
+ return Kind::CPU_TO_MPS;
168
+ }
169
+ return Kind::MPS_TO_MPS;
170
+ }
171
+ };
172
+
173
+ struct CopyStat : CopyInfo {
174
+ explicit CopyStat(std::string CopyKindStr) :
175
+ CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
176
+ // total number of copies
177
+ size_t totalCount = 0;
178
+ // number of Scalar copies (i.e., less than sizeof(int64))
179
+ size_t scalarsCount = 0;
180
+ // number of blocking copies (i.e., require syncing to GPU)
181
+ size_t blockingCount = 0;
182
+ // number of copies that used memcpy(), instead of Metal Blit Encoder
183
+ size_t memcpyCount = 0;
184
+ // accumulated GPU time in ms for the scalar copies
185
+ std::atomic<double> scalarsGpuTime{0.0};
186
+ // copy kind in string type
187
+ std::string kindStr;
188
+ };
189
+
190
+ class MPSProfiler {
191
+ public:
192
+ // lower 16 bits used for profiler options
193
+ enum ProfileOptions : uint32_t {
194
+ OPTIONS_NONE = 0,
195
+ // ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
196
+ // (used for convenience to not compute bit flags by OR-ing manually)
197
+ // trace all signpost types using events
198
+ ALL_SIGNPOST_EVENTS = (1 << 0),
199
+ // trace all signpost types using intervals
200
+ ALL_SIGNPOST_INTERVALS = (1 << 1),
201
+ // always wait for command buffer to finish executing after each commit
202
+ WAIT_UNTIL_COMPLETED = (1 << 2),
203
+ // for interval-based signposts, include the scheduling portion of
204
+ // Graph/Kernel/Copy executions as well.
205
+ // if flag is disable, only "GPU run time" is included in interval,
206
+ // and not schedule time.
207
+ INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
208
+
209
+ // use these if you need to trace signposts types individually (rarely required)
210
+ // trace signpost using intervals
211
+ USE_INTERVALS = (1 << 4),
212
+ // trace signpost by emitting events
213
+ USE_EVENTS = (1 << 5),
214
+ // used for sanity check (Change this when new option added)
215
+ OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
216
+ };
217
+
218
+ // when adding new types, #define the type string in MPSProfiler.mm as well.
219
+ // upper 16 bits used for event types
220
+ enum SignpostTypes : uint32_t {
221
+ SIGNPOST_NONE = 0,
222
+ // trace signposts for PyTorch operation executions
223
+ RUN_OPERATION = (1 << 16),
224
+ // trace signposts for blitter copies
225
+ BLIT_COPY = (1 << 17),
226
+ // trace signposts for ops that fall back on CPU
227
+ CPU_FALLBACK = (1 << 18),
228
+ // used for sanity check (Change this when new type added)
229
+ SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
230
+ };
231
+
232
+ enum LogOptions : uint32_t {
233
+ LOG_NONE = 0,
234
+
235
+ // Info logging options during execution
236
+ // -------------------------------------
237
+ // prints operation info (id/key/run_count) during execution
238
+ OPERATION_INFO = (1 << 0),
239
+ // prints copy info (src/dst tensors/buffers, size, etc.) during execution
240
+ COPY_INFO = (1 << 1),
241
+ // prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
242
+ CPU_FALLBACK_INFO = (1 << 2),
243
+
244
+ // Profiling Statistics logging options when process terminates
245
+ // ------------------------------------------------------------
246
+ // prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
247
+ // this is convenient to not combine following stats bit flags manually
248
+ ALL_STATS = (1 << 3),
249
+ // prints operation stats (GPU times, run count, etc.) before process terminates
250
+ OPERATION_STATS = (1 << 4),
251
+ // prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
252
+ COPY_STATS = (1 << 5),
253
+ // prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
254
+ // for tensors, etc.) before process terminates
255
+ CPU_FALLBACK_STATS = (1 << 6),
256
+
257
+ // Metadata format options when logging the info
258
+ // ---------------------------------------------
259
+ // if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
260
+ // from Metal Command Buffers) (e.g., [GPU=0.324 ms])
261
+ INCLUDE_GPU_TIME = (1 << 7),
262
+ // if enabled, includes GPU scheduling time in metadata separately
263
+ // (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
264
+ // e.g., [GPU=0.324 ms, KRNL=0.036 ms]
265
+ INCLUDE_KERNEL_TIME = (1 << 8),
266
+ // if enabled, includes the unique buffer ID in metadata for the storage
267
+ // of a tensor that was allocated on MPSAllocator. This is useful (along with
268
+ // the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
269
+ // with various operations.
270
+ INCLUDE_BUFFER_ID = (1 << 9),
271
+
272
+ // used for sanity check (Change this when new option added)
273
+ LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
274
+ };
275
+
276
+ explicit MPSProfiler();
277
+ ~MPSProfiler();
278
+
279
+ // the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
280
+ // the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
281
+ uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
282
+ uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
283
+ uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
284
+ const OptionalTensorRef srcTensor,
285
+ const OptionalTensorRef dstTensor,
286
+ size_t length, bool isNonBlocking, bool usesBlitter = true);
287
+ uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
288
+ void beginProfileGPUInterval(const void* handle);
289
+
290
+ void endProfileCopy(uint64_t profileId, SyncType syncType);
291
+ void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE);
292
+ void endProfileCPUFallback(const std::string& opName);
293
+
294
+ // these are used to hook into Python bindings for torch.mps.profiler module.
295
+ // this enables generating OS Signpost traces from MPSProfiler on-demand
296
+ // during runtime (instead of environment variables).
297
+ // The "mode" could be either "interval", "event", or both "interval,event"
298
+ // for interval-based and/or event-based signpost tracing.
299
+ void StartTrace(const string& mode, bool waitUntilCompleted);
300
+ void StopTrace();
301
+
302
+ // convenience functions to indicate whether signpost tracing or
303
+ // logging are enabled for the SignpostTypes
304
+ bool isOperationProfilingEnabled() const {
305
+ return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
306
+ (m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
307
+ }
308
+ bool isCopyProfilingEnabled() const {
309
+ return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
310
+ (m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
311
+ }
312
+ bool isCPUFallbackProfilingEnabled() const {
313
+ return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
314
+ (m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
315
+ }
316
+ bool isSignpostTracingEnabled() const {
317
+ return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
318
+ }
319
+
320
+ private:
321
+ // indicates what type of signpost types are enabled and traced by MPS profiler.
322
+ uint32_t m_signpost_types = 0;
323
+ uint32_t m_profile_options = 0;
324
+ uint32_t m_log_options = 0;
325
+ uint64_t m_kernel_counter = 0;
326
+ uint64_t m_graph_counter = 0;
327
+ uint64_t m_cpu_fb_counter = 0;
328
+ uint64_t m_copy_counter = 0;
329
+ // technically, it's possible to trace both events and intervals at the same time
330
+ // so we use separate os_log categories for them
331
+ os_log_t m_os_log_events;
332
+ os_log_t m_os_log_intervals;
333
+ // stats logging could run either from destructor or signal handler
334
+ // so this is used to check if logging has already started.
335
+ std::atomic_bool hasLoggedStats{false};
336
+ // indicates there are pending completionHandler callbacks that haven't been called yet.
337
+ std::atomic_bool hasPendingCompletionHandlers{false};
338
+ // used to capture sigint signal to log profiling stats
339
+ static struct sigaction currentSigint, previousSigint;
340
+
341
+ // We use the following lists for two reasons:
342
+ // 1- for interval-based signposts the "begin" point won't be in same function
343
+ // as the "end" point where we need to be able to retrieve signpost's info
344
+ // 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
345
+
346
+ // the pointer key for this map is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
347
+ // this list is retained and could be logged along with aggregate profiling numbers when the process ends.
348
+ std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>> m_op_info_list{};
349
+ // the string key for this map is the op name that we fall back to execute on CPU
350
+ // this list is retained and could be logged along with aggregate profiling numbers when the process ends.
351
+ std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>> m_cpu_fb_info_list{};
352
+ // this list contains the info for copies, and its key is the unique profileId
353
+ // which is generated from m_copy_counter
354
+ // The copyInfo list is not retained.
355
+ std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
356
+ // a short list that contains copy stats
357
+ std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
358
+
359
+ void initialize();
360
+ void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
361
+ void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
362
+ os_signpost_id_t interval_signpost_id,
363
+ double gpuTime, double schedulingTime);
364
+ void addProfilerScheduledHandler(BaseInfo& info);
365
+ void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
366
+ void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
367
+ const std::string& msg) const;
368
+ void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
369
+ const std::string& msg) const;
370
+ void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
371
+
372
+ void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
373
+ // returns true if logging the profiling info "during the execution" is enabled
374
+ bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
375
+ // logs all the profiling stats that are enabled
376
+ void logProfilingStats();
377
+ // logs kernel profiling stats when the process ends.
378
+ void logOperationsProfilingStats(std::FILE* f) const;
379
+ // logs CPU Fallback profiling stats when the process ends.
380
+ void logCPUFallbackProfilingStats(std::FILE* f) const;
381
+ // logs copy profiling stats when the process ends.
382
+ void logCopyProfilingStats(std::FILE* f) const;
383
+
384
+ os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
385
+ static SignpostTypes getSignpostType(BaseInfo::Type infoType);
386
+ static void handleIntSignal(int signal);
387
+ };
388
+
389
+ } // namespace Profiler
390
+
391
+ Profiler::MPSProfiler& getMPSProfiler();
392
+
393
+ } // namespace at::mps
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/BatchLinearAlgebra.h ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/Optional.h>
4
+ #include <c10/util/string_view.h>
5
+ #include <ATen/Config.h>
6
+ #include <ATen/native/DispatchStub.h>
7
+
8
+ // Forward declare TI
9
+ namespace at {
10
+ class Tensor;
11
+ struct TensorIterator;
12
+
13
+ namespace native {
14
+ enum class TransposeType;
15
+ }
16
+
17
+ }
18
+
19
+ namespace at::native {
20
+
21
+ enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
22
+
23
+ #if AT_BUILD_WITH_LAPACK()
24
+ // Define per-batch functions to be used in the implementation of batched
25
+ // linear algebra operations
26
+
27
+ template <class scalar_t>
28
+ void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
29
+
30
+ template <class scalar_t>
31
+ void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
32
+
33
+ template <class scalar_t, class value_t=scalar_t>
34
+ void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
35
+
36
+ template <class scalar_t>
37
+ void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
38
+
39
+ template <class scalar_t>
40
+ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
41
+
42
+ template <class scalar_t>
43
+ void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
44
+
45
+ template <class scalar_t, class value_t = scalar_t>
46
+ void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
47
+
48
+ template <class scalar_t>
49
+ void lapackGels(char trans, int m, int n, int nrhs,
50
+ scalar_t *a, int lda, scalar_t *b, int ldb,
51
+ scalar_t *work, int lwork, int *info);
52
+
53
+ template <class scalar_t, class value_t = scalar_t>
54
+ void lapackGelsd(int m, int n, int nrhs,
55
+ scalar_t *a, int lda, scalar_t *b, int ldb,
56
+ value_t *s, value_t rcond, int *rank,
57
+ scalar_t* work, int lwork,
58
+ value_t *rwork, int* iwork, int *info);
59
+
60
+ template <class scalar_t, class value_t = scalar_t>
61
+ void lapackGelsy(int m, int n, int nrhs,
62
+ scalar_t *a, int lda, scalar_t *b, int ldb,
63
+ int *jpvt, value_t rcond, int *rank,
64
+ scalar_t *work, int lwork, value_t* rwork, int *info);
65
+
66
+ template <class scalar_t, class value_t = scalar_t>
67
+ void lapackGelss(int m, int n, int nrhs,
68
+ scalar_t *a, int lda, scalar_t *b, int ldb,
69
+ value_t *s, value_t rcond, int *rank,
70
+ scalar_t *work, int lwork,
71
+ value_t *rwork, int *info);
72
+
73
+ template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t>
74
+ struct lapackLstsq_impl;
75
+
76
+ template <class scalar_t, class value_t>
77
+ struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> {
78
+ static void call(
79
+ char trans, int m, int n, int nrhs,
80
+ scalar_t *a, int lda, scalar_t *b, int ldb,
81
+ scalar_t *work, int lwork, int *info, // Gels flavor
82
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
83
+ value_t *s, // Gelss flavor
84
+ int *iwork // Gelsd flavor
85
+ ) {
86
+ lapackGels<scalar_t>(
87
+ trans, m, n, nrhs,
88
+ a, lda, b, ldb,
89
+ work, lwork, info);
90
+ }
91
+ };
92
+
93
+ template <class scalar_t, class value_t>
94
+ struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> {
95
+ static void call(
96
+ char trans, int m, int n, int nrhs,
97
+ scalar_t *a, int lda, scalar_t *b, int ldb,
98
+ scalar_t *work, int lwork, int *info, // Gels flavor
99
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
100
+ value_t *s, // Gelss flavor
101
+ int *iwork // Gelsd flavor
102
+ ) {
103
+ lapackGelsy<scalar_t, value_t>(
104
+ m, n, nrhs,
105
+ a, lda, b, ldb,
106
+ jpvt, rcond, rank,
107
+ work, lwork, rwork, info);
108
+ }
109
+ };
110
+
111
+ template <class scalar_t, class value_t>
112
+ struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> {
113
+ static void call(
114
+ char trans, int m, int n, int nrhs,
115
+ scalar_t *a, int lda, scalar_t *b, int ldb,
116
+ scalar_t *work, int lwork, int *info, // Gels flavor
117
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
118
+ value_t *s, // Gelss flavor
119
+ int *iwork // Gelsd flavor
120
+ ) {
121
+ lapackGelsd<scalar_t, value_t>(
122
+ m, n, nrhs,
123
+ a, lda, b, ldb,
124
+ s, rcond, rank,
125
+ work, lwork,
126
+ rwork, iwork, info);
127
+ }
128
+ };
129
+
130
+ template <class scalar_t, class value_t>
131
+ struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> {
132
+ static void call(
133
+ char trans, int m, int n, int nrhs,
134
+ scalar_t *a, int lda, scalar_t *b, int ldb,
135
+ scalar_t *work, int lwork, int *info, // Gels flavor
136
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
137
+ value_t *s, // Gelss flavor
138
+ int *iwork // Gelsd flavor
139
+ ) {
140
+ lapackGelss<scalar_t, value_t>(
141
+ m, n, nrhs,
142
+ a, lda, b, ldb,
143
+ s, rcond, rank,
144
+ work, lwork,
145
+ rwork, info);
146
+ }
147
+ };
148
+
149
+ template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t>
150
+ void lapackLstsq(
151
+ char trans, int m, int n, int nrhs,
152
+ scalar_t *a, int lda, scalar_t *b, int ldb,
153
+ scalar_t *work, int lwork, int *info, // Gels flavor
154
+ int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
155
+ value_t *s, // Gelss flavor
156
+ int *iwork // Gelsd flavor
157
+ ) {
158
+ lapackLstsq_impl<driver_type, scalar_t, value_t>::call(
159
+ trans, m, n, nrhs,
160
+ a, lda, b, ldb,
161
+ work, lwork, info,
162
+ jpvt, rcond, rank, rwork,
163
+ s,
164
+ iwork);
165
+ }
166
+
167
+ template <class scalar_t>
168
+ void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
169
+
170
+ template <class scalar_t>
171
+ void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
172
+
173
+ template <class scalar_t>
174
+ void lapackLdlHermitian(
175
+ char uplo,
176
+ int n,
177
+ scalar_t* a,
178
+ int lda,
179
+ int* ipiv,
180
+ scalar_t* work,
181
+ int lwork,
182
+ int* info);
183
+
184
+ template <class scalar_t>
185
+ void lapackLdlSymmetric(
186
+ char uplo,
187
+ int n,
188
+ scalar_t* a,
189
+ int lda,
190
+ int* ipiv,
191
+ scalar_t* work,
192
+ int lwork,
193
+ int* info);
194
+
195
+ template <class scalar_t>
196
+ void lapackLdlSolveHermitian(
197
+ char uplo,
198
+ int n,
199
+ int nrhs,
200
+ scalar_t* a,
201
+ int lda,
202
+ int* ipiv,
203
+ scalar_t* b,
204
+ int ldb,
205
+ int* info);
206
+
207
+ template <class scalar_t>
208
+ void lapackLdlSolveSymmetric(
209
+ char uplo,
210
+ int n,
211
+ int nrhs,
212
+ scalar_t* a,
213
+ int lda,
214
+ int* ipiv,
215
+ scalar_t* b,
216
+ int ldb,
217
+ int* info);
218
+
219
+ template<class scalar_t, class value_t=scalar_t>
220
+ void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
221
+ #endif
222
+
223
+ #if AT_BUILD_WITH_BLAS()
224
+ template <class scalar_t>
225
+ void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
226
+ #endif
227
+
228
+ using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
229
+ DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
230
+
231
+ using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
232
+
233
+ DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
234
+
235
+ using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
236
+
237
+ DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
238
+
239
+ using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
240
+ DECLARE_DISPATCH(geqrf_fn, geqrf_stub);
241
+
242
+ using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
243
+ DECLARE_DISPATCH(orgqr_fn, orgqr_stub);
244
+
245
+ using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
246
+ DECLARE_DISPATCH(ormqr_fn, ormqr_stub);
247
+
248
+ using linalg_eigh_fn = void (*)(
249
+ const Tensor& /*eigenvalues*/,
250
+ const Tensor& /*eigenvectors*/,
251
+ const Tensor& /*infos*/,
252
+ bool /*upper*/,
253
+ bool /*compute_eigenvectors*/);
254
+ DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
255
+
256
+ using lstsq_fn = void (*)(
257
+ const Tensor& /*a*/,
258
+ Tensor& /*b*/,
259
+ Tensor& /*rank*/,
260
+ Tensor& /*singular_values*/,
261
+ Tensor& /*infos*/,
262
+ double /*rcond*/,
263
+ std::string /*driver_name*/);
264
+ DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
265
+
266
+ using triangular_solve_fn = void (*)(
267
+ const Tensor& /*A*/,
268
+ const Tensor& /*B*/,
269
+ bool /*left*/,
270
+ bool /*upper*/,
271
+ TransposeType /*transpose*/,
272
+ bool /*unitriangular*/);
273
+ DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
274
+
275
+ using lu_factor_fn = void (*)(
276
+ const Tensor& /*input*/,
277
+ const Tensor& /*pivots*/,
278
+ const Tensor& /*infos*/,
279
+ bool /*compute_pivots*/);
280
+ DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
281
+
282
+ using unpack_pivots_fn = void(*)(
283
+ TensorIterator& iter,
284
+ const int64_t dim_size,
285
+ const int64_t max_pivot);
286
+ DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
287
+
288
+ using lu_solve_fn = void (*)(
289
+ const Tensor& /*LU*/,
290
+ const Tensor& /*pivots*/,
291
+ const Tensor& /*B*/,
292
+ TransposeType /*trans*/);
293
+ DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);
294
+
295
+ using ldl_factor_fn = void (*)(
296
+ const Tensor& /*LD*/,
297
+ const Tensor& /*pivots*/,
298
+ const Tensor& /*info*/,
299
+ bool /*upper*/,
300
+ bool /*hermitian*/);
301
+ DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub);
302
+
303
+ using svd_fn = void (*)(
304
+ const Tensor& /*A*/,
305
+ const bool /*full_matrices*/,
306
+ const bool /*compute_uv*/,
307
+ const c10::optional<c10::string_view>& /*driver*/,
308
+ const Tensor& /*U*/,
309
+ const Tensor& /*S*/,
310
+ const Tensor& /*Vh*/,
311
+ const Tensor& /*info*/);
312
+ DECLARE_DISPATCH(svd_fn, svd_stub);
313
+
314
+ using ldl_solve_fn = void (*)(
315
+ const Tensor& /*LD*/,
316
+ const Tensor& /*pivots*/,
317
+ const Tensor& /*result*/,
318
+ bool /*upper*/,
319
+ bool /*hermitian*/);
320
+ DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub);
321
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/EmbeddingBag.h ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/Config.h>
3
+ #include <cstdint>
4
+
5
+ #ifdef USE_FBGEMM
6
+ #include <fbgemm/FbgemmEmbedding.h>
7
+ #endif
8
+
9
+ namespace at::native {
10
+
11
+ void check_arguments(
12
+ const Tensor& weight,
13
+ const Tensor& indices,
14
+ const Tensor& offsets,
15
+ const int64_t mode,
16
+ const c10::optional<Tensor>& per_sample_weights,
17
+ bool include_last_offset);
18
+
19
+ void make_bag_size_out(
20
+ Tensor& bag_size_out,
21
+ const Tensor& offsets,
22
+ const Tensor& indices,
23
+ const int64_t mode,
24
+ const bool include_last_offset,
25
+ const bool requires_grad);
26
+
27
+ void make_max_indices_out(
28
+ Tensor& max_indices_out,
29
+ const Tensor& weight,
30
+ const Tensor& indices,
31
+ const Tensor& offsets,
32
+ const Tensor& bag_size,
33
+ const int64_t mode,
34
+ bool include_last_offset);
35
+
36
+ void make_offset2bag_out(
37
+ Tensor& offset2bag,
38
+ Tensor& output,
39
+ const Tensor& weight,
40
+ const Tensor& indices,
41
+ const Tensor& offsets,
42
+ const int64_t mode,
43
+ const c10::optional<Tensor>& per_sample_weights,
44
+ const int64_t padding_idx = -1);
45
+
46
+ #ifdef USE_FBGEMM
47
+
48
+ template<bool has_weight, typename TIndex, typename TData>
49
+ struct _CallbackAndBlockSize {
50
+ using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
51
+
52
+ int64_t blockSize = -1;
53
+ TCallback callback = nullptr;
54
+
55
+ static TCallback generateCallback(int64_t block_size) {
56
+ return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
57
+ block_size,
58
+ has_weight,
59
+ /* normalize_by_lengths */false,
60
+ /* prefetch */16,
61
+ /* is_weight_positional */false,
62
+ /* use_offsets */true);
63
+ }
64
+
65
+ _CallbackAndBlockSize() = default;
66
+
67
+ explicit _CallbackAndBlockSize(c10::optional<int64_t> maybe_block_size)
68
+ : blockSize(maybe_block_size.value_or(-1))
69
+ , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
70
+ {}
71
+ };
72
+
73
+ template<typename... StorageMixins>
74
+ struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
75
+
76
+ _EmbeddingBagKernelCacheImpl() = default;
77
+ // use each of the mixins to store corresponding kernel and block size
78
+ explicit _EmbeddingBagKernelCacheImpl(c10::optional<int64_t> maybe_block_size)
79
+ : StorageMixins(maybe_block_size)...
80
+ {}
81
+
82
+ // this method is thread safe (call sites may call from different threads)
83
+ template<bool has_weight, typename TIndex, typename TData>
84
+ typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
85
+ getCallback(int64_t block_size) const {
86
+ // if the cache doesn't store the kernel for the incoming block size
87
+ // (so it is different from the one stored in corresponding mixin)
88
+ // regenerate the kernel (not writing it into the cache so we avoid locks)
89
+ if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
90
+ return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
91
+ }
92
+ // else retrieve the cached kernel from the corresponding mixin
93
+ return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
94
+ }
95
+ };
96
+
97
+ // instantiate the cache with the list of storage mixins
98
+ // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
99
+ using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
100
+ _CallbackAndBlockSize<true, int32_t, float>,
101
+ _CallbackAndBlockSize<false, int32_t, float>,
102
+ _CallbackAndBlockSize<true, int64_t, float>,
103
+ _CallbackAndBlockSize<false, int64_t, float>,
104
+ _CallbackAndBlockSize<true, int32_t, unsigned short>,
105
+ _CallbackAndBlockSize<false, int32_t, unsigned short>,
106
+ _CallbackAndBlockSize<true, int64_t, unsigned short>,
107
+ _CallbackAndBlockSize<false, int64_t, unsigned short>>;
108
+ #else
109
+ struct _EmbeddingBagKernelCache {
110
+ explicit _EmbeddingBagKernelCache(c10::optional<int64_t> /* maybe_block_size */) {}
111
+ };
112
+ #endif
113
+
114
+ void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
115
+ Tensor& bag_size, Tensor* max_indices,
116
+ const Tensor &weight, const Tensor &indices,
117
+ const Tensor &offsets, const int64_t mode = 0,
118
+ const c10::optional<Tensor>& per_sample_weights = c10::nullopt,
119
+ bool include_last_offset = false,
120
+ int64_t padding_idx = -1,
121
+ _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
122
+
123
+ void _embedding_bag_cpu_out(
124
+ at::Tensor& output,
125
+ at::Tensor& offset2bag,
126
+ at::Tensor& bag_size,
127
+ at::Tensor* p_max_indices,
128
+ const at::Tensor& weight,
129
+ const at::Tensor& indices,
130
+ const at::Tensor& offsets,
131
+ const bool scale_grad_by_freq,
132
+ const int64_t mode,
133
+ const bool sparse,
134
+ const c10::optional<at::Tensor>& per_sample_weights,
135
+ const bool include_last_offset,
136
+ const c10::optional<int64_t>& padding_idx,
137
+ _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
138
+
139
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Fill.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Functions that fill Tensors with constants. Implementations are in Fill.cpp.
2
+
3
+ #pragma once
4
+
5
+ #include <ATen/native/DispatchStub.h>
6
+
7
+ namespace c10 {
8
+ class Scalar;
9
+ }
10
+
11
+ namespace at {
12
+ class Tensor;
13
+ struct TensorIterator;
14
+
15
+ namespace native {
16
+
17
+ DECLARE_DISPATCH(void(*)(TensorIterator&, const c10::Scalar&), fill_stub);
18
+
19
+ Tensor& fill_out(Tensor& self, const Scalar& value);
20
+
21
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/LossMulti.h ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/AccumulateType.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/TensorUtils.h>
6
+
7
+ namespace at::native {
8
+ namespace {
9
+ static C10_UNUSED void multilabel_margin_loss_shape_check(
10
+ int64_t& nframe,
11
+ int64_t& dim,
12
+ const int64_t& ndims,
13
+ const Tensor& input,
14
+ const Tensor& target) {
15
+ TORCH_CHECK(
16
+ (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
17
+ "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
18
+ input.sizes());
19
+
20
+ if (ndims <= 1) {
21
+ nframe = 1;
22
+ dim = ndims == 0 ? 1 : input.size(0);
23
+ TORCH_CHECK(
24
+ target.dim() <= 1 && target.numel() == dim,
25
+ "inconsistent target size: ", target.sizes(), " for input of size: ",
26
+ input.sizes());
27
+ } else {
28
+ nframe = input.size(0);
29
+ dim = input.size(1);
30
+ TORCH_CHECK(
31
+ target.dim() == 2 && target.size(0) == nframe &&
32
+ target.size(1) == dim,
33
+ "inconsistent target size: ", target.sizes(), " for input of size: ",
34
+ input.sizes());
35
+ }
36
+ }
37
+
38
+ static C10_UNUSED void multi_margin_loss_shape_check(
39
+ int64_t& nframe,
40
+ int64_t& dim,
41
+ const int64_t& ndims,
42
+ const Tensor& input,
43
+ const Tensor& target,
44
+ const c10::optional<Tensor>& weight) {
45
+ TORCH_CHECK(
46
+ (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
47
+ "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
48
+ input.sizes());
49
+
50
+ if (ndims <= 1) {
51
+ nframe = 1;
52
+ dim = ndims == 0 ? 1 : input.size(0);
53
+ } else {
54
+ nframe = input.size(0);
55
+ dim = input.size(1);
56
+ }
57
+
58
+ TORCH_CHECK(
59
+ target.dim() <= 1 && target.numel() == nframe,
60
+ "inconsistent target size, expected ", nframe, " but got ",
61
+ target.sizes());
62
+ if (weight && weight->defined()) {
63
+ TORCH_CHECK(
64
+ weight->dim() <= 1 && weight->numel() == dim,
65
+ "inconsistent weight size, expected ", dim, " but got ",
66
+ weight->sizes());
67
+ }
68
+ }
69
+
70
+
71
+ } // anonymous namespace
72
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Normalization.h ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/TensorIterator.h>
4
+ #include <ATen/native/DispatchStub.h>
5
+
6
+ namespace at::native {
7
+
8
+ using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
9
+ DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
10
+
11
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/Pow.h ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace c10 {
6
+ class Scalar;
7
+ }
8
+
9
+ namespace at {
10
+
11
+ struct TensorIterator;
12
+ struct TensorIteratorBase;
13
+
14
+ namespace native {
15
+
16
+ #if defined(__CUDACC__) || defined(__HIPCC__)
17
+ #define HOST_DEVICE __host__ __device__
18
+ #else
19
+ #define HOST_DEVICE
20
+ #endif
21
+
22
+ // integral power in pytorch allows for negative exponents, giving truncated integral results.
23
+ // e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the
24
+ // only non-zero result.
25
+ template <class T,
26
+ typename std::enable_if<std::is_integral<T>::value, T>::type* = nullptr>
27
+ static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
28
+ T result = 1;
29
+ while (b) {
30
+ if (b & 1) {
31
+ result *= a;
32
+ }
33
+ b /= 2;
34
+ a *= a;
35
+ }
36
+ return result;
37
+ }
38
+
39
+ template <class T,
40
+ typename std::enable_if<std::is_integral<T>::value && !std::is_signed<T>::value, T>::type* = nullptr>
41
+ static inline HOST_DEVICE T powi(T a, T b) {
42
+ return powi_impl(a, b);
43
+ }
44
+
45
+ template <class T,
46
+ typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value, T>::type* = nullptr>
47
+ static inline HOST_DEVICE T powi(T a, T b) {
48
+ if ( b < 0 ) {
49
+ if ( a == 1 ) {
50
+ return 1;
51
+ } else if ( a == -1 ) {
52
+ auto negative = (-b) % static_cast<T>(2);
53
+ return negative ? -1 : 1;
54
+ } else {
55
+ return 0;
56
+ }
57
+ }
58
+ return powi_impl(a, b);
59
+ }
60
+
61
+ using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&);
62
+ using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
63
+
64
+ DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub);
65
+ DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub);
66
+
67
+ } // namespace native
68
+
69
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/ReduceOps.h ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <c10/util/ArrayRef.h>
5
+ #include <c10/util/Optional.h>
6
+
7
+ namespace c10 {
8
+ class Scalar;
9
+ }
10
+
11
+ namespace at {
12
+ struct TensorIterator;
13
+ class Tensor;
14
+ }
15
+
16
+ namespace at::native {
17
+
18
+ using reduce_fn = void(*)(TensorIterator &);
19
+
20
+ DECLARE_DISPATCH(reduce_fn, sum_stub);
21
+ DECLARE_DISPATCH(reduce_fn, nansum_stub);
22
+ DECLARE_DISPATCH(reduce_fn, prod_stub);
23
+ DECLARE_DISPATCH(reduce_fn, mean_stub);
24
+ DECLARE_DISPATCH(reduce_fn, and_stub);
25
+ DECLARE_DISPATCH(reduce_fn, or_stub);
26
+ DECLARE_DISPATCH(reduce_fn, min_values_stub);
27
+ DECLARE_DISPATCH(reduce_fn, max_values_stub);
28
+ DECLARE_DISPATCH(reduce_fn, argmax_stub);
29
+ DECLARE_DISPATCH(reduce_fn, argmin_stub);
30
+
31
+ using reduce_std_var_function =
32
+ void (*)(TensorIterator&, double correction, bool take_sqrt);
33
+ DECLARE_DISPATCH(reduce_std_var_function, std_var_stub);
34
+
35
+ using reduce_norm_fn =
36
+ void (*)(Tensor&, const Tensor&, const c10::Scalar&, c10::optional<int64_t>);
37
+ DECLARE_DISPATCH(reduce_norm_fn, norm_kernel);
38
+
39
+ using reduce_fn_flag = void(*)(TensorIterator &, const c10::Scalar&);
40
+ DECLARE_DISPATCH(reduce_fn_flag, norm_stub);
41
+
42
+ using structured_cum_fn = void (*)(const Tensor&, const Tensor&, int64_t);
43
+ using cum_fn = void (*)(Tensor&, const Tensor&, int64_t);
44
+ DECLARE_DISPATCH(structured_cum_fn, cumsum_stub);
45
+ DECLARE_DISPATCH(structured_cum_fn, cumprod_stub);
46
+ DECLARE_DISPATCH(cum_fn, logcumsumexp_stub);
47
+
48
+ DECLARE_DISPATCH(void (*)(const Tensor&, int64_t, bool, Tensor&, Tensor&), aminmax_stub);
49
+ DECLARE_DISPATCH(void (*)(const Tensor&, Tensor&, Tensor&), aminmax_allreduce_stub);
50
+
51
+ // Used in cuda/Normalization.cu
52
+ TORCH_API std::tuple<Tensor&,Tensor&> var_mean_out(
53
+ Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim,
54
+ int64_t correction, bool keepdim);
55
+
56
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/SobolEngineOpsUtils.h ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /// This file contains some tensor-agnostic operations to be used in the
2
+ /// core functions of the `SobolEngine`
3
+ #include <ATen/core/Tensor.h>
4
+
5
+ #ifndef AT_PER_OPERATOR_HEADERS
6
+ #include <ATen/Functions.h>
7
+ #else
8
+ #include <ATen/ops/arange.h>
9
+ #include <ATen/ops/mul.h>
10
+ #include <ATen/ops/pow.h>
11
+ #endif
12
+
13
+ namespace at::native::sobol_utils {
14
+
15
+ /// Function to return the minimum of number of bits to represent the integer `n`
16
+ inline int64_t bit_length(const int64_t n) {
17
+ int64_t nbits, nloc;
18
+ for (nloc = n, nbits = 0; nloc > 0; nloc /= 2, nbits++);
19
+ return nbits;
20
+ }
21
+
22
+ /// Function to get the position of the rightmost zero in the bit representation of an integer
23
+ /// This value is the zero-indexed position
24
+ inline int64_t rightmost_zero(const int64_t n) {
25
+ int64_t z, i;
26
+ for (z = n, i = 0; z % 2 == 1; z /= 2, i++);
27
+ return i;
28
+ }
29
+
30
+ /// Function to get a subsequence of bits in the representation of an integer starting from
31
+ /// `pos` and of length `length`
32
+ inline int64_t bitsubseq(const int64_t n, const int64_t pos, const int64_t length) {
33
+ return (n >> pos) & ((1 << length) - 1);
34
+ }
35
+
36
+ /// Function to perform the inner product between a batched square matrix and a power of 2 vector
37
+ inline at::Tensor cdot_pow2(const at::Tensor& bmat) {
38
+ at::Tensor inter = at::arange(bmat.size(-1) - 1, -1, -1, bmat.options());
39
+ inter = at::pow(2, inter).expand_as(bmat);
40
+ return at::mul(inter, bmat).sum(-1);
41
+ }
42
+
43
+ /// All definitions below this point are data. These are constant, and should not be modified
44
+ /// without notice
45
+
46
+ constexpr int64_t MAXDIM = 21201;
47
+ constexpr int64_t MAXDEG = 18;
48
+ constexpr int64_t MAXBIT = 30;
49
+ constexpr int64_t LARGEST_NUMBER = 1 << MAXBIT;
50
+ constexpr float RECIPD = 1.0 / LARGEST_NUMBER;
51
+
52
+ extern const int64_t poly[MAXDIM];
53
+ extern const int64_t initsobolstate[MAXDIM][MAXDEG];
54
+
55
+ } // namespace at::native::sobol_utils
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorCompare.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+
5
+ namespace c10 {
6
+ class Scalar;
7
+ }
8
+
9
+ namespace at {
10
+ class Tensor;
11
+ struct TensorIterator;
12
+ struct TensorIteratorBase;
13
+ }
14
+
15
+ namespace at::native {
16
+
17
+ using reduce_minmax_fn =
18
+ void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
19
+ using structured_reduce_minmax_fn =
20
+ void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool);
21
+
22
+ DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub);
23
+ DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub);
24
+
25
+ using where_fn = void (*)(TensorIterator &);
26
+ DECLARE_DISPATCH(where_fn, where_kernel);
27
+
28
+ using is_infinity_op_fn = void (*)(TensorIteratorBase &);
29
+ DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
30
+ DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
31
+
32
+ using mode_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
33
+ DECLARE_DISPATCH(mode_fn, mode_stub);
34
+
35
+ using clamp_tensor_fn = void (*)(TensorIteratorBase &);
36
+ DECLARE_DISPATCH(clamp_tensor_fn, clamp_stub);
37
+
38
+ namespace detail {
39
+ enum class ClampLimits {Min, Max, MinMax};
40
+ }
41
+
42
+ DECLARE_DISPATCH(void (*)(TensorIteratorBase &, const c10::Scalar&, const c10::Scalar&), clamp_scalar_stub);
43
+ DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_min_scalar_stub);
44
+ DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_max_scalar_stub);
45
+
46
+ using isin_default_fn = void (*)(const Tensor&, const Tensor&, bool, const Tensor&);
47
+ DECLARE_DISPATCH(isin_default_fn, isin_default_stub);
48
+
49
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TensorIterator.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <ATen/TensorIterator.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/TriangularOpsUtils.h ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/Tensor.h>
2
+ #include <ATen/native/LinearAlgebraUtils.h>
3
+
4
+ namespace at::native {
5
+
6
+ /*
7
+ * Given batches of matrices with arbitrary batch dim,
8
+ * computes the number of batches for Triu and Tril. This ignores stride 0 dimension
9
+ */
10
+ static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
11
+ int64_t result = 1;
12
+ for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
13
+ if (batched_matrices.stride(i) != 0) {
14
+ result *= batched_matrices.size(i);
15
+ }
16
+ }
17
+ return result;
18
+ }
19
+
20
+ /* Checks a necessary property for the triu and tril implementations, hence the name.
21
+ * Here batch contiguity is checked for tensors with greater than 4 dimensions.
22
+ * Contiguous tensors and tensors with less than 3 dimensions pass this check
23
+ */
24
+ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
25
+ // Complete contiguity is the most desired property, which is why
26
+ // we return true if the tensor is contiguous
27
+ if (tensor.is_contiguous()) {
28
+ auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
29
+ if (tensor.strides() == default_strides_for_size) {
30
+ return std::make_tuple(true, tensor);
31
+ } else {
32
+ return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
33
+ }
34
+ }
35
+
36
+ int64_t dims = tensor.dim();
37
+
38
+ // Tensors with dimension less than 4 are handled by default
39
+ if (allow_zero_stride && dims <= 3) {
40
+ return std::make_tuple(true, tensor);
41
+ }
42
+
43
+ int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
44
+ for (int64_t i = dims - 3; i >= 0; i--) {
45
+ // Skip trivial dimension;
46
+ if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
47
+ continue;
48
+ }
49
+ if (expected_stride != tensor.stride(i)) {
50
+ return std::make_tuple(false, tensor.contiguous());
51
+ }
52
+ expected_stride *= tensor.size(i);
53
+ }
54
+ return std::make_tuple(true, tensor);
55
+ }
56
+
57
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/IsContiguous.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at { namespace native { inline namespace CPU_CAPABILITY {
4
+
5
+ // n: number of function arguments (arity)
6
+ // traits: function_traits (see FunctionTraits.h)
7
+ // s: index of scalar argument or -1
8
+ template <int n, int stride_index, typename traits, int s=-1>
9
+ struct IsContiguous {
10
+ static bool eval(const int64_t* strides) {
11
+ using type = typename traits::template arg<n - 1>::type;
12
+ return strides[stride_index] == (s == n ? 0 : sizeof(type)) &&
13
+ IsContiguous<n - 1, stride_index - 1, traits, s>::eval(strides);
14
+ }
15
+ };
16
+
17
+ // will be called when there is an output exists
18
+ template <typename traits, int s>
19
+ struct IsContiguous<0, 0, traits, s> {
20
+ static bool eval(const int64_t* strides) {
21
+ return strides[0] == sizeof(typename traits::result_type);
22
+ }
23
+ };
24
+
25
+ // will be called when there is no output
26
+ template <typename traits, int s>
27
+ struct IsContiguous<0, -1, traits, s> {
28
+ static bool eval(const int64_t* /*strides*/) {
29
+ return true;
30
+ }
31
+ };
32
+
33
+ // output and all inputs are contiguous
34
+ template <typename traits,
35
+ typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
36
+ static inline bool is_contiguous(const int64_t* strides) {
37
+ return IsContiguous<traits::arity, traits::arity - 1, traits>::eval(strides);
38
+ }
39
+
40
+ template <typename traits,
41
+ typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
42
+ static inline bool is_contiguous(const int64_t* strides) {
43
+ return IsContiguous<traits::arity, traits::arity, traits>::eval(strides);
44
+ }
45
+
46
+ // input at `s` is scalar (stride 0); output and other inputs are contiguous
47
+ // NB: output is typically at strides[0] so first input corresponds to s=1
48
+ template <typename traits, int s,
49
+ typename std::enable_if<std::is_void<typename traits::result_type>::value>::type* = nullptr>
50
+ static inline bool is_contiguous_scalar(const int64_t* strides) {
51
+ static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
52
+ return IsContiguous<traits::arity, traits::arity - 1, traits, s>::eval(strides);
53
+ }
54
+
55
+ template <typename traits, int s,
56
+ typename std::enable_if<!std::is_void<typename traits::result_type>::value>::type* = nullptr>
57
+ static inline bool is_contiguous_scalar(const int64_t* strides) {
58
+ static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds");
59
+ return IsContiguous<traits::arity, traits::arity, traits, s>::eval(strides);
60
+ }
61
+
62
+ }}}
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cpu/SoftmaxKernel.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/native/DispatchStub.h>
4
+ #include <cstdint>
5
+
6
+ namespace at {
7
+ class Tensor;
8
+
9
+ namespace native {
10
+
11
+ using forward_fn = void (*)(const Tensor&, const Tensor&);
12
+ using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
13
+
14
+ DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel);
15
+ DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel);
16
+ DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel);
17
+ DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel);
18
+
19
+ using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
20
+ using backward_fn_with_dim =
21
+ void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
22
+
23
+ DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel);
24
+ DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel);
25
+ DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel);
26
+ DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel);
27
+ }
28
+ }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/CUDAJitLoops.cuh ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/jit_macros.h>
3
+
4
+ // Jiterator functions are guarded behind this macro
5
+ #if AT_USE_JITERATOR()
6
+
7
+ #include <ATen/OpMathType.h>
8
+ #include <ATen/TensorIterator.h>
9
+ #include <ATen/core/Array.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <ATen/cuda/detail/OffsetCalculator.cuh>
12
+ #include <ATen/native/cuda/jit_utils.h>
13
+ #include <ATen/native/cuda/MemoryAccess.cuh>
14
+ #include <ATen/native/cuda/thread_constants.h>
15
+
16
+ #include <ATen/native/cuda/Loops.cuh>
17
+
18
+ #include <c10/macros/Macros.h>
19
+ #include <c10/core/ScalarType.h>
20
+ #include <c10/util/SmallBuffer.h>
21
+
22
+ #include <initializer_list>
23
+ #include <type_traits>
24
+ #include <tuple>
25
+ #include <mutex>
26
+
27
+ namespace at {
28
+ namespace native {
29
+
30
+ template <typename Tuple, std::size_t... I>
31
+ constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
32
+ constexpr auto size = seq.size();
33
+ (void)t; // warning : unused parameter when tuple is empty.
34
+ return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
35
+ }
36
+
37
+ // Helper function convert tuple to std::array<void*, N>
38
+ // for passing the arguments to CUDA Kernel
39
+ // NOTE: We capture tuple by reference,
40
+ // so the pointers in returned array are only valid
41
+ // till tuple is alive.
42
+ template <typename ...Args>
43
+ constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
44
+ constexpr auto tuple_size = sizeof...(Args);
45
+ return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
46
+ }
47
+
48
+ struct JittedVecKernelCache {
49
+ // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
50
+ at::cuda::jit::NvrtcFunction vec1;
51
+ at::cuda::jit::NvrtcFunction vec2;
52
+ at::cuda::jit::NvrtcFunction vec4;
53
+ };
54
+
55
+ struct JittedKernelVariantCache {
56
+ JittedVecKernelCache vec;
57
+ at::cuda::jit::NvrtcFunction noncontiguous;
58
+ at::cuda::jit::NvrtcFunction dynamic_contiguous;
59
+ at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
60
+ };
61
+
62
+ inline c10::SmallBuffer<void*, 64> pack_kernel_args(
63
+ std::initializer_list<void*> args,
64
+ c10::ArrayRef<void*> extra_args) {
65
+ c10::SmallBuffer<void*, 64> ret(args.size() + extra_args.size());
66
+ std::copy(args.begin(), args.end(), ret.data());
67
+ std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
68
+ return ret;
69
+ }
70
+
71
+ template<typename array_t,
72
+ typename inp_calc_t,
73
+ typename out_calc_t,
74
+ typename loader_t,
75
+ typename storer_t>
76
+ void launch_jitted_unrolled_kernel(
77
+ std::mutex &jiterator_mutex,
78
+ at::cuda::jit::NvrtcFunction &fn_cache,
79
+ const at::cuda::jit::KernelDescriptor &desc,
80
+ int64_t N,
81
+ array_t data,
82
+ inp_calc_t ic,
83
+ out_calc_t oc,
84
+ loader_t l,
85
+ storer_t s,
86
+ bool contiguous,
87
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
88
+ void* scalar_val,
89
+ c10::ArrayRef<void*> extra_args) {
90
+
91
+ TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
92
+ //casting result to int is always safe, intermediate is int64 and won't overflow
93
+ const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
94
+
95
+ if (!fn_cache.function) {
96
+ const std::lock_guard<std::mutex> lock{jiterator_mutex};
97
+ if (!fn_cache.function) {
98
+ constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
99
+ !std::is_same<decltype(s), memory::StoreWithoutCast>();
100
+ auto code = at::cuda::jit::generate_code(
101
+ desc, contiguous, dynamic_casting, scalar_pos);
102
+ fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
103
+ }
104
+ }
105
+
106
+ auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
107
+ at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
108
+ {num_threads(), 1u, 1u});
109
+ }
110
+
111
+ template<int arity, typename array_t>
112
+ void launch_jitted_vectorized_kernel(
113
+ std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
114
+ const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
115
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
116
+ void *scalar_val, c10::ArrayRef<void*> extra_args) {
117
+ TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
118
+ // N is still int64_t for the computation, but it's always safe to cast result to int
119
+ const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
120
+ const int vec_size = at::cuda::jit::can_vectorize_up_to(
121
+ desc, c10::ArrayRef<char*>(data.data, data.size()));
122
+
123
+ // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
124
+ // fn_ptr is set to the appropriate function based on the vec size and GPU used
125
+ at::cuda::jit::NvrtcFunction* fn_ptr;
126
+ if (vec_size == 4) {
127
+ fn_ptr = &fn_cache.vec4;
128
+ } else if (vec_size == 2) {
129
+ fn_ptr = &fn_cache.vec2;
130
+ } else if (vec_size ==1) {
131
+ fn_ptr = &fn_cache.vec1;
132
+ } else {
133
+ TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
134
+ }
135
+
136
+ bool vectorized = vec_size > 1;
137
+
138
+ if (!fn_ptr->function) {
139
+ const std::lock_guard<std::mutex> lock{jiterator_mutex};
140
+ if (!fn_ptr->function) { // cache miss!
141
+
142
+ // Generates program
143
+ auto code = at::cuda::jit::generate_code(
144
+ desc, /*contiguous=*/true, /*dynamic_casting=*/false,
145
+ scalar_pos, vectorized, vec_size);
146
+ std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
147
+
148
+ // Acquires the program
149
+ *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
150
+ }
151
+ }
152
+
153
+ if (vectorized) {
154
+ auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
155
+ at::cuda::jit::launch_jitted_pwise_function(
156
+ *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
157
+ } else {
158
+ // NVCC complains about unused variables l and s.
159
+ // It should be false positive in most cases, so we suppress the warnings.
160
+ #pragma nv_diagnostic push
161
+ #pragma nv_diag_suppress 177
162
+ auto ic = TrivialOffsetCalculator<arity>();
163
+ auto oc = TrivialOffsetCalculator<1>();
164
+ auto l = memory::LoadWithoutCast();
165
+ auto s = memory::StoreWithoutCast();
166
+
167
+ auto args = pack_kernel_args(
168
+ {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
169
+ at::cuda::jit::launch_jitted_pwise_function(
170
+ *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
171
+ #pragma nv_diagnostic pop
172
+ }
173
+ }
174
+
175
+ template <int arity>
176
+ void jitted_gpu_kernel_generic(
177
+ std::mutex &jiterator_mutex,
178
+ JittedKernelVariantCache &cache,
179
+ const at::cuda::jit::KernelDescriptor &desc,
180
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
181
+ c10::ArrayRef<void*> extra_args,
182
+ TensorIteratorBase& iter,
183
+ const bool dynamic_casting,
184
+ void *scalar_val) {
185
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
186
+ TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
187
+ TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
188
+
189
+ constexpr int ntensors = arity + 1;
190
+ at::detail::Array<char*, ntensors> data;
191
+ for (auto i : c10::irange(ntensors)) {
192
+ data[i] = (char*)iter.data_ptr(i);
193
+ }
194
+
195
+ int64_t numel = iter.numel();
196
+ bool contiguous = iter.is_contiguous();
197
+
198
+ // Decides which of 4 kernel types to launch
199
+ // Variations are:
200
+ // - Case 1: no dynamic casting and contiguous
201
+ // - Case 2: no dynamic casting and noncontiguous
202
+ // - Case 3: dynamic casting and contiguous
203
+ // - Case 4: dynamic casting and noncontiguous
204
+ // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
205
+
206
+ if (!dynamic_casting) {
207
+ if (contiguous) {
208
+ // Case 1: no dynamic casting and contiguous
209
+ launch_jitted_vectorized_kernel<arity>(
210
+ jiterator_mutex, cache.vec, desc,
211
+ numel, data, scalar_pos, scalar_val, extra_args);
212
+ return;
213
+ }
214
+
215
+ // Case 2: no dynamic casting and noncontiguous
216
+ auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
217
+ auto output_offset_calculator = make_output_offset_calculator(iter);
218
+ auto loader = memory::LoadWithoutCast();
219
+ auto storer = memory::StoreWithoutCast();
220
+ launch_jitted_unrolled_kernel(
221
+ jiterator_mutex, cache.noncontiguous, desc, numel, data,
222
+ input_offset_calculator, output_offset_calculator, loader,
223
+ storer, contiguous, scalar_pos, scalar_val, extra_args);
224
+ return;
225
+ }
226
+
227
+ // Cases 3 and 4 are handled below
228
+ // Both require construction of a storer (this asserts 1 output) and one or more loaders
229
+
230
+ // Creates store cast to output (the zeroth tensor in TensorIterator)
231
+ auto storer = memory::StoreWithCast<1>(iter);
232
+
233
+ // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
234
+ auto loader = memory::LoadWithCast<arity>(iter);
235
+
236
+ if (contiguous) {
237
+ // Case 3: dynamic casting and contiguous
238
+ auto input_offset_calculator = TrivialOffsetCalculator<arity>();
239
+ auto output_offset_calculator = TrivialOffsetCalculator<1>();
240
+ launch_jitted_unrolled_kernel(
241
+ jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
242
+ output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
243
+ return;
244
+ }
245
+
246
+ // Case 4: dynamic casting and noncontiguous
247
+ auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
248
+ auto output_offset_calculator = make_output_offset_calculator(iter);
249
+ launch_jitted_unrolled_kernel(
250
+ jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
251
+ output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
252
+ }
253
+
254
+ // NOTE: static to reduce chances of name collision.
255
+ template <
256
+ char const* name,
257
+ typename result_type,
258
+ typename f_inputs_type,
259
+ int arity,
260
+ at::cuda::jit::BinaryFuncVariant scalar_pos =
261
+ at::cuda::jit::BinaryFuncVariant::NoScalar,
262
+ typename... ExtraArgs>
263
+ static void jitted_gpu_kernel_impl(
264
+ TensorIteratorBase& iter,
265
+ const std::string &f,
266
+ const bool dynamic_casting,
267
+ at::opmath_type<f_inputs_type> scalar_val,
268
+ std::tuple<ExtraArgs...> extra_args) {
269
+
270
+ // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
271
+ // the same compute capability
272
+ static std::mutex jiterator_mutex;
273
+ static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
274
+
275
+ constexpr int nInputs = arity;
276
+ constexpr int nOutputs = 1; // TODO: Support more than 1 output
277
+ static const auto desc = at::cuda::jit::make_kernel_descriptor<
278
+ result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
279
+
280
+ auto &cache = device_caches[iter.device().index()];
281
+ auto extra_args_array = tuple_to_array(extra_args);
282
+ return jitted_gpu_kernel_generic<arity>(
283
+ jiterator_mutex,
284
+ cache,
285
+ desc,
286
+ scalar_pos,
287
+ extra_args_array,
288
+ iter,
289
+ dynamic_casting,
290
+ &scalar_val
291
+ );
292
+ }
293
+
294
+ }} // at::native
295
+
296
+ #endif // AT_USE_JITERATOR()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/GridSampler.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <array>
3
+ #include <cstdint>
4
+
5
+ namespace at {
6
+ class TensorBase;
7
+ }
8
+
9
+ namespace at {
10
+ namespace native {
11
+
12
+ void launch_grid_sampler_2d_forward_kernel(
13
+ const TensorBase &output, const TensorBase &input, const TensorBase &grid,
14
+ int64_t interpolation_mode, int64_t padding_mode, bool align_corners);
15
+
16
+ void launch_grid_sampler_3d_forward_kernel(
17
+ const TensorBase &output, const TensorBase &input, const TensorBase &grid,
18
+ int64_t interpolation_mode, int64_t padding_mode, bool align_corners);
19
+
20
+ void launch_grid_sampler_2d_backward_kernel(
21
+ const TensorBase &grad_input, const TensorBase &grad_grid,
22
+ const TensorBase &grad_output, const TensorBase &input,
23
+ const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
24
+ bool align_corners, std::array<bool, 2> output_mask);
25
+
26
+ void launch_grid_sampler_3d_backward_kernel(
27
+ const TensorBase &grad_input, const TensorBase &grad_grid,
28
+ const TensorBase &grad_output, const TensorBase &input,
29
+ const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
30
+ bool align_corners, std::array<bool, 2> output_mask);
31
+
32
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MemoryAccess.cuh ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cstdint>
4
+ #include <type_traits>
5
+ #include <c10/core/DynamicCast.h>
6
+ #include <c10/util/Exception.h>
7
+ #include <c10/util/TypeCast.h>
8
+ #include <c10/macros/Macros.h>
9
+ #include <ATen/core/Array.h>
10
+ #include <ATen/detail/FunctionTraits.h>
11
+ #include <ATen/cuda/detail/OffsetCalculator.cuh>
12
+ #include <ATen/native/cuda/thread_constants.h>
13
+
14
+ #include <thrust/tuple.h>
15
+
16
+ // References:
17
+ // https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/
18
+
19
+ namespace at { namespace native { namespace memory {
20
+
21
+ namespace detail {
22
+
23
+ // What does the `static_unroll` do?
24
+ //
25
+ // We want to do something like:
26
+ //
27
+ // using args_t = typename traits::ArgsTuple;
28
+ // args_t args;
29
+ // #pragma unroll
30
+ // for (int i = 0; i < traits::arity; i++) {
31
+ // std::get<i>(args) = ....
32
+ // }
33
+ //
34
+ // but unfortunately the above code does not work because
35
+ // the template argument has to be a compile time constant
36
+ // so `static_unroll` is created to simulate `#pragma unroll`
37
+ // using template metaprogramming.
38
+
39
+ template<template<int i> typename func, int end, int current=0>
40
+ struct static_unroll {
41
+ template<typename... Args>
42
+ static inline C10_HOST_DEVICE void with_args(Args&&... args) {
43
+ func<current>::apply(std::forward<Args>(args)...);
44
+ static_unroll<func, end, current+1>::with_args(args...);
45
+ }
46
+ };
47
+
48
+ template<template<int i> typename func, int end>
49
+ struct static_unroll<func, end, end> {
50
+ template<typename... Args>
51
+ static inline C10_HOST_DEVICE void with_args(Args... args) {}
52
+ };
53
+
54
+ // helper structs to be used with static_unroll to load arguments
55
+ // one by one
56
+
57
+ template<int arg_index>
58
+ struct vectorized_load_helper {
59
+ template <typename args_t, typename policy_t>
60
+ static __device__ void apply(policy_t &self, args_t *args, int idx) {
61
+ using arg_t = std::tuple_element_t<arg_index, args_t>;
62
+ // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
63
+ // need a +1 offset to get the input
64
+ auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
65
+ auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
66
+ self.load_single_arg(args_accessor, ptr);
67
+ }
68
+ };
69
+
70
+ template<int arg_index>
71
+ struct unroll_load_helper {
72
+ template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
73
+ static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) {
74
+ using arg_t = std::tuple_element_t<arg_index, args_t>;
75
+ // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
76
+ // need a +1 offset to get the input
77
+ std::get<arg_index>(args[j]) = loader.template load<arg_t>(self.data[arg_index + num_outputs], offset[arg_index], arg_index);
78
+ }
79
+ };
80
+
81
+ template <int current>
82
+ struct multi_outputs_store_helper {
83
+ template<int ntensors, int num_outputs, typename ...Args>
84
+ C10_HOST_DEVICE static void apply(
85
+ at::detail::Array<char*, ntensors> data,
86
+ at::detail::Array<uint32_t, num_outputs> offsets,
87
+ thrust::tuple<Args...> ret) {
88
+ using T = typename thrust::tuple_element<current, thrust::tuple<Args...>>::type;
89
+ T *to = reinterpret_cast<T *>(data[current]) + offsets[current];
90
+ *to = thrust::get<current>(ret);
91
+ }
92
+ };
93
+
94
+ } // namespace detail
95
+
96
+ struct LoadWithoutCast {
97
+ template<typename scalar_t>
98
+ __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
99
+ return c10::load(reinterpret_cast<scalar_t *>(base_ptr) + offset);
100
+ }
101
+ };
102
+
103
+ template <int N>
104
+ struct LoadWithCast {
105
+ using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
106
+ using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
107
+
108
+ array_t dtypes;
109
+ size_array_t element_sizes;
110
+
111
+ LoadWithCast(const TensorIteratorBase& iter) {
112
+ CUDA_KERNEL_ASSERT(iter.ninputs() == N);
113
+ #pragma unroll
114
+ for (auto i = 0; i < N; ++i) {
115
+ this->dtypes[i] = iter.dtype(i + iter.noutputs());
116
+ element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs()));
117
+ }
118
+ }
119
+
120
+ template<typename scalar_t>
121
+ __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
122
+ void *ptr = base_ptr + element_sizes[arg] * offset;
123
+ return c10::fetch_and_cast<scalar_t>(dtypes[arg], ptr);
124
+ }
125
+ };
126
+
127
+ struct StoreWithoutCast {
128
+ template<typename scalar_t>
129
+ __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
130
+ *(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
131
+ }
132
+ };
133
+
134
+ template <int N = 1>
135
+ struct StoreWithCast {
136
+ using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
137
+ using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
138
+
139
+ array_t dtypes;
140
+ size_array_t element_sizes;
141
+
142
+ StoreWithCast(const TensorIteratorBase& iter) {
143
+ CUDA_KERNEL_ASSERT(iter.noutputs() == N);
144
+ #pragma unroll
145
+ for (auto i = 0; i < N; ++i) {
146
+ this->dtypes[i] = iter.dtype(i);
147
+ element_sizes[i] = c10::elementSize(iter.dtype(i));
148
+ }
149
+ }
150
+
151
+ template<typename scalar_t>
152
+ __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
153
+ void *ptr = base_ptr + element_sizes[arg] * offset;
154
+ c10::cast_and_store<scalar_t>(dtypes[arg], ptr, value);
155
+ }
156
+ };
157
+
158
+ // aligned vector generates vectorized load/store on CUDA
159
+ template<typename scalar_t, int vec_size>
160
+ struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
161
+ scalar_t val[vec_size];
162
+ };
163
+
164
+ template <int vec_size, typename scalar_t>
165
+ __device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
166
+ using vec_t = aligned_vector<scalar_t, vec_size>;
167
+ auto *from = reinterpret_cast<const vec_t *>(base_ptr);
168
+ return from[offset];
169
+ }
170
+
171
+ template <int vec_size>
172
+ __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
173
+ // See NOTE [Loading boolean values]
174
+ auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
175
+ aligned_vector<bool, vec_size> ret;
176
+ for (int i = 0; i < vec_size; ++i) {
177
+ ret.val[i] = bool(tmp.val[i]);
178
+ }
179
+ return ret;
180
+ }
181
+
182
+ namespace policies {
183
+
184
+ // Assumption:
185
+ // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
186
+ template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
187
+ struct unroll {
188
+
189
+ data_t data;
190
+ int remaining;
191
+ inp_calc_t input_offset_calculator;
192
+ out_calc_t output_offset_calculator;
193
+ loader_t loader;
194
+ storer_t storer;
195
+
196
+ __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
197
+ data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
198
+
199
+ __device__ inline bool check_inbounds(int thread_work_elem) {
200
+ return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining);
201
+ }
202
+
203
+ template<typename args_t>
204
+ __device__ inline void load(args_t *args, int idx) {
205
+ constexpr int arity = std::tuple_size<args_t>::value;
206
+ int thread_idx = threadIdx.x;
207
+ #pragma unroll
208
+ for (int i = 0; i < thread_work_size(); i++) {
209
+ if (thread_idx >= remaining) {
210
+ return;
211
+ }
212
+ int linear_idx = thread_idx + block_work_size() * idx;
213
+ auto offset = input_offset_calculator.get(linear_idx);
214
+ detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
215
+ thread_idx += num_threads();
216
+ }
217
+ }
218
+
219
+ template<typename scalar_t>
220
+ __device__ inline void store(scalar_t *from, int idx) {
221
+ int thread_idx = threadIdx.x;
222
+ #pragma unroll
223
+ for (int i = 0; i < thread_work_size(); i++) {
224
+ if (thread_idx >= remaining) {
225
+ return;
226
+ }
227
+ int linear_idx = thread_idx + block_work_size() * idx;
228
+ int offset = output_offset_calculator.get(linear_idx)[0];
229
+ storer.store(from[i], data[0], offset);
230
+ thread_idx += num_threads();
231
+ }
232
+ }
233
+ };
234
+
235
+ // Assumption:
236
+ // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
237
+ // Note:
238
+ // Functions in vectorized policy does not do boundary check. It assumes the whole block
239
+ // has its job to do. So the reminders should be handled by the caller manually.
240
+ template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
241
+ struct vectorized {
242
+
243
+ static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
244
+ static constexpr int loop_size = thread_work_size() / vec_size;
245
+
246
+ data_t data;
247
+
248
+ __device__ vectorized(data_t data) : data(data) {}
249
+
250
+ __device__ inline constexpr bool check_inbounds(int thread_work_elem) {
251
+ return true;
252
+ }
253
+
254
+ template<typename accessor_t, typename scalar_t>
255
+ __device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
256
+ int thread_idx = threadIdx.x;
257
+ #pragma unroll
258
+ for (int i = 0; i < loop_size; i++) {
259
+ int index = thread_idx + i * num_threads();
260
+ auto v = load_vector<vec_size>(from, index);
261
+ #pragma unroll
262
+ for (int j = 0; j < vec_size; j++) {
263
+ to(vec_size * i + j) = v.val[j];
264
+ }
265
+ }
266
+ }
267
+
268
+ template<typename args_t>
269
+ __device__ inline void load(args_t *args, int idx) {
270
+ constexpr int arity = std::tuple_size<args_t>::value;
271
+ detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
272
+ }
273
+
274
+ template<typename scalar_t>
275
+ __device__ inline void store(scalar_t *from, int idx) {
276
+ using vec_t = aligned_vector<scalar_t, vec_size>;
277
+ scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
278
+ vec_t *to_ = reinterpret_cast<vec_t *>(to);
279
+ int thread_idx = threadIdx.x;
280
+ #pragma unroll
281
+ for (int i = 0; i < loop_size; i++) {
282
+ int index = thread_idx + i * num_threads();
283
+ vec_t v;
284
+ for (int j = 0; j < vec_size; j++) {
285
+ v.val[j] = from[vec_size * i + j];
286
+ }
287
+ to_[index] = v;
288
+ }
289
+ }
290
+ };
291
+
292
+ template <typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
293
+ struct multi_outputs_unroll {
294
+ //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct
295
+ //we don't use inheritance because of compiler bug in cuda 10.2+
296
+ data_t data;
297
+ int remaining;
298
+ inp_calc_t input_offset_calculator;
299
+ out_calc_t output_offset_calculator;
300
+ LoadWithoutCast loader;
301
+ StoreWithoutCast storer;
302
+
303
+ __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
304
+ data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
305
+
306
+ __device__ inline bool check_inbounds(int thread_work_elem) {
307
+ return ((int)(threadIdx.x + thread_work_elem*num_threads()) < remaining);
308
+ }
309
+
310
+ template<typename args_t>
311
+ __device__ inline void load(args_t *args, int idx) {
312
+ constexpr int arity = std::tuple_size<args_t>::value;
313
+ int thread_idx = threadIdx.x;
314
+ #pragma unroll
315
+ for (int i = 0; i < thread_work_size(); i++) {
316
+ if (thread_idx >= remaining) {
317
+ return;
318
+ }
319
+ int linear_idx = thread_idx + block_work_size() * idx;
320
+ auto offset = input_offset_calculator.get(linear_idx);
321
+ detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
322
+ thread_idx += num_threads();
323
+ }
324
+ }
325
+
326
+
327
+ template <typename return_t>
328
+ __device__ inline void store(return_t *from, int idx) {
329
+ int thread_idx = threadIdx.x;
330
+ #pragma unroll
331
+ for (int i = 0; i < thread_work_size(); i++) {
332
+ if (thread_idx >= this->remaining) {
333
+ return;
334
+ }
335
+ int linear_idx = thread_idx + block_work_size() * idx;
336
+ auto offsets = this->output_offset_calculator.get(linear_idx);
337
+ memory::detail::static_unroll<detail::multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
338
+ thread_idx += num_threads();
339
+ }
340
+ }
341
+ };
342
+
343
+ } // namespace policies
344
+
345
+ // This is only used in host, but we will wrap this into some templates
346
+ // which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE
347
+ // in order to compile
348
+ template<typename scalar_t>
349
+ inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) {
350
+ uint64_t address = reinterpret_cast<uint64_t>(pointer);
351
+ constexpr int vec2_alignment = std::alignment_of<aligned_vector<scalar_t, 2>>::value;
352
+ constexpr int vec4_alignment = std::alignment_of<aligned_vector<scalar_t, 4>>::value;
353
+ if (address % vec4_alignment == 0) {
354
+ return 4;
355
+ } else if (address % vec2_alignment == 0) {
356
+ return 2;
357
+ }
358
+ return 1;
359
+ }
360
+
361
+ template<int i>
362
+ struct can_vectorize_up_to_helper {
363
+ template <typename array_t, typename traits>
364
+ static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) {
365
+ using arg_t = typename traits::template arg<i>::type;
366
+ // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we
367
+ // need a +1 offset to get the input
368
+ result = std::min<int>(result, can_vectorize_up_to<arg_t>(pointers[i + 1]));
369
+ }
370
+ };
371
+
372
+ template<typename func_t, typename array_t>
373
+ inline int can_vectorize_up_to(array_t pointers) {
374
+ using traits = function_traits<func_t>;
375
+ using return_t = typename traits::result_type;
376
+ constexpr int arity = traits::arity;
377
+ int result = can_vectorize_up_to<return_t>(pointers[0]);
378
+ // We need to get the type for each argument of `func_t`, this can only
379
+ // be done at compile time.
380
+ detail::static_unroll<can_vectorize_up_to_helper, arity>::with_args(result, pointers, traits());
381
+ return result;
382
+ }
383
+
384
+ }}} // namespace at::native::memory
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/MultiTensorApply.cuh ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <c10/cuda/CUDAGuard.h>
5
+ #include <ATen/native/cuda/Loops.cuh>
6
+ #include <ATen/native/cuda/MemoryAccess.cuh>
7
+ #include <vector>
8
+
9
+ namespace at::native {
10
+
11
+ namespace {
12
+
13
+ static constexpr int64_t kILP = 4;
14
+ static constexpr int64_t kChunkSize = 65536;
15
+ static constexpr int64_t kBlockSize = 512;
16
+
17
+ // TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
18
+ // TensorListMetadata has to be < 4KB - the limit for kernel launch argument
19
+ static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
20
+ static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
21
+ static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
22
+ static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
23
+ 72,
24
+ 60};
25
+
26
+ template <typename T>
27
+ __device__ __forceinline__ bool is_aligned(T* p) {
28
+ return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
29
+ }
30
+
31
+ template <typename T>
32
+ __device__ __forceinline__ void load_store(
33
+ T* dst,
34
+ T* src,
35
+ int64_t dst_offset,
36
+ int64_t src_offset) {
37
+ using LT = at::native::memory::aligned_vector<T, kILP>;
38
+ ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
39
+ }
40
+
41
+ template <int n>
42
+ struct TensorListMetadata {
43
+ const void* addresses[n][depth_to_max_tensors[n - 1]];
44
+ int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
45
+ unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
46
+ int block_to_chunk[depth_to_max_blocks[n - 1]];
47
+ int start_tensor_this_launch;
48
+ };
49
+
50
+ template <typename scalar_vals_t, int n>
51
+ struct TensorListScalarListMetadata {
52
+ const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
53
+ int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
54
+ scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
55
+ unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
56
+ int block_to_chunk[depth_to_max_blocks[n - 1]];
57
+ };
58
+
59
+ // note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
60
+ // 4kb with `c10::complex<double>`
61
+ template <>
62
+ struct TensorListScalarListMetadata<c10::complex<double>, 1> {
63
+ const void* addresses[1]
64
+ [depth_to_max_tensors_scalarlist_of_complex_double[0]];
65
+ int64_t
66
+ numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
67
+ c10::complex<double>
68
+ scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
69
+ unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
70
+ int block_to_chunk[depth_to_max_blocks[1 - 1]];
71
+ };
72
+
73
+ template <>
74
+ struct TensorListScalarListMetadata<c10::complex<double>, 2> {
75
+ const void* addresses[2]
76
+ [depth_to_max_tensors_scalarlist_of_complex_double[1]];
77
+ int64_t
78
+ numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
79
+ c10::complex<double>
80
+ scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
81
+ unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
82
+ int block_to_chunk[depth_to_max_blocks[2 - 1]];
83
+ };
84
+
85
+ // NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
86
+ // whose each element is `at::Tensor` of 1 element representing the number of
87
+ // `step`s called so far.
88
+ template <int n>
89
+ struct FusedOptimizerTensorListMetadata {
90
+ const void* addresses[n][depth_to_max_tensors[n - 1]];
91
+ int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
92
+ const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
93
+ unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
94
+ int block_to_chunk[depth_to_max_blocks[n - 1]];
95
+ int start_tensor_this_launch;
96
+ };
97
+
98
+ template <typename T, typename U, typename... ArgTypes>
99
+ C10_LAUNCH_BOUNDS_1(kBlockSize)
100
+ __global__ void multi_tensor_apply_kernel(
101
+ T tensorListMeta,
102
+ U callable,
103
+ ArgTypes... args) {
104
+ // Hand the chunk information to the user-supplied functor to process however
105
+ // it likes.
106
+ callable(kChunkSize, tensorListMeta, args...);
107
+ }
108
+
109
+ } // namespace
110
+
111
+ // multi_tensor_apply enables horizontal fusion across lists of tensors.
112
+ // For example, whereas you once had a for-loop of a + b = c, where a, b,
113
+ // and c are individual tensors in lists as, bs, and cs, you can now with
114
+ // fewer kernel launches compute as + bs = cs.
115
+ //
116
+ // You can also imagine bs to be a scalar list vs a tensor list.
117
+ //
118
+ // The function below takes in tensor lists, scalars, and a callable and
119
+ // chunks up the computation to launch as few kernels as possible by iterating
120
+ // through every "chunk" in every tensor (thus the nested for loops). In the
121
+ // simplest case, everything gets bundled into just one kernel launch, but
122
+ // due to blocksize constraints, we may need to launch multiple kernels.
123
+ // Each kernel launch is defined by one tensorListMeta construct, which we
124
+ // use to track and reset the necessary metadata for each launch.
125
+ template <int depth, typename scalar_T, typename T, typename... ArgTypes>
126
+ void multi_tensor_apply(
127
+ std::vector<std::vector<at::Tensor>>& tensor_lists,
128
+ at::ArrayRef<Scalar> scalars,
129
+ T callable,
130
+ ArgTypes... args) {
131
+ TORCH_CHECK(
132
+ tensor_lists.size() == depth,
133
+ "Number of tensor lists has to match the depth.");
134
+ const size_t n_tensors = tensor_lists[0].size();
135
+ using scalar_vals_t = typename T::opmath_t;
136
+ TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
137
+
138
+ int loc_block_info = 0;
139
+ int loc_tensor_info = 0;
140
+ for (size_t t = 0; t < n_tensors; t++) {
141
+ // short-circuit to avoid adding empty tensors to tensorListMeta
142
+ if (tensor_lists[0][t].numel() == 0) {
143
+ continue;
144
+ }
145
+ tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to<scalar_T>();
146
+ tensorListMeta.numel_for_tensor[loc_tensor_info] =
147
+ tensor_lists[0][t].numel();
148
+ for (int d = 0; d < depth; d++) {
149
+ tensorListMeta.addresses[d][loc_tensor_info] =
150
+ tensor_lists[d][t].const_data_ptr();
151
+ }
152
+ loc_tensor_info++;
153
+
154
+ // now we enter [chunking territory].
155
+ // we will launch a kernel when EITHER the blocks get filled up OR
156
+ // the tensors get filled up. There will always be at least one block
157
+ // per tensor since the zero-sized ones will not enter the loop, so
158
+ // the nested forloop within represents iterating through the chunks
159
+ // of a single tensor.
160
+ const auto numel = tensor_lists[0][t].numel();
161
+ const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
162
+ for (auto chunk = 0; chunk < chunks; chunk++) {
163
+ tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
164
+ tensorListMeta.block_to_chunk[loc_block_info] = chunk;
165
+ loc_block_info++;
166
+
167
+ // a tensor is not considered full unless all its chunks have been
168
+ // processed
169
+ const bool tensors_full =
170
+ (loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
171
+ chunk == chunks - 1);
172
+ const bool blocks_full =
173
+ (loc_block_info == depth_to_max_blocks[depth - 1]);
174
+
175
+ if (tensors_full || blocks_full) {
176
+ multi_tensor_apply_kernel<<<
177
+ loc_block_info,
178
+ kBlockSize,
179
+ 0,
180
+ at::cuda::getCurrentCUDAStream()>>>(
181
+ tensorListMeta, callable, args...);
182
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
183
+
184
+ // Reset.
185
+ loc_block_info = 0;
186
+ // all chunks have already been handled in the kernel
187
+ if (chunk == chunks - 1) {
188
+ loc_tensor_info = 0;
189
+ } else { // blocks were full and tensor chunks remain
190
+ tensorListMeta.numel_for_tensor[0] =
191
+ tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
192
+ tensorListMeta.scalar_vals[0] =
193
+ tensorListMeta.scalar_vals[loc_tensor_info - 1];
194
+ for (int d = 0; d < depth; d++) {
195
+ tensorListMeta.addresses[d][0] =
196
+ tensorListMeta.addresses[d][loc_tensor_info - 1];
197
+ }
198
+ loc_tensor_info = 1;
199
+ }
200
+ }
201
+ }
202
+ }
203
+
204
+ // note: [finishing what we started]
205
+ // if there's remaining work to be done but the tensors/blocks aren't full
206
+ // yet we are at the end, submit the kernel to do the work!
207
+ if (loc_block_info != 0) {
208
+ multi_tensor_apply_kernel<<<
209
+ loc_block_info,
210
+ kBlockSize,
211
+ 0,
212
+ at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
213
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
214
+ }
215
+ }
216
+
217
+ template <int depth, typename T, typename... ArgTypes>
218
+ void multi_tensor_apply(
219
+ std::vector<std::vector<at::Tensor>>& tensor_lists,
220
+ T callable,
221
+ ArgTypes... args) {
222
+ TORCH_CHECK(
223
+ tensor_lists.size() == depth,
224
+ "Number of tensor lists has to match the depth.");
225
+ const size_t n_tensors = tensor_lists[0].size();
226
+ TensorListMetadata<depth> tensorListMeta;
227
+ tensorListMeta.start_tensor_this_launch = 0;
228
+
229
+ int loc_block_info = 0;
230
+ int loc_tensor_info = 0;
231
+ for (size_t t = 0; t < n_tensors; t++) {
232
+ // short-circuit to avoid adding empty tensors to tensorListMeta
233
+ if (tensor_lists[0][t].numel() == 0) {
234
+ continue;
235
+ }
236
+ tensorListMeta.numel_for_tensor[loc_tensor_info] =
237
+ tensor_lists[0][t].numel();
238
+ for (int d = 0; d < depth; d++) {
239
+ tensorListMeta.addresses[d][loc_tensor_info] =
240
+ tensor_lists[d][t].const_data_ptr();
241
+ }
242
+ loc_tensor_info++;
243
+
244
+ // see note: [chunking territory].
245
+ const auto numel = tensor_lists[0][t].numel();
246
+ const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
247
+ for (auto chunk = 0; chunk < chunks; chunk++) {
248
+ tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
249
+ tensorListMeta.block_to_chunk[loc_block_info] = chunk;
250
+ loc_block_info++;
251
+
252
+ const bool tensors_full =
253
+ (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
254
+ chunk == chunks - 1);
255
+ const bool blocks_full =
256
+ (loc_block_info == depth_to_max_blocks[depth - 1]);
257
+
258
+ if (tensors_full || blocks_full) {
259
+ multi_tensor_apply_kernel<<<
260
+ loc_block_info,
261
+ kBlockSize,
262
+ 0,
263
+ at::cuda::getCurrentCUDAStream()>>>(
264
+ tensorListMeta, callable, args...);
265
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
266
+
267
+ // Reset.
268
+ loc_block_info = 0;
269
+ if (chunk == chunks - 1) {
270
+ loc_tensor_info = 0;
271
+ tensorListMeta.start_tensor_this_launch = t + 1;
272
+ } else {
273
+ tensorListMeta.numel_for_tensor[0] =
274
+ tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
275
+ for (int d = 0; d < depth; d++) {
276
+ tensorListMeta.addresses[d][0] =
277
+ tensorListMeta.addresses[d][loc_tensor_info - 1];
278
+ }
279
+ loc_tensor_info = 1;
280
+ tensorListMeta.start_tensor_this_launch = t;
281
+ }
282
+ }
283
+ }
284
+ }
285
+
286
+ // see note: [finishing what we started]
287
+ if (loc_block_info != 0) {
288
+ multi_tensor_apply_kernel<<<
289
+ loc_block_info,
290
+ kBlockSize,
291
+ 0,
292
+ at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
293
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
294
+ }
295
+ }
296
+
297
+ template <int depth, typename T, typename... ArgTypes>
298
+ void multi_tensor_apply_for_fused_optimizer(
299
+ std::vector<std::vector<at::Tensor>>& tensor_lists,
300
+ at::TensorList state_steps,
301
+ T callable,
302
+ ArgTypes... args) {
303
+ TORCH_CHECK(
304
+ tensor_lists.size() == depth,
305
+ "Number of tensor lists has to match the depth");
306
+ const auto num_tensors = tensor_lists[0].size();
307
+ FusedOptimizerTensorListMetadata<depth> tensorListMeta;
308
+
309
+ int loc_block_info = 0;
310
+ int loc_tensor_info = 0;
311
+ for (const auto& tensor_index : c10::irange(num_tensors)) {
312
+ // short-circuit to avoid adding empty tensors to tensorListMeta
313
+ if (tensor_lists[0][tensor_index].numel() == 0) {
314
+ continue;
315
+ }
316
+ tensorListMeta.state_steps_addresses[loc_tensor_info] =
317
+ state_steps[tensor_index].const_data_ptr();
318
+ tensorListMeta.numel_for_tensor[loc_tensor_info] =
319
+ tensor_lists[0][tensor_index].numel();
320
+ for (const auto& d : c10::irange(depth)) {
321
+ tensorListMeta.addresses[d][loc_tensor_info] =
322
+ tensor_lists[d][tensor_index].const_data_ptr();
323
+ }
324
+ loc_tensor_info++;
325
+
326
+ // see above note: [chunking territory]
327
+ const auto numel = tensor_lists[0][tensor_index].numel();
328
+ const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
329
+ TORCH_CHECK(chunks > -1);
330
+ for (const auto& chunk : c10::irange(chunks)) {
331
+ tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
332
+ tensorListMeta.block_to_chunk[loc_block_info] = chunk;
333
+ loc_block_info++;
334
+
335
+ const auto tensor_full =
336
+ (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
337
+ chunk == chunks - 1);
338
+ const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
339
+
340
+ if (tensor_full || blocks_full) {
341
+ multi_tensor_apply_kernel<<<
342
+ loc_block_info,
343
+ kBlockSize,
344
+ 0,
345
+ at::cuda::getCurrentCUDAStream()>>>(
346
+ tensorListMeta, callable, args...);
347
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
348
+
349
+ // Reset.
350
+ loc_block_info = 0;
351
+ if (chunk == chunks - 1) {
352
+ loc_tensor_info = 0;
353
+ } else {
354
+ tensorListMeta.numel_for_tensor[0] =
355
+ tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
356
+ tensorListMeta.state_steps_addresses[0] =
357
+ tensorListMeta.state_steps_addresses[loc_tensor_info - 1];
358
+ for (const auto& d : c10::irange(depth)) {
359
+ tensorListMeta.addresses[d][0] =
360
+ tensorListMeta.addresses[d][loc_tensor_info - 1];
361
+ }
362
+ loc_tensor_info = 1;
363
+ }
364
+ }
365
+ }
366
+ }
367
+
368
+ // see above note: [finishing what we've started]
369
+ if (loc_block_info != 0) {
370
+ multi_tensor_apply_kernel<<<
371
+ loc_block_info,
372
+ kBlockSize,
373
+ 0,
374
+ at::cuda::getCurrentCUDAStream()>>>(tensorListMeta, callable, args...);
375
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
376
+ }
377
+ }
378
+
379
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Resize.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/EmptyTensor.h>
4
+ #include <ATen/native/ResizeCommon.h>
5
+
6
+ #include <c10/cuda/CUDAGuard.h>
7
+
8
+ namespace at { namespace native {
9
+
10
+ TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
11
+
12
+ static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) {
13
+ // It does not make sense to try to resize a storage
14
+ // to hold 0 elements, and this can break
15
+ // if storage_offset is positive but
16
+ // new_size is 0, so just bail in that case
17
+ // (same comment is in Resize.h)
18
+ if (self->numel() == 0) {
19
+ return;
20
+ }
21
+
22
+ const Storage &storage = self->unsafe_storage();
23
+ TORCH_CHECK(storage, "Tensor: invalid null storage");
24
+ if (new_size_bytes > storage.nbytes()) {
25
+ resize_bytes_cuda(storage.unsafeGetStorageImpl(), new_size_bytes);
26
+ }
27
+ }
28
+
29
+ inline TensorImpl* resize_impl_cuda_(
30
+ TensorImpl* self,
31
+ IntArrayRef size,
32
+ at::OptionalIntArrayRef stride,
33
+ bool device_guard = true) {
34
+ if (self->sizes() == size && (!stride || self->strides() == stride)) {
35
+ return self;
36
+ }
37
+
38
+ // NB: We don't need to hold the device guard when calling from TH
39
+ cuda::OptionalCUDAGuard guard;
40
+ if (device_guard) {
41
+ guard.set_index(self->storage().device().index());
42
+ }
43
+
44
+ const auto itemsize = self->dtype().itemsize();
45
+ const auto storage_offset = self->storage_offset();
46
+ size_t storage_size = 1;
47
+ if (stride) {
48
+ self->set_sizes_and_strides(size, *stride);
49
+ storage_size = at::detail::computeStorageNbytes(
50
+ size, *stride, itemsize, storage_offset);
51
+ } else {
52
+ self->set_sizes_contiguous(size);
53
+ storage_size = at::detail::computeStorageNbytesContiguous(
54
+ size, itemsize, storage_offset);
55
+ }
56
+ maybe_resize_storage_cuda(self, storage_size);
57
+
58
+ return self;
59
+ }
60
+
61
+ }}
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/Sort.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <cstdint>
3
+ #include <ATen/core/TensorBase.h>
4
+ #include <ATen/native/cuda/SortStable.h>
5
+
6
+ namespace at {
7
+ namespace native {
8
+
9
+ inline bool should_use_small_sort(const TensorBase &self, int64_t dim) {
10
+ return self.size(dim) <= 4096;
11
+ }
12
+
13
+ void sortKeyValueInplace(
14
+ const TensorBase &key, const TensorBase &value, int dim,
15
+ bool descending, bool stable=false);
16
+
17
+ }} // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+
4
+ namespace at {
5
+ namespace native {
6
+
7
+ void _fused_adam_amsgrad_cuda_impl_(
8
+ at::TensorList params,
9
+ at::TensorList grads,
10
+ at::TensorList exp_avgs,
11
+ at::TensorList exp_avg_sqs,
12
+ at::TensorList max_exp_avg_sqs,
13
+ at::TensorList state_steps,
14
+ const double lr,
15
+ const double beta1,
16
+ const double beta2,
17
+ const double weight_decay,
18
+ const double eps,
19
+ const bool maximize,
20
+ const c10::optional<at::Tensor>& grad_scale,
21
+ const c10::optional<at::Tensor>& found_inf);
22
+
23
+ void _fused_adam_amsgrad_cuda_impl_(
24
+ at::TensorList params,
25
+ at::TensorList grads,
26
+ at::TensorList exp_avgs,
27
+ at::TensorList exp_avg_sqs,
28
+ at::TensorList max_exp_avg_sqs,
29
+ at::TensorList state_steps,
30
+ const at::Tensor& lr,
31
+ const double beta1,
32
+ const double beta2,
33
+ const double weight_decay,
34
+ const double eps,
35
+ const bool maximize,
36
+ const c10::optional<at::Tensor>& grad_scale,
37
+ const c10::optional<at::Tensor>& found_inf);
38
+
39
+ } // namespace native
40
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adam_impl.cuh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+
4
+ namespace at {
5
+ namespace native {
6
+
7
+ void _fused_adam_cuda_impl_(
8
+ at::TensorList params,
9
+ at::TensorList grads,
10
+ at::TensorList exp_avgs,
11
+ at::TensorList exp_avg_sqs,
12
+ at::TensorList state_steps,
13
+ const double lr,
14
+ const double beta1,
15
+ const double beta2,
16
+ const double weight_decay,
17
+ const double eps,
18
+ const bool maximize,
19
+ const c10::optional<at::Tensor>& grad_scale,
20
+ const c10::optional<at::Tensor>& found_inf);
21
+
22
+ void _fused_adam_cuda_impl_(
23
+ at::TensorList params,
24
+ at::TensorList grads,
25
+ at::TensorList exp_avgs,
26
+ at::TensorList exp_avg_sqs,
27
+ at::TensorList state_steps,
28
+ const at::Tensor& lr,
29
+ const double beta1,
30
+ const double beta2,
31
+ const double weight_decay,
32
+ const double eps,
33
+ const bool maximize,
34
+ const c10::optional<at::Tensor>& grad_scale,
35
+ const c10::optional<at::Tensor>& found_inf);
36
+
37
+ } // namespace native
38
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+
4
+ namespace at {
5
+ namespace native {
6
+
7
+ void _fused_adamw_amsgrad_cuda_impl_(
8
+ at::TensorList params,
9
+ at::TensorList grads,
10
+ at::TensorList exp_avgs,
11
+ at::TensorList exp_avg_sqs,
12
+ at::TensorList max_exp_avg_sqs,
13
+ at::TensorList state_steps,
14
+ const double lr,
15
+ const double beta1,
16
+ const double beta2,
17
+ const double weight_decay,
18
+ const double eps,
19
+ const bool maximize,
20
+ const c10::optional<at::Tensor>& grad_scale,
21
+ const c10::optional<at::Tensor>& found_inf);
22
+
23
+ void _fused_adamw_amsgrad_cuda_impl_(
24
+ at::TensorList params,
25
+ at::TensorList grads,
26
+ at::TensorList exp_avgs,
27
+ at::TensorList exp_avg_sqs,
28
+ at::TensorList max_exp_avg_sqs,
29
+ at::TensorList state_steps,
30
+ const at::Tensor& lr,
31
+ const double beta1,
32
+ const double beta2,
33
+ const double weight_decay,
34
+ const double eps,
35
+ const bool maximize,
36
+ const c10::optional<at::Tensor>& grad_scale,
37
+ const c10::optional<at::Tensor>& found_inf);
38
+
39
+ } // namespace native
40
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/fused_adamw_impl.cuh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Tensor.h>
3
+
4
+ namespace at {
5
+ namespace native {
6
+
7
+ void _fused_adamw_cuda_impl_(
8
+ at::TensorList params,
9
+ at::TensorList grads,
10
+ at::TensorList exp_avgs,
11
+ at::TensorList exp_avg_sqs,
12
+ at::TensorList state_steps,
13
+ const double lr,
14
+ const double beta1,
15
+ const double beta2,
16
+ const double weight_decay,
17
+ const double eps,
18
+ const bool maximize,
19
+ const c10::optional<at::Tensor>& grad_scale,
20
+ const c10::optional<at::Tensor>& found_inf);
21
+
22
+ void _fused_adamw_cuda_impl_(
23
+ at::TensorList params,
24
+ at::TensorList grads,
25
+ at::TensorList exp_avgs,
26
+ at::TensorList exp_avg_sqs,
27
+ at::TensorList state_steps,
28
+ const at::Tensor& lr,
29
+ const double beta1,
30
+ const double beta2,
31
+ const double weight_decay,
32
+ const double eps,
33
+ const bool maximize,
34
+ const c10::optional<at::Tensor>& grad_scale,
35
+ const c10::optional<at::Tensor>& found_inf);
36
+
37
+ } // namespace native
38
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/reduction_template.cuh ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ namespace at {
2
+ namespace cuda {
3
+ //windows doesn't like large string literals, so split in two
4
+ const std::string reduction_template_0 = R"ESCAPE(
5
+ #define C10_HOST_DEVICE __host__ __device__
6
+ #define C10_DEVICE __device__
7
+ #if defined(__clang__) && defined(__HIP__)
8
+ #ifndef __forceinline__
9
+ #define __forceinline__ inline __attribute__((always_inline))
10
+ #endif
11
+ // until ROCm support for kernel asserts is restored
12
+ #define assert(expr) (static_cast<void>(0))
13
+ #endif
14
+
15
+ template <typename T>
16
+ __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
17
+ {
18
+ #if defined(__clang__) && defined(__HIP__)
19
+ return __shfl_down(value, delta, width);
20
+ #else
21
+ return __shfl_down_sync(mask, value, delta, width);
22
+ #endif
23
+ }
24
+
25
+
26
+ #if ${complex}
27
+ template <typename T>
28
+ __device__ __forceinline__ std::complex<T> WARP_SHFL_DOWN(std::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
29
+ {
30
+ return std::complex<T>(
31
+ #if defined(__clang__) && defined(__HIP__)
32
+ __shfl_down(value.real(), delta, width),
33
+ __shfl_down(value.imag(), delta, width));
34
+ #else
35
+ __shfl_down_sync(mask, value.real(), delta, width),
36
+ __shfl_down_sync(mask, value.imag(), delta, width));
37
+ #endif
38
+ }
39
+ #endif
40
+
41
+ // aligned vector generates vectorized load/store on CUDA
42
+ template<typename scalar_t, int vec_size>
43
+ struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
44
+ scalar_t val[vec_size];
45
+ };
46
+
47
+
48
+ C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
49
+ // get GCD of num and denom using Euclid's algorithm.
50
+ // Can replace this with std::gcd if we ever support c++17.
51
+ size_t a = denominator;
52
+ size_t b = numerator;
53
+ while (b != 0) {
54
+ a %= b;
55
+ // swap(a,b)
56
+ size_t tmp = a;
57
+ a = b;
58
+ b = tmp;
59
+ }
60
+
61
+ // a is now the GCD
62
+ numerator /= a;
63
+ denominator /= a;
64
+ }
65
+
66
+
67
+
68
+
69
+ struct ReduceConfig {
70
+ //has to match host-side ReduceConfig in the eager code
71
+ static constexpr int BLOCK_X = 0;
72
+ static constexpr int BLOCK_Y = 1;
73
+ static constexpr int CTA = 2;
74
+
75
+ static constexpr int input_vec_size = 4;
76
+ int element_size_bytes;
77
+ int num_inputs;
78
+ int num_outputs;
79
+ int step_input = 1;
80
+ int step_output = 1;
81
+ int ctas_per_output = 1;
82
+ int input_mult[3] = {0, 0, 0};
83
+ int output_mult[2] = {0, 0};
84
+
85
+ int block_width;
86
+ int block_height;
87
+ int num_threads;
88
+
89
+ bool vectorize_input = false;
90
+ int output_vec_size = 1;
91
+
92
+ C10_HOST_DEVICE bool should_block_x_reduce() const {
93
+ return input_mult[BLOCK_X] != 0;
94
+ }
95
+
96
+ C10_HOST_DEVICE bool should_block_y_reduce() const {
97
+ return input_mult[BLOCK_Y] != 0;
98
+ }
99
+
100
+ C10_HOST_DEVICE bool should_global_reduce() const {
101
+ return input_mult[CTA] != 0;
102
+ }
103
+
104
+ C10_DEVICE bool should_store(int output_idx) const {
105
+ return output_idx < num_outputs &&
106
+ (!should_block_x_reduce() || threadIdx.x == 0) &&
107
+ (!should_block_y_reduce() || threadIdx.y == 0);
108
+ }
109
+
110
+ C10_DEVICE bool should_reduce_tail() const {
111
+ return (!should_block_y_reduce() || threadIdx.y == 0) &&
112
+ (!should_global_reduce() || blockIdx.y == 0);
113
+ }
114
+
115
+ C10_HOST_DEVICE int input_idx() const {
116
+ int lane = threadIdx.x;
117
+ int warp = threadIdx.y;
118
+ int cta2 = blockIdx.y;
119
+ return (lane * input_mult[BLOCK_X] +
120
+ warp * input_mult[BLOCK_Y] +
121
+ cta2 * input_mult[CTA]);
122
+ }
123
+
124
+ template <int output_vec_size>
125
+ C10_HOST_DEVICE int output_idx() const {
126
+ int lane = threadIdx.x;
127
+ int warp = threadIdx.y;
128
+ int cta1 = blockIdx.x;
129
+ return (lane * output_mult[BLOCK_X] +
130
+ warp * output_mult[BLOCK_Y] +
131
+ cta1 * step_output) * output_vec_size;
132
+ }
133
+
134
+ C10_DEVICE int shared_memory_offset(int offset) const {
135
+ return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
136
+ }
137
+
138
+ C10_DEVICE int staging_memory_offset(int cta2) const {
139
+ int offset = cta2 + blockIdx.x * gridDim.y;
140
+ if (!should_block_x_reduce()) {
141
+ offset = threadIdx.x + offset * blockDim.x;
142
+ }
143
+ return offset;
144
+ }
145
+
146
+
147
+ };
148
+
149
+
150
+ //TODO this will need to be different for more generic reduction functions
151
+ namespace reducer {
152
+
153
+ using scalar_t = ${scalar_type};
154
+ using arg_t = ${reduction_accum_type};
155
+ using out_scalar_t = ${result_type};
156
+
157
+
158
+ inline __device__ ${functor}
159
+
160
+ inline __device__ out_scalar_t project(arg_t arg) {
161
+ return (out_scalar_t) arg;
162
+ }
163
+
164
+ inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
165
+ return WARP_SHFL_DOWN(arg, offset);
166
+ }
167
+
168
+ inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
169
+ return acc;
170
+ }
171
+
172
+ // wrap a normal reduction that ignores the index
173
+ inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) {
174
+ return combine(acc, val);
175
+ }
176
+ }
177
+
178
+
179
+ struct ReduceJitOp {
180
+ using scalar_t = ${scalar_type};
181
+ using arg_t = ${reduction_accum_type};
182
+ using out_scalar_t = ${result_type};
183
+
184
+ using InputCalculator = OffsetCalculator<1>;
185
+ using OutputCalculator = OffsetCalculator<2>;
186
+
187
+ // static constexpr bool can_accumulate_in_output =
188
+ // std::is_convertible<arg_t, out_scalar_t>::value
189
+ // && std::is_convertible<out_scalar_t, arg_t>::value;
190
+
191
+ static constexpr int input_vec_size = ReduceConfig::input_vec_size;
192
+
193
+ arg_t ident;
194
+ ReduceConfig config;
195
+ InputCalculator input_calc;
196
+ OutputCalculator output_calc;
197
+ const void* src;
198
+ const char* dst[2]; //it accepts at most two destinations
199
+ // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
200
+ // output is not permissible
201
+ void* acc_buf;
202
+ // cta_buf used for accumulation between blocks during global reduction
203
+ void* cta_buf;
204
+ int* semaphores;
205
+ int64_t base_idx;
206
+ bool accumulate;
207
+ bool final_output;
208
+ int noutputs;
209
+
210
+
211
+ C10_DEVICE void run() const {
212
+ extern __shared__ char shared_memory[];
213
+ uint32_t output_idx = config.output_idx<${output_vec_size}>();
214
+ uint32_t input_idx = config.input_idx();
215
+ auto base_offsets1 = output_calc.get(output_idx)[1];
216
+
217
+ using arg_vec_t = Array<arg_t, ${output_vec_size}>;
218
+ arg_vec_t value;
219
+
220
+ if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
221
+ const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
222
+
223
+ value = thread_reduce<${output_vec_size}>(input_slice);
224
+ }
225
+
226
+ if (config.should_block_y_reduce()) {
227
+ value = block_y_reduce<${output_vec_size}>(value, shared_memory);
228
+ }
229
+ if (config.should_block_x_reduce()) {
230
+ value = block_x_reduce<${output_vec_size}>(value, shared_memory);
231
+ }
232
+
233
+ using out_ptr_vec_t = Array<out_scalar_t*, ${output_vec_size}>;
234
+ using offset_vec_t = Array<uint32_t, ${output_vec_size}>;
235
+ offset_vec_t base_offsets;
236
+ out_ptr_vec_t out;
237
+
238
+ #pragma unroll
239
+ for (int i = 0; i < ${output_vec_size}; i++) {
240
+ base_offsets[i] = output_calc.get(output_idx + i)[0];
241
+ out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
242
+ }
243
+
244
+ arg_vec_t* acc = nullptr;
245
+ if (acc_buf != nullptr) {
246
+ size_t numerator = sizeof(arg_t);
247
+ size_t denominator = sizeof(out_scalar_t);
248
+ reduce_fraction(numerator, denominator);
249
+ acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
250
+ }
251
+
252
+ if (config.should_global_reduce()) {
253
+ value = global_reduce<${output_vec_size}>(value, acc, shared_memory);
254
+ } else if (config.should_store(output_idx)) {
255
+ if (accumulate) {
256
+ #pragma unroll
257
+ for (int i = 0; i < ${output_vec_size}; i++) {
258
+ value[i] = reducer::translate_idx(value[i], base_idx);
259
+ }
260
+ }
261
+
262
+ if (acc == nullptr) {
263
+ if (accumulate) {
264
+ value = accumulate_in_output<${output_vec_size}>(out, value);
265
+ }
266
+ if (final_output) {
267
+ set_results_to_output<${output_vec_size}>(value, base_offsets);
268
+ } else {
269
+ #pragma unroll
270
+ for (int i = 0; i < ${output_vec_size}; i++) {
271
+ *(out[i]) = get_accumulated_output(out[i], value[i]);
272
+ }
273
+ }
274
+ } else {
275
+ if (accumulate) {
276
+ #pragma unroll
277
+ for (int i = 0; i < ${output_vec_size}; i++) {
278
+ value[i] = reducer::combine((*acc)[i], value[i]);
279
+ }
280
+ }
281
+ if (final_output) {
282
+ set_results_to_output<${output_vec_size}>(value, base_offsets);
283
+ } else {
284
+ *acc = value;
285
+ }
286
+ }
287
+ }
288
+ }
289
+
290
+ template <int output_vec_size>
291
+ C10_DEVICE Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
292
+ if (config.vectorize_input) {
293
+ assert(output_vec_size == 1);
294
+ // reduce at the header of input_slice where memory is not aligned,
295
+ // so that thread_reduce will have an aligned memory to work on.
296
+ return {input_vectorized_thread_reduce_impl(data)};
297
+ } else {
298
+ uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
299
+ bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
300
+ if (is_contiguous) {
301
+ return thread_reduce_impl<output_vec_size>(data, [](uint32_t idx) { return idx; });
302
+ } else if (input_calc.dims == 1) {
303
+ return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return idx * element_stride; });
304
+ } else {
305
+ return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
306
+ }
307
+ }
308
+ }
309
+
310
+ C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
311
+ uint32_t end = config.num_inputs;
312
+
313
+ // Handle the head of input slice where data is not aligned
314
+ arg_t value = ident;
315
+ constexpr int align_bytes = alignof(aligned_vector<scalar_t, input_vec_size>);
316
+ constexpr int align_elements = align_bytes / sizeof(scalar_t);
317
+ int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t);
318
+ if (shift > 0) {
319
+ data -= shift;
320
+ end += shift;
321
+ if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
322
+ value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift);
323
+ }
324
+ end -= align_elements;
325
+ data += align_elements;
326
+ shift = align_elements - shift;
327
+ }
328
+
329
+ // Do the vectorized reduction
330
+ using load_t = aligned_vector<scalar_t, input_vec_size>;
331
+
332
+ uint32_t idx = config.input_idx();
333
+ const uint32_t stride = config.step_input;
334
+
335
+ // Multiple accumulators to remove dependency between unrolled loops.
336
+ arg_t value_list[input_vec_size];
337
+ value_list[0] = value;
338
+
339
+ #pragma unroll
340
+ for (int i = 1; i < input_vec_size; i++) {
341
+ value_list[i] = ident;
342
+ }
343
+
344
+ scalar_t values[input_vec_size];
345
+
346
+ load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);
347
+
348
+ while (idx * input_vec_size + input_vec_size - 1 < end) {
349
+ *values_vector = reinterpret_cast<const load_t*>(data)[idx];
350
+ #pragma unroll
351
+ for (uint32_t i = 0; i < input_vec_size; i++) {
352
+ value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
353
+ }
354
+ idx += stride;
355
+ }
356
+
357
+ // tail
358
+ uint32_t tail_start = end - end % input_vec_size;
359
+ if (config.should_reduce_tail()) {
360
+ int idx = tail_start + threadIdx.x;
361
+ if (idx < end) {
362
+ value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift);
363
+ }
364
+ }
365
+
366
+ // combine accumulators
367
+ #pragma unroll
368
+ for (int i = 1; i < input_vec_size; i++) {
369
+ value_list[0] = reducer::combine(value_list[0], value_list[i]);
370
+ }
371
+ return value_list[0];
372
+ }
373
+
374
+ template <int output_vec_size, typename offset_calc_t>
375
+ C10_DEVICE Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
376
+ uint32_t idx = config.input_idx();
377
+ const uint32_t end = config.num_inputs;
378
+ const uint32_t stride = config.step_input;
379
+ const int vt0=${vt0};
380
+
381
+ using arg_vec_t = Array<arg_t, output_vec_size>;
382
+ using load_t = aligned_vector<scalar_t, output_vec_size>;
383
+ const load_t* data = reinterpret_cast<const load_t*>(data_);
384
+
385
+ // Multiple accumulators to remove dependency between unrolled loops.
386
+ arg_vec_t value_list[vt0];
387
+
388
+ #pragma unroll
389
+ for (int i = 0; i < vt0; i++) {
390
+ #pragma unroll
391
+ for (int j = 0; j < output_vec_size; j++) {
392
+ value_list[i][j] = ident;
393
+ }
394
+ }
395
+
396
+ load_t values[vt0];
397
+
398
+ while (idx + (vt0 - 1) * stride < end) {
399
+ #pragma unroll
400
+ for (uint32_t i = 0; i < vt0; i++) {
401
+ values[i] = data[calc(idx + i * stride) / output_vec_size];
402
+ }
403
+ #pragma unroll
404
+ for (uint32_t i = 0; i < vt0; i++) {
405
+ #pragma unroll
406
+ for (uint32_t j = 0; j < output_vec_size; j++) {
407
+ value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride);
408
+ }
409
+ }
410
+ idx += stride * vt0;
411
+ }
412
+
413
+ // tail
414
+ int idx_ = idx;
415
+ #pragma unroll
416
+ for (uint32_t i = 0; i < vt0; i++) {
417
+ if (idx >= end) {
418
+ break;
419
+ }
420
+ values[i] = data[calc(idx) / output_vec_size];
421
+ idx += stride;
422
+ }
423
+ idx = idx_;
424
+ #pragma unroll
425
+ for (uint32_t i = 0; i < vt0; i++) {
426
+ if (idx >= end) {
427
+ break;
428
+ }
429
+ #pragma unroll
430
+ for (uint32_t j = 0; j < output_vec_size; j++) {
431
+ value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx);
432
+ }
433
+ idx += stride;
434
+ }
435
+
436
+ // combine accumulators
437
+ #pragma unroll
438
+ for (int i = 1; i < vt0; i++) {
439
+ #pragma unroll
440
+ for (uint32_t j = 0; j < output_vec_size; j++) {
441
+ value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]);
442
+ }
443
+ }
444
+ return value_list[0];
445
+ }
446
+ template <int output_vec_size>
447
+ C10_DEVICE Array<arg_t, output_vec_size> block_x_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
448
+ using args_vec_t = Array<arg_t, output_vec_size>;
449
+ int dim_x = blockDim.x;
450
+ args_vec_t* shared = (args_vec_t*)shared_memory;
451
+ if (dim_x > warpSize) {
452
+ int address_base = threadIdx.x + threadIdx.y*blockDim.x;
453
+ shared[address_base] = value;
454
+ for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
455
+ __syncthreads();
456
+ if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
457
+ args_vec_t other = shared[address_base + offset];
458
+ #pragma unroll
459
+ for (int i = 0; i < output_vec_size; i++) {
460
+ value[i] = reducer::combine(value[i], other[i]);
461
+ }
462
+ shared[address_base] = value;
463
+ }
464
+ }
465
+ dim_x = warpSize;
466
+ }
467
+
468
+ __syncthreads();
469
+
470
+ for (int offset = 1; offset < dim_x; offset <<= 1) {
471
+ #pragma unroll
472
+ for (int i = 0; i < output_vec_size; i++) {
473
+ arg_t other = reducer::warp_shfl_down(value[i], offset);
474
+ value[i] = reducer::combine(value[i], other);
475
+ }
476
+ }
477
+ return value;
478
+ }
479
+
480
+ template <int output_vec_size>
481
+ C10_DEVICE Array<arg_t, output_vec_size> block_y_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
482
+ using args_vec_t = Array<arg_t, output_vec_size>;
483
+ args_vec_t* shared = (args_vec_t*)shared_memory;
484
+ shared[config.shared_memory_offset(0)] = value;
485
+ for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
486
+ __syncthreads();
487
+ if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
488
+ args_vec_t other = shared[config.shared_memory_offset(offset)];
489
+ #pragma unroll
490
+ for (int i = 0; i < output_vec_size; i++) {
491
+ value[i] = reducer::combine(value[i], other[i]);
492
+ }
493
+ shared[config.shared_memory_offset(0)] = value;
494
+ }
495
+ }
496
+ return value;
497
+ }
498
+ )ESCAPE";
499
+
500
+ const std::string reduction_template_1 = R"ESCAPE(
501
+
502
+ C10_DEVICE bool mark_block_finished() const {
503
+ __shared__ bool is_last_block_done_shared;
504
+
505
+ __syncthreads();
506
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
507
+ int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
508
+ is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
509
+ }
510
+
511
+ __syncthreads();
512
+
513
+ return is_last_block_done_shared;
514
+ }
515
+
516
+ template <int output_vec_size>
517
+ C10_DEVICE Array<arg_t, output_vec_size> accumulate_in_output(
518
+ Array<out_scalar_t*, output_vec_size> out,
519
+ Array<arg_t, output_vec_size> value
520
+ ) const {
521
+ Array<arg_t, output_vec_size> ret;
522
+ #pragma unroll
523
+ for (int i = 0; i < output_vec_size; i++) {
524
+ ret[i] = reducer::combine(*(out[i]), value[i]);
525
+ }
526
+ return ret;
527
+ }
528
+
529
+
530
+ C10_DEVICE out_scalar_t get_accumulated_output(
531
+ out_scalar_t* out, arg_t value
532
+ ) const {
533
+ assert(!final_output);
534
+ return (out_scalar_t)value;
535
+ }
536
+
537
+ template<class T>
538
+ C10_DEVICE void set_results(const T x, const uint32_t base_offset) const {
539
+ assert(noutputs == 1);
540
+ auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
541
+ *res = x;
542
+ }
543
+
544
+ //TODO - multi-output reduction - we won't be able to use thrust::pair
545
+ //just explicitly specify typed output reads/writes
546
+ //Currently implemented for max of two outputs
547
+ // template<class T1, class T2>
548
+ // C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
549
+ // if (noutputs >= 1) {
550
+ // auto res0 = (T1*)((char*)dst[0] + base_offset);
551
+ // *res0 = x.first;
552
+ // }
553
+ // if (noutputs >= 2) {
554
+ // // base offset is computed assuming element size being sizeof(T1), so we need to make a
555
+ // // correction to obtain the correct base offset
556
+ // auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
557
+ // *res1 = x.second;
558
+ // }
559
+ // }
560
+
561
+ template <int output_vec_size>
562
+ C10_DEVICE void set_results_to_output(Array<arg_t, output_vec_size> value, Array<uint32_t, output_vec_size> base_offset) const {
563
+ assert(final_output);
564
+ #pragma unroll
565
+ for (int i = 0; i < output_vec_size; i++) {
566
+ set_results(reducer::project(value[i]), base_offset[i]);
567
+ }
568
+ }
569
+
570
+ template <int output_vec_size>
571
+ C10_DEVICE Array<arg_t, output_vec_size> global_reduce(Array<arg_t, output_vec_size> value, Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
572
+ using arg_vec_t = Array<arg_t, output_vec_size>;
573
+ using out_ptr_vec_t = Array<out_scalar_t*, output_vec_size>;
574
+ using offset_vec_t = Array<uint32_t, output_vec_size>;
575
+
576
+ arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
577
+ uint32_t output_idx = config.output_idx<output_vec_size>();
578
+ offset_vec_t base_offsets;
579
+ out_ptr_vec_t out;
580
+
581
+ #pragma unroll
582
+ for (int i = 0; i < output_vec_size; i++) {
583
+ base_offsets[i] = output_calc.get(output_idx + i)[0];
584
+ out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
585
+ }
586
+
587
+ bool should_store = config.should_store(output_idx);
588
+ if (should_store) {
589
+ uint32_t offset = config.staging_memory_offset(blockIdx.y);
590
+ reduce_buffer[offset] = value;
591
+ }
592
+
593
+ __threadfence(); // make sure writes are globally visible
594
+ __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
595
+ bool is_last_block_done = mark_block_finished();
596
+
597
+ if (is_last_block_done) {
598
+ value = ident;
599
+ if (config.should_block_x_reduce()) {
600
+ uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
601
+ uint32_t step = blockDim.x * blockDim.y;
602
+ for (; input_offset < config.ctas_per_output; input_offset += step) {
603
+ uint32_t idx = config.staging_memory_offset(input_offset);
604
+ arg_vec_t next = reduce_buffer[idx];
605
+ #pragma unroll
606
+ for (int i = 0; i < output_vec_size; i++) {
607
+ value[i] = reducer::combine(value[i], next[i]);
608
+ }
609
+ }
610
+ } else {
611
+ uint32_t input_offset = threadIdx.y;
612
+ uint32_t step = blockDim.y;
613
+ for (; input_offset < config.ctas_per_output; input_offset += step) {
614
+ uint32_t idx = config.staging_memory_offset(input_offset);
615
+ arg_vec_t next = reduce_buffer[idx];
616
+ #pragma unroll
617
+ for (int i = 0; i < output_vec_size; i++) {
618
+ value[i] = reducer::combine(value[i], next[i]);
619
+ }
620
+ }
621
+ }
622
+ value = block_y_reduce(value, shared_memory);
623
+ if (config.should_block_x_reduce()) {
624
+ value = block_x_reduce<output_vec_size>(value, shared_memory);
625
+ }
626
+ if (should_store) {
627
+ if (accumulate) {
628
+ #pragma unroll
629
+ for (int i = 0; i < output_vec_size; i++) {
630
+ value[i] = reducer::translate_idx(value[i], base_idx);
631
+ }
632
+ }
633
+
634
+ if (acc == nullptr) {
635
+ if (accumulate) {
636
+ value = accumulate_in_output<output_vec_size>(out, value);
637
+ }
638
+ if (final_output) {
639
+ set_results_to_output<output_vec_size>(value, base_offsets);
640
+ } else {
641
+ #pragma unroll
642
+ for (int i = 0; i < output_vec_size; i++) {
643
+ *(out[i]) = get_accumulated_output(out[i], value[i]);
644
+ }
645
+ }
646
+ } else {
647
+ if (accumulate) {
648
+ #pragma unroll
649
+ for (int i = 0; i < output_vec_size; i++) {
650
+ value[i] = reducer::combine((*acc)[i], value[i]);
651
+ }
652
+ }
653
+ if (final_output) {
654
+ set_results_to_output<output_vec_size>(value, base_offsets);
655
+ } else {
656
+ *acc = value;
657
+ }
658
+ }
659
+ }
660
+ }
661
+
662
+ return value;
663
+ }
664
+ };
665
+
666
+ extern "C"
667
+ __launch_bounds__(${max_threads_lb}, 4)
668
+ __global__ void reduction_${name}_kernel(ReduceJitOp r){
669
+ r.run();
670
+ }
671
+ )ESCAPE";
672
+
673
+ const std::string reduction_template = reduction_template_0 + reduction_template_1;
674
+
675
+
676
+ const std::string &get_reduction_template() {
677
+ return reduction_template;
678
+ }
679
+
680
+ }}
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/cuda/thread_constants.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/macros/Macros.h>
3
+
4
+ // Marks a lambda as executable on both the host and device. The __host__
5
+ // attribute is important so that we can access static type information from
6
+ // the host, even if the function is typically only executed on the device.
7
+ #ifndef GPU_LAMBDA
8
+ #define GPU_LAMBDA __host__ __device__
9
+ #endif
10
+
11
+ #if defined(USE_ROCM)
12
+ constexpr int num_threads() {
13
+ return 256;
14
+ }
15
+ #else
16
+ constexpr uint32_t num_threads() {
17
+ return C10_WARP_SIZE * 4;
18
+ }
19
+ #endif
20
+
21
+ constexpr int thread_work_size() { return 4; }
22
+ constexpr int block_work_size() { return thread_work_size() * num_threads(); }
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/OperationUtils.h ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
6
+ #include <ATen/Tensor.h>
7
+ #include <ATen/Utils.h>
8
+ #include <ATen/mps/MPSStream.h>
9
+ #include <ATen/native/mps/TensorFactory.h>
10
+ #include <c10/util/Optional.h>
11
+ #include <c10/core/ScalarType.h>
12
+ #include <torch/library.h>
13
+ #include <exception>
14
+ #include <unordered_map>
15
+
16
+ #ifndef AT_PER_OPERATOR_HEADERS
17
+ #include <ATen/Functions.h>
18
+ #include <ATen/NativeFunctions.h>
19
+ #else
20
+ #include <ATen/ops/empty.h>
21
+ #include <ATen/ops/empty_like.h>
22
+ #include <ATen/ops/zeros.h>
23
+ #include <ATen/ops/zeros_like.h>
24
+ #endif
25
+
26
+ #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
27
+
28
+ // Fwd declarations
29
+ namespace at {
30
+ struct TensorIteratorBase;
31
+ }
32
+ using namespace at::mps;
33
+
34
+ namespace at::native::mps {
35
+
36
+ void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
37
+
38
+ struct MPSScalar {
39
+ id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
40
+
41
+ size_t size = 0;
42
+ ScalarType type = ScalarType::Undefined;
43
+ c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope)
44
+ union {
45
+ float f; // MPS doesn't support 'double'
46
+ at::Half h;
47
+ int64_t i;
48
+ bool b;
49
+ c10::complex<float> cf;
50
+ c10::complex<at::Half> ch;
51
+ at::BFloat16 bf16;
52
+ } value {};
53
+ };
54
+
55
+ void runMPSGraph(MPSStream* mpsStream,
56
+ MPSGraph* mpsGraph,
57
+ NSDictionary* feeds,
58
+ NSDictionary* results);
59
+
60
+ MPSDataType getMPSDataType(ScalarType scalar_type);
61
+ static inline MPSDataType getMPSDataType(const Tensor& t) {
62
+ return getMPSDataType(t.scalar_type());
63
+ }
64
+ MPSDataType getMPSScalarType(ScalarType scalar_type);
65
+ static inline MPSDataType getMPSScalarType(const Tensor& t) {
66
+ return getMPSScalarType(t.scalar_type());
67
+ }
68
+ MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
69
+ std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
70
+ static inline std::string getMPSTypeString(const Tensor& t, bool short_name = false) {
71
+ return getMPSTypeString(t.scalar_type(), short_name);
72
+ }
73
+ std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
74
+ NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
75
+ NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
76
+ std::string getMPSShapeString(MPSShape* shape);
77
+ std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true);
78
+ std::string getArrayRefString(const IntArrayRef s);
79
+ // use has_storage() on the returned tensor to determine if src actually is a view
80
+ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
81
+ Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
82
+ bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape);
83
+ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
84
+ MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
85
+ MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
86
+
87
+ // The MPSShape could vary based on memory format
88
+ MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
89
+ MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
90
+
91
+ static inline id<MTLBuffer> getMTLBufferStorage(const at::Tensor& tensor) {
92
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
93
+ }
94
+
95
+ class Placeholder {
96
+ public:
97
+ Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
98
+ Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
99
+ Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
100
+ bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid);
101
+ MPSGraphTensor* getMPSGraphTensor() {
102
+ return _placeholder;
103
+ }
104
+ MPSGraphTensorData* getMPSGraphTensorData() {
105
+ return _value;
106
+ }
107
+ bool isIntermediate() {
108
+ return _value == nullptr;
109
+ }
110
+
111
+ private:
112
+ MPSGraphTensor* _placeholder;
113
+ MPSGraphTensorData* _value;
114
+ Tensor _tensor;
115
+ };
116
+
117
+ void resize_tensor(Tensor* output);
118
+ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device);
119
+ MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
120
+ MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
121
+ MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
122
+ MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
123
+ MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor);
124
+ MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
125
+
126
+ MPSGraph* make_mps_graph();
127
+ void printTensorNDArray(const Tensor& t);
128
+ MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType);
129
+
130
+ MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
131
+ MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
132
+ MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor);
133
+ MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
134
+ MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar);
135
+
136
+ string get_mem_format_string(c10::MemoryFormat memory_format);
137
+
138
+ using MPSCacheKey = uint64_t;
139
+
140
+ // derive this class to cache a graph and its inputs/outputs
141
+ // can be used to store any NSObject
142
+ struct MPSCachedGraph
143
+ {
144
+ MPSCachedGraph(NSObject *object) : _object([object retain]) {}
145
+ virtual ~MPSCachedGraph() {
146
+ [_object release];
147
+ _object = nullptr;
148
+ }
149
+
150
+ template<typename T>
151
+ inline T* as() {
152
+ return static_cast<T*>(this);
153
+ }
154
+
155
+ MPSGraph *graph() const { return (MPSGraph *)_object; }
156
+ NSObject *object() const { return _object; }
157
+ private:
158
+ NSObject *_object = nullptr;
159
+ };
160
+
161
+ struct MPSUnaryCachedGraph : public MPSCachedGraph
162
+ {
163
+ MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
164
+ MPSGraphTensor *inputTensor_ = nil;
165
+ MPSGraphTensor *outputTensor_ = nil;
166
+ };
167
+
168
+ struct MPSUnaryGradCachedGraph : public MPSCachedGraph
169
+ {
170
+ MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
171
+ MPSGraphTensor *gradOutputTensor_ = nil;
172
+ MPSGraphTensor *inputTensor_ = nil;
173
+ MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output
174
+ MPSGraphTensor *gradInputTensor_ = nil;
175
+ };
176
+
177
+ struct MPSBinaryCachedGraph : public MPSCachedGraph
178
+ {
179
+ MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
180
+ MPSGraphTensor *inputTensor_ = nil;
181
+ MPSGraphTensor *otherTensor_ = nil;
182
+ MPSGraphTensor *outputTensor_ = nil;
183
+ };
184
+
185
+ struct MPSBinaryGradCachedGraph : public MPSCachedGraph
186
+ {
187
+ MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
188
+ MPSGraphTensor *gradOutputTensor_ = nil;
189
+ MPSGraphTensor *inputTensor_ = nil;
190
+ MPSGraphTensor *otherTensor_ = nil;
191
+ MPSGraphTensor *gradInputTensor_ = nil;
192
+ };
193
+
194
+ // TODO: Improve the overall design of MPSGraphCache.
195
+ // https://github.com/pytorch/pytorch/issues/77176
196
+ // Cache holding various keys mapped to graphs
197
+ struct MPSGraphCache
198
+ {
199
+ typedef MPSCachedGraph * (^CreateCachedGraphBlock)();
200
+
201
+ struct CacheEntry {
202
+ CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
203
+ MPSCachedGraph* cachedGraph_ = nullptr;
204
+ std::string key_;
205
+ };
206
+
207
+ public:
208
+
209
+ static MPSGraphCache* getInstance() {
210
+ if(_instance_cache == nullptr) {
211
+ _instance_cache = new MPSGraphCache();
212
+ }
213
+ return _instance_cache;
214
+ }
215
+
216
+ ~MPSGraphCache() {
217
+ dispatch_release(serialQueue_);
218
+
219
+ for (const auto& i : cache_) {
220
+ delete i.second.cachedGraph_;
221
+ }
222
+ }
223
+
224
+ // Disallow the copy constructor and operator= functions
225
+ MPSGraphCache(const MPSGraphCache&) = delete;
226
+ void operator=(const MPSGraphCache&) = delete;
227
+
228
+ MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
229
+
230
+ __block MPSCachedGraph* cachedGraph = nil;
231
+
232
+ MPSCacheKey hash = std::hash<std::string>{}(key);
233
+
234
+ dispatch_sync_with_rethrow(serialQueue_, ^() {
235
+ // verify the cached entry doesn't already exist
236
+ if (cache_.count(hash) != 0) {
237
+ auto& entry = cache_.at(hash);
238
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
239
+ cachedGraph = entry.cachedGraph_;
240
+ } else {
241
+ cachedGraph = createCacheBlock();
242
+ CacheEntry entry(key, cachedGraph);
243
+ cache_.emplace(hash, entry);
244
+ profileCachedGraph(entry);
245
+ }
246
+ });
247
+ return cachedGraph;
248
+ }
249
+
250
+ template<typename T>
251
+ inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
252
+ return static_cast<T *>(CreateCachedGraph(key, createCacheBlock));
253
+ }
254
+
255
+ MPSCachedGraph* LookUp(const std::string& key) const {
256
+
257
+ __block MPSCachedGraph* cachedGraph = nullptr;
258
+
259
+ MPSCacheKey hash = std::hash<std::string>{}(key);
260
+
261
+ dispatch_sync(serialQueue_, ^() {
262
+
263
+ if (cache_.count(hash) != 0) {
264
+ auto& entry = cache_.at(hash);
265
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
266
+ cachedGraph = entry.cachedGraph_;
267
+ profileCachedGraph(entry);
268
+ }
269
+ });
270
+ return cachedGraph;
271
+ }
272
+
273
+ template<typename T>
274
+ inline T* LookUpAs(const std::string& key) const {
275
+ return static_cast<T *>(LookUp(key));
276
+ }
277
+
278
+ private:
279
+ MPSGraphCache() {
280
+ serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL);
281
+ }
282
+ // this is defined in OperationUtils.mm to not include
283
+ // MPSProfiler.h in header OperationUtils.h
284
+ void profileCachedGraph(const CacheEntry& cacheEntry) const;
285
+
286
+ static MPSGraphCache* _instance_cache;
287
+ std::unordered_map<MPSCacheKey, CacheEntry> cache_;
288
+ dispatch_queue_t serialQueue_ = nullptr;
289
+
290
+ };
291
+
292
+ // Common template for creating graph with a specified cache if missing
293
+ template<typename T>
294
+ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(MPSGraph*, T*)> instantiate) {
295
+ auto cache_ = MPSGraphCache::getInstance();
296
+ if (auto rc = cache_->LookUpAs<T>(key)) {
297
+ return rc;
298
+ }
299
+ return cache_->CreateCachedGraphAs<T>(key, ^mps::MPSCachedGraph*() {
300
+ T* newCachedGraph = nil;
301
+ @autoreleasepool {
302
+ // Initialize graph
303
+ auto mpsGraph = mps::make_mps_graph();
304
+ newCachedGraph = new T(mpsGraph);
305
+ instantiate(mpsGraph, newCachedGraph);
306
+ }
307
+ return newCachedGraph;
308
+ });
309
+ }
310
+
311
+ // Common math operations
312
+ MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
313
+
314
+ #define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
315
+ if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
316
+ TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
317
+ ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
318
+ }
319
+
320
+ /**
321
+ * Returns distance from lowest to highest element offset in given tensor.
322
+ */
323
+ size_t compute_storage_numel_distance(const at::Tensor& t);
324
+
325
+ /**
326
+ * Checks whether tensor is mapped to a contiguous area in the storage.
327
+ */
328
+ inline bool is_dense_in_storage(const at::Tensor& t) {
329
+ return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
330
+ }
331
+
332
+ static inline void mtl_setBuffer(id<MTLComputeCommandEncoder> encoder, const Tensor& t, unsigned idx) {
333
+ [encoder setBuffer:getMTLBufferStorage(t)
334
+ offset:t.storage_offset() * t.element_size()
335
+ atIndex:idx];
336
+ }
337
+
338
+ static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
339
+ id<MTLComputePipelineState> cplState,
340
+ uint32_t length) {
341
+ const uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
342
+ auto size = MTLSizeMake(length, 1, 1);
343
+ auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
344
+ [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
345
+ }
346
+
347
+ id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
348
+
349
+ inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
350
+ return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() };
351
+ }
352
+
353
+ inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) {
354
+ return @{
355
+ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
356
+ p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
357
+ };
358
+ }
359
+
360
+ inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) {
361
+ return @{
362
+ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
363
+ p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
364
+ p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
365
+ };
366
+ }
367
+
368
+ inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) {
369
+ return @{
370
+ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
371
+ p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
372
+ p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
373
+ p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(),
374
+ };
375
+ }
376
+
377
+ inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) {
378
+ runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
379
+ }
380
+
381
+ inline bool supportsComplex() {
382
+ return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
383
+ }
384
+
385
+ // MPS yet to support double types, but starting from MacOS 14, supports bfloat16
386
+ inline bool supportedFloatingType(ScalarType dtype) {
387
+ return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
388
+ }
389
+
390
+ inline bool supportedFloatingType(const Tensor& t) {
391
+ return supportedFloatingType(t.scalar_type());
392
+ }
393
+
394
+ } // namespace at::native::mps
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/mps/TensorFactory.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2022 Apple Inc.
2
+
3
+ #define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \
4
+ AT_DISPATCH_SWITCH( \
5
+ TYPE, NAME, \
6
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
7
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
8
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
9
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
10
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
11
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
12
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Transformer-specific NestedTensor utility functions.
3
+ *
4
+ * Not co-located with NestedTensor core code yet because they only
5
+ * support specific cases needed in transformers.
6
+ */
7
+ #pragma once
8
+
9
+ #include <vector>
10
+
11
+ #include <c10/macros/Macros.h>
12
+ #include <c10/util/Optional.h>
13
+
14
+ namespace c10 {
15
+ class Scalar;
16
+ } // namespace c10
17
+
18
+ namespace at {
19
+ class Tensor;
20
+ namespace native {
21
+ struct NestedTensorImpl;
22
+
23
+ // Requires that self is a contiguous NestedTensor, other is not a
24
+ // NestedTensor, self.dim() == 3, and other.dim() == 2. Also, self
25
+ // must have a consistent last dimension across its included Tensors
26
+ // and that dimension must match other.size(0).
27
+ Tensor NestedTensor_matmul(const Tensor& self, const Tensor& other);
28
+
29
+ // Requires that mat1 is a contiguous NestedTensor, self & mat2 are
30
+ // not NestedTensors, mat1.dim() == 3, mat2.dim() == 2, and that mat1
31
+ // has a consistent last dimension across its included Tensors that
32
+ // matches mat2.size(0).
33
+ Tensor NestedTensor_times_Tensor_plus_Tensor_addmm(
34
+ const Tensor& self,
35
+ const Tensor& mat1,
36
+ const Tensor& mat2,
37
+ const c10::Scalar& beta,
38
+ const c10::Scalar& alpha,
39
+ c10::optional<bool> use_gelu = c10::nullopt);
40
+
41
+ Tensor NestedTensor_add_NestedTensor_in_place(
42
+ const Tensor& self,
43
+ const Tensor& other);
44
+
45
+ TORCH_API Tensor NestedTensor_batch_offsets_from_size_tensor(
46
+ const Tensor& sizes,
47
+ int64_t extra_elements);
48
+
49
+ Tensor NestedTensor_from_padded_tensor_cpu(
50
+ const Tensor& padded,
51
+ const NestedTensorImpl& nt);
52
+
53
+ Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim, c10::optional<int64_t> mask_dim_length);
54
+
55
+ template <typename T>
56
+ void remove_padding_kernelLauncher(
57
+ const T* input,
58
+ T* output,
59
+ const int* offsets,
60
+ const int* input_sizes,
61
+ const int* output_sizes,
62
+ int output_dim,
63
+ const int batch_size);
64
+
65
+ template <typename T>
66
+ void remove_padding_transform0213_kernelLauncher(
67
+ const T* input,
68
+ T* output,
69
+ const int* offsets,
70
+ const int* input_sizes,
71
+ const int* output_sizes,
72
+ int output_dim,
73
+ const int batch_size);
74
+
75
+ template <typename T>
76
+ void add_padding_kernelLauncher(
77
+ T* input,
78
+ T* output,
79
+ T padding_value,
80
+ const int* offsets,
81
+ const int* input_sizes,
82
+ int input_dim,
83
+ const std::vector<int64_t>& output_sizes,
84
+ const int batch_size,
85
+ const int output_batch_size);
86
+
87
+ TORCH_API Tensor flash_attention_helper(
88
+ const Tensor& query,
89
+ const Tensor& key,
90
+ const Tensor& value,
91
+ double dropout_p,
92
+ bool need_attn_weights,
93
+ bool is_causal);
94
+
95
+ TORCH_API std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
96
+ const Tensor& query,
97
+ const Tensor& key,
98
+ const Tensor& value,
99
+ double dropout_p,
100
+ bool need_attn_weights,
101
+ bool is_causal);
102
+ } // namespace native
103
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/AffineQuantizer.h ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/Dispatch.h>
5
+ #include <ATen/native/DispatchStub.h>
6
+ #include <ATen/native/quantized/AffineQuantizerBase.h>
7
+
8
+ namespace at {
9
+ namespace native {
10
+
11
+ Tensor& quantize_tensor_per_tensor_affine(
12
+ const Tensor& rtensor,
13
+ Tensor& qtensor,
14
+ double scale,
15
+ int64_t zero_point);
16
+ Tensor& quantize_tensor_per_channel_affine(
17
+ const Tensor& rtensor,
18
+ Tensor& qtensor,
19
+ Tensor scales,
20
+ Tensor zero_points,
21
+ int64_t axis);
22
+
23
+ Tensor& quantize_tensor_per_channel_float_qparams(
24
+ const Tensor& rtensor,
25
+ Tensor& qtensor,
26
+ Tensor scales,
27
+ Tensor zero_points,
28
+ int64_t axis);
29
+
30
+ Tensor& dequantize_tensor_per_tensor_affine(
31
+ const Tensor& qtensor,
32
+ Tensor& rtensor,
33
+ double scale,
34
+ int64_t zero_point);
35
+ Tensor& dequantize_tensor_per_channel_affine(
36
+ const Tensor& qtensor,
37
+ Tensor& rtensor,
38
+ Tensor scales,
39
+ Tensor zero_points,
40
+ int64_t axis);
41
+ Tensor& dequantize_tensor_per_channel_float_qparams(
42
+ const Tensor& qtensor,
43
+ Tensor& rtensor,
44
+ Tensor scales,
45
+ Tensor zero_points,
46
+ int64_t axis);
47
+
48
+ using quantize_tensor_per_tensor_affine_fn =
49
+ void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point);
50
+
51
+ using quantize_tensor_per_channel_affine_fn = void (*)(
52
+ const Tensor& rtensor,
53
+ Tensor& qtensor,
54
+ const Tensor& scales,
55
+ const Tensor& zero_points,
56
+ int64_t axis);
57
+
58
+ using quantize_tensor_per_channel_float_qparams_fn = void (*)(
59
+ const Tensor& rtensor,
60
+ Tensor& qtensor,
61
+ const Tensor& scales,
62
+ const Tensor& zero_points,
63
+ int64_t axis);
64
+
65
+ using dequantize_tensor_per_tensor_affine_fn =
66
+ void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point);
67
+
68
+ using dequantize_tensor_per_channel_affine_fn = void (*)(
69
+ const Tensor& qtensor,
70
+ Tensor& rtensor,
71
+ const Tensor& scales,
72
+ const Tensor& zero_points,
73
+ int64_t axis);
74
+
75
+ using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
76
+ const Tensor& qtensor,
77
+ Tensor& rtensor,
78
+ const Tensor& scales,
79
+ const Tensor& zero_points,
80
+ int64_t axis);
81
+
82
+ using quantize_tensor_per_tensor_affine_sub_byte_fn =
83
+ void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point);
84
+
85
+ using dequantize_tensor_per_tensor_affine_sub_byte_fn =
86
+ void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point);
87
+
88
+ DECLARE_DISPATCH(
89
+ quantize_tensor_per_tensor_affine_fn,
90
+ quantize_tensor_per_tensor_affine_stub);
91
+ DECLARE_DISPATCH(
92
+ quantize_tensor_per_channel_affine_fn,
93
+ quantize_tensor_per_channel_affine_stub);
94
+ DECLARE_DISPATCH(
95
+ quantize_tensor_per_channel_float_qparams_fn,
96
+ quantize_tensor_per_channel_float_qparams_stub);
97
+
98
+ DECLARE_DISPATCH(
99
+ dequantize_tensor_per_tensor_affine_fn,
100
+ dequantize_tensor_per_tensor_affine_stub);
101
+ DECLARE_DISPATCH(
102
+ dequantize_tensor_per_channel_affine_fn,
103
+ dequantize_tensor_per_channel_affine_stub);
104
+ DECLARE_DISPATCH(
105
+ dequantize_tensor_per_channel_float_qparams_fn,
106
+ dequantize_tensor_per_channel_float_qparams_stub);
107
+
108
+ DECLARE_DISPATCH(
109
+ quantize_tensor_per_tensor_affine_sub_byte_fn,
110
+ quantize_tensor_per_tensor_affine_sub_byte_stub);
111
+
112
+ DECLARE_DISPATCH(
113
+ dequantize_tensor_per_tensor_affine_sub_byte_fn,
114
+ dequantize_tensor_per_tensor_affine_sub_byte_stub);
115
+
116
+ template <typename T>
117
+ TORCH_API Tensor quantize_tensor(
118
+ Tensor rtensor,
119
+ Tensor qtensor,
120
+ double scale,
121
+ int64_t zero_point);
122
+ template <typename T>
123
+ TORCH_API Tensor dequantize_tensor(
124
+ Tensor qtensor,
125
+ Tensor rtensor,
126
+ double scale,
127
+ int64_t zero_point);
128
+
129
+ } // namespace native
130
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/ConvUtils.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/List.h>
3
+ #include <ATen/native/ConvUtils.h>
4
+
5
+ namespace at::native::quantized {
6
+ namespace {
7
+ // MakeConvOutputShape used from both CPU and CUDA libraries
8
+ // and exporting symbol from torch_cpu would probably take more storage
9
+ // than duplicating implementation which likely be inlined away
10
+ template <int kSpatialDim>
11
+ at::SmallVector<int64_t, kSpatialDim + 2> MakeConvOutputShape(
12
+ int N, // mini-batch
13
+ int M, // output channels
14
+ const std::array<int64_t, kSpatialDim>& input_image_shape,
15
+ const std::vector<int64_t>& kernel,
16
+ const torch::List<int64_t>& stride,
17
+ const torch::List<int64_t>& padding,
18
+ const torch::List<int64_t>& dilation);
19
+
20
+ #if defined(USE_CUDA) || defined(USE_PYTORCH_QNNPACK)
21
+ template <>
22
+ at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
23
+ int N, // mini-batch
24
+ int M, // output channels
25
+ const std::array<int64_t, 2>& input_image_shape,
26
+ const std::vector<int64_t>& kernel,
27
+ const at::List<int64_t>& stride,
28
+ const at::List<int64_t>& padding,
29
+ const at::List<int64_t>& dilation) {
30
+ const int H = input_image_shape[0];
31
+ const int W = input_image_shape[1];
32
+ const int64_t Y_H =
33
+ (H + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
34
+ const int64_t Y_W =
35
+ (W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
36
+ return {N, M, Y_H, Y_W};
37
+ }
38
+
39
+ template <>
40
+ at::SmallVector<int64_t, 5> MakeConvOutputShape<3>(
41
+ int N, // mini-batch
42
+ int M, // output channels
43
+ const std::array<int64_t, 3>& input_image_shape,
44
+ const std::vector<int64_t>& kernel,
45
+ const at::List<int64_t>& stride,
46
+ const at::List<int64_t>& padding,
47
+ const torch::List<int64_t>& dilation) {
48
+ const int D = input_image_shape[0];
49
+ const int H = input_image_shape[1];
50
+ const int W = input_image_shape[2];
51
+ const int64_t Y_D =
52
+ (D + 2 * padding[0] - dilation[0] * (kernel[0] - 1) - 1) / stride[0] + 1;
53
+ const int64_t Y_H =
54
+ (H + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1;
55
+ const int64_t Y_W =
56
+ (W + 2 * padding[2] - dilation[2] * (kernel[2] - 1) - 1) / stride[2] + 1;
57
+ return {N, M, Y_D, Y_H, Y_W};
58
+ }
59
+
60
+ #endif
61
+ } // anonymous namespace
62
+ } // namespace at::native::quantized
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/IndexKernel.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/TensorIterator.h>
3
+
4
+ namespace at {
5
+ namespace native {
6
+ using masked_fill_kernel_quantized_fn = void(*)(TensorIterator& iter, const Scalar& value, double scale, int zero_point);
7
+ using index_put_kernel_quantized_fn = void(*)(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point);
8
+
9
+ DECLARE_DISPATCH(masked_fill_kernel_quantized_fn, masked_fill_kernel_quantized_stub);
10
+ DECLARE_DISPATCH(index_put_kernel_quantized_fn, index_put_kernel_quantized_stub);
11
+
12
+
13
+ } // native
14
+ } // at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/PackedParams.h ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/core/ivalue.h>
5
+
6
+ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
7
+ virtual at::Tensor apply(
8
+ at::Tensor input,
9
+ double output_scale,
10
+ int64_t output_zero_point) = 0;
11
+ virtual at::Tensor apply_relu(
12
+ at::Tensor input,
13
+ double output_scale,
14
+ int64_t output_zero_point) = 0;
15
+
16
+ // out variant of LinearPackedParamsBase::apply
17
+ virtual at::Tensor& apply_out(
18
+ const at::Tensor& /*input*/,
19
+ double /*output_scale*/,
20
+ int64_t /*output_zero_point*/,
21
+ at::Tensor& output) {
22
+ throw std::runtime_error(
23
+ "apply_out is not implemented for this packed "
24
+ "parameter type");
25
+ return output;
26
+ }
27
+
28
+ virtual at::Tensor& apply_relu_out(
29
+ const at::Tensor& /*input*/,
30
+ double /*output_scale*/,
31
+ int64_t /*output_zero_point*/,
32
+ at::Tensor& output) {
33
+ throw std::runtime_error(
34
+ "apply_relu_out is not implemented for this packed "
35
+ "parameter type");
36
+ return output;
37
+ }
38
+
39
+ // Corresponding pattern (the ops with `*` are part of the pattern that
40
+ // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32):
41
+ // input -> q* -> dq* -> linear* ->
42
+ // qweight -> dq* /
43
+ //
44
+ // After fusion:
45
+ // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* ->
46
+ // qweight /
47
+ //
48
+ // Additional Note: the weight is packed as well
49
+ // Params:
50
+ // X: float32 Tensor, will be quantized to quint8 in the op
51
+ // W_prepack: packed qint8 quantized weight and bias
52
+ // Returns:
53
+ // Y: float32 Tensor
54
+ virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
55
+ at::Tensor input,
56
+ double input_scale,
57
+ int64_t input_zero_point) {
58
+ throw std::runtime_error(
59
+ "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed "
60
+ "parameter type");
61
+ return {};
62
+ }
63
+
64
+ // Corresponding pattern (the ops with `*` are part of the pattern that
65
+ // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32):
66
+ // input -> q* -> dq* -> linear* -> relu* ->
67
+ // qweight -> dq* /
68
+ //
69
+ // After fusion:
70
+ // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* ->
71
+ // qweight /
72
+ //
73
+ // Additional Note: the weight is packed as well
74
+ // Params:
75
+ // input: float32 Tensor, will be quantized to quint8 in the op
76
+ // Returns:
77
+ // float32 Tensor
78
+ virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
79
+ at::Tensor input,
80
+ double input_scale,
81
+ int64_t input_zero_point) {
82
+ throw std::runtime_error(
83
+ "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed "
84
+ "parameter type");
85
+ return {};
86
+ }
87
+
88
+ virtual at::Tensor apply_dynamic(
89
+ at::Tensor input,
90
+ bool reduce_range = false) = 0;
91
+ virtual at::Tensor apply_dynamic_relu(
92
+ at::Tensor input,
93
+ bool reduce_range = false) = 0;
94
+
95
+ virtual at::Tensor& apply_dynamic_out(
96
+ const at::Tensor& /* input */,
97
+ at::Tensor& output,
98
+ bool /* reduce_range */) {
99
+ throw std::runtime_error(
100
+ "apply_dynamic_out is not implemented for this packed "
101
+ "parameter type");
102
+ return output;
103
+ }
104
+ virtual at::Tensor& apply_dynamic_relu_out(
105
+ const at::Tensor& /* input */,
106
+ at::Tensor& output,
107
+ bool /* reduce_range */) {
108
+ throw std::runtime_error(
109
+ "apply_dynamic_relu_out is not implemented for this packed "
110
+ "parameter type");
111
+ return output;
112
+ }
113
+
114
+ virtual std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() = 0;
115
+
116
+ virtual c10::optional<at::Tensor> bias() = 0;
117
+
118
+ virtual void set_bias(c10::optional<at::Tensor> /*bias*/) {
119
+ throw std::runtime_error(
120
+ "set_bias is not implemented for this packed "
121
+ "parameter type");
122
+ }
123
+ };
124
+
125
+ template <int kSpatialDim = 2>
126
+ struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
127
+ virtual at::Tensor apply(
128
+ const at::Tensor& input,
129
+ double output_scale,
130
+ int64_t output_zero_point) = 0;
131
+ virtual at::Tensor apply_relu(
132
+ const at::Tensor& input,
133
+ double output_scale,
134
+ int64_t output_zero_point) = 0;
135
+ virtual at::Tensor apply_dynamic(
136
+ const at::Tensor& input,
137
+ bool reduce_range) = 0;
138
+
139
+ virtual std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() = 0;
140
+
141
+ virtual torch::List<int64_t> stride() const = 0;
142
+ virtual torch::List<int64_t> padding() const = 0;
143
+ virtual torch::List<int64_t> output_padding() const = 0;
144
+ virtual torch::List<int64_t> dilation() const = 0;
145
+ virtual int64_t groups() const = 0;
146
+ virtual bool transpose() const = 0;
147
+ };
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Tensor.h>
4
+ #include <ATen/core/ivalue.h>
5
+
6
+ struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder {
7
+ virtual at::Tensor embeddingbag_byte(
8
+ const at::Tensor& indices,
9
+ const c10::optional<at::Tensor>& offsets,
10
+ bool pruned_weights,
11
+ const c10::optional<at::Tensor>& per_sample_weights_,
12
+ const c10::optional<at::Tensor>& compressed_indices_mapping,
13
+ bool include_last_offset,
14
+ bool is_embedding_op) = 0;
15
+
16
+ virtual at::Tensor embeddingbag_4bit(
17
+ const at::Tensor& indices,
18
+ const c10::optional<at::Tensor>& offsets,
19
+ bool pruned_weights,
20
+ const c10::optional<at::Tensor>& per_sample_weights_,
21
+ const c10::optional<at::Tensor>& compressed_indices_mapping,
22
+ bool include_last_offset,
23
+ bool is_embedding_op) = 0;
24
+
25
+ virtual at::Tensor unpack() = 0;
26
+
27
+ virtual int64_t bit_rate() const = 0;
28
+ virtual int64_t version() const = 0;
29
+ };
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/native/quantized/cpu/QnnpackUtils.h ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_PYTORCH_QNNPACK
4
+ #include <ATen/core/Tensor.h>
5
+ #include <c10/util/irange.h>
6
+ #include <pytorch_qnnpack.h>
7
+ #include <qnnpack_func.h>
8
+ #include <ATen/native/quantized/cpu/XnnpackUtils.h>
9
+ #include <ATen/native/quantized/PackedParams.h>
10
+ #include <ATen/native/utils/Factory.h>
11
+
12
+ #ifndef AT_PER_OPERATOR_HEADERS
13
+ #include <ATen/Functions.h>
14
+ #else
15
+ #include <ATen/ops/empty.h>
16
+ #endif
17
+
18
+ #include <utility>
19
+ inline int kPaddingChannels = 8;
20
+ struct QnnpackOperatorDeleter {
21
+ void operator()(pytorch_qnnp_operator_t op) {
22
+ pytorch_qnnp_delete_operator(op);
23
+ }
24
+ };
25
+
26
+ // PackedWeight struct for QNNPACK stores the original Weight and Bias as
27
+ // QNNPACK currently does not support an unpack function.
28
+ // For PyTorch Mobile, once the model is scripted and serialized we don't need
29
+ // to call unpack, so we can save some memory by checking for this case and free
30
+ // the original weights after packing.
31
+ // Input scale is set to null in pre-pack step. QNNPACK needs bias quantized
32
+ // with input scale which is available at runtime in pytorch. During runtime if
33
+ // input scale value changes then we requantize bias with the updated scale. For
34
+ // inference we expect the graph to be static so the input scale should not
35
+ // change across consecutive inference calls.
36
+ struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
37
+ PackedLinearWeightsQnnp(
38
+ std::unique_ptr<qnnpack::PackBMatrix> w,
39
+ at::Tensor orig_weight,
40
+ at::Tensor bias,
41
+ c10::optional<double> input_scale,
42
+ at::Tensor w_scales,
43
+ std::vector<uint8_t>&& w_zps)
44
+ : w(std::move(w)),
45
+ orig_weight(std::move(orig_weight)),
46
+ bias_(at::native::mobile::allocate_padded_contiguous_if_needed(
47
+ bias, bias.suggest_memory_format())),
48
+ per_channel_(this->orig_weight.qscheme() == at::kPerChannelAffine),
49
+ input_scale(std::move(input_scale)),
50
+ w_scales(std::move(w_scales)),
51
+ w_zero_points(std::move(w_zps)),
52
+ q_scheme(this->orig_weight.qscheme()) {
53
+ weight_sizes = this->orig_weight.sizes().vec();
54
+ }
55
+
56
+ std::unique_ptr<qnnpack::PackBMatrix> w;
57
+ at::Tensor orig_weight;
58
+ at::Tensor bias_;
59
+ bool per_channel_;
60
+ c10::optional<double> input_scale;
61
+ at::Tensor w_scales;
62
+ std::vector<uint8_t> w_zero_points;
63
+ std::vector<float> requantization_scales;
64
+ std::vector<int64_t> weight_sizes;
65
+ c10::QScheme q_scheme;
66
+
67
+ at::Tensor apply(
68
+ at::Tensor input,
69
+ double output_scale,
70
+ int64_t output_zero_point) override;
71
+ at::Tensor apply_relu(
72
+ at::Tensor input,
73
+ double output_scale,
74
+ int64_t output_zero_point) override;
75
+
76
+ at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
77
+ at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
78
+
79
+ std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
80
+
81
+ c10::optional<at::Tensor> bias() override {
82
+ return bias_;
83
+ }
84
+
85
+ static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
86
+ at::Tensor weight,
87
+ c10::optional<at::Tensor> bias);
88
+
89
+ bool per_channel() const {
90
+ return per_channel_;
91
+ }
92
+
93
+ private:
94
+ std::mutex qnnp_mutex_;
95
+
96
+ #ifdef USE_XNNPACK
97
+ xnnpack_operator xnnp_linear_op;
98
+
99
+ template <typename scalar_t, bool kReluFused>
100
+ at::Tensor apply_impl_xnnp(
101
+ const at::Tensor& input,
102
+ double output_scale,
103
+ int64_t output_zero_point);
104
+ #endif // USE_XNNPACK
105
+
106
+ template <bool ReluFused>
107
+ at::Tensor apply_impl(
108
+ at::Tensor input,
109
+ double output_scale,
110
+ int64_t output_zero_point);
111
+
112
+ template <bool ReluFused>
113
+ at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range);
114
+ };
115
+
116
+ template <int kSpatialDim = 2>
117
+ struct PackedConvWeightsQnnp : public ConvPackedParamsBase<kSpatialDim> {
118
+ PackedConvWeightsQnnp(
119
+ std::unique_ptr<qnnpack::PrePackConvWeights> w,
120
+ at::Tensor orig_weight,
121
+ at::Tensor bias,
122
+ torch::List<int64_t> stride,
123
+ torch::List<int64_t> padding,
124
+ torch::List<int64_t> output_padding,
125
+ torch::List<int64_t> dilation,
126
+ int64_t groups,
127
+ bool transpose,
128
+ c10::optional<double> input_scale,
129
+ std::vector<int64_t> kernel,
130
+ at::Tensor w_scale,
131
+ std::vector<uint8_t>&& w_zps,
132
+ bool is_per_channel)
133
+ : w(std::move(w)),
134
+ orig_weight(std::move(orig_weight)),
135
+ bias(std::move(bias)),
136
+ stride_(std::move(stride)),
137
+ padding_(std::move(padding)),
138
+ output_padding_(std::move(output_padding)),
139
+ dilation_(std::move(dilation)),
140
+ groups_(groups),
141
+ transpose_(transpose),
142
+ is_per_channel_(is_per_channel),
143
+ input_scale(input_scale),
144
+ kernel_(std::move(kernel)),
145
+ w_scales(std::move(w_scale)),
146
+ w_zero_points(std::move(w_zps)) {
147
+ const bool any_padding = std::any_of(
148
+ padding_.begin(), padding_.end(), [](const auto& e) { return e != 0; });
149
+ const size_t kernel_size =
150
+ std::accumulate(kernel_.begin(), kernel_.end(), 1, std::multiplies<>());
151
+
152
+ const size_t group_input_channels = transpose
153
+ ? this->orig_weight.size(0) / groups
154
+ : this->orig_weight.size(1);
155
+ const size_t group_output_channels = transpose
156
+ ? this->orig_weight.size(1)
157
+ : this->orig_weight.size(0) / groups;
158
+
159
+ const size_t kernel_depth = kSpatialDim == 3 ? kernel_[0] : 1;
160
+ const size_t kernel_height = kernel_[kSpatialDim - 2];
161
+ const size_t kernel_width = kernel_[kSpatialDim - 1];
162
+
163
+ pytorch_qnnp_ukernel_type ukernel_type;
164
+ if (transpose_) {
165
+ ukernel_type = pytorch_qnnp_ukernel_type_conv;
166
+ } else {
167
+ ukernel_type = pytorch_qnnp_ukernel_type_none;
168
+
169
+ const bool has_depthwise_dimensions =
170
+ (kSpatialDim == 2 &&
171
+ ((kernel_height == 3 && kernel_width == 3) ||
172
+ (kernel_height == 5 && kernel_width == 5))) ||
173
+ (kSpatialDim == 3 && kernel_height == 3 && kernel_width == 3 &&
174
+ kernel_depth == 3);
175
+ const bool has_depthwise_grouping =
176
+ group_input_channels == 1 && group_output_channels == 1 && groups > 1;
177
+
178
+ if (has_depthwise_dimensions && has_depthwise_grouping) {
179
+ ukernel_type = pytorch_qnnp_ukernel_type_dwconv;
180
+ } else if (
181
+ kernel_size == 1 &&
182
+ std::all_of(
183
+ stride_.begin(),
184
+ stride_.end(),
185
+ [](const auto& e) { return e == 1; }) &&
186
+ !any_padding) {
187
+ ukernel_type = group_input_channels >= SIZE_MAX
188
+ ? pytorch_qnnp_ukernel_type_xzp_gemm
189
+ : pytorch_qnnp_ukernel_type_gemm;
190
+ } else {
191
+ ukernel_type = pytorch_qnnp_ukernel_type_conv;
192
+ }
193
+ }
194
+
195
+ if (is_per_channel && ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) {
196
+ TORCH_INTERNAL_ASSERT(
197
+ false, "Per channel quantized weights are not supported for XZP kernels");
198
+ }
199
+
200
+ pytorch_qnnp_operator_t convolution{nullptr};
201
+ // Initially all the params are set to zero.
202
+ convolution = static_cast<pytorch_qnnp_operator_t>(
203
+ calloc(1, sizeof(struct pytorch_qnnp_operator)));
204
+ if (convolution == nullptr) {
205
+ TORCH_INTERNAL_ASSERT(
206
+ false, "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
207
+ sizeof(struct pytorch_qnnp_operator));
208
+ }
209
+
210
+ convolution_op =
211
+ std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter>(
212
+ convolution);
213
+
214
+ // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
215
+ convolution->ukernel_type = ukernel_type;
216
+ convolution->groups = groups;
217
+ convolution->group_input_channels = group_input_channels;
218
+ convolution->group_output_channels = group_output_channels;
219
+ convolution->kernel_depth = kernel_depth;
220
+ convolution->kernel_height = kernel_height;
221
+ convolution->kernel_width = kernel_width;
222
+ convolution->stride_depth = kSpatialDim == 3 ? stride_[0] : 1;
223
+ convolution->stride_height = stride_[kSpatialDim - 2];
224
+ convolution->stride_width = stride_[kSpatialDim - 1];
225
+ convolution->dilation_depth = kSpatialDim == 3 ? dilation_[0] : 1;
226
+ convolution->dilation_height = dilation_[kSpatialDim - 2];
227
+ convolution->dilation_width = dilation_[kSpatialDim - 1];
228
+ convolution->input_padding_height = padding_[kSpatialDim - 2];
229
+ convolution->input_padding_width = padding_[kSpatialDim - 1];
230
+ convolution->input_padding_depth = kSpatialDim == 3 ? padding_[0] : 0;
231
+ convolution->per_channel = is_per_channel_;
232
+ convolution->transpose = transpose_;
233
+
234
+ const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
235
+ const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
236
+
237
+ size_t zero_size = sizeof(uint8_t) * k_stride;
238
+ size_t zero_offset = 0;
239
+
240
+ if (transpose_) {
241
+ convolution->adjustment_width = output_padding_[1];
242
+ convolution->adjustment_height = output_padding_[0];
243
+ if (group_input_channels < 8) {
244
+ zero_size += 8;
245
+ zero_offset = 8;
246
+ }
247
+ } else {
248
+ zero_buffer_size = 0;
249
+ if (any_padding) {
250
+ zero_size = 0;
251
+ zero_offset = 0;
252
+ if (ukernel_type == pytorch_qnnp_ukernel_type_dwconv) {
253
+ const uint32_t cr = pytorch_qnnp_params.q8dw9.cr;
254
+ const size_t group_stride = (groups + (cr - 1)) & -cr;
255
+ if (groups >= 8) {
256
+ zero_size = sizeof(uint8_t) * group_stride;
257
+ zero_offset = 0;
258
+ } else {
259
+ zero_size = sizeof(uint8_t) * group_stride + 8;
260
+ zero_offset = sizeof(uint8_t) * 8;
261
+ }
262
+ } else if (
263
+ ukernel_type == pytorch_qnnp_ukernel_type_conv ||
264
+ ukernel_type == pytorch_qnnp_ukernel_type_gemm) {
265
+ if (group_input_channels >= 8) {
266
+ zero_size = sizeof(uint8_t) * k_stride;
267
+ zero_offset = 0;
268
+ } else {
269
+ zero_size = sizeof(uint8_t) * k_stride + 8;
270
+ zero_offset = 8;
271
+ }
272
+ }
273
+ }
274
+ }
275
+
276
+ // NOLINTNEXTLINE(clang-analyzer-optin.portability.UnixAPI)
277
+ void* zero_buffer = malloc(zero_size);
278
+ if (zero_buffer == nullptr) {
279
+ pytorch_qnnp_delete_operator(convolution);
280
+ TORCH_INTERNAL_ASSERT(
281
+ false, "failed to allocate %zu bytes for zero padding",
282
+ zero_size);
283
+ }
284
+ // Need to set to input zero point
285
+ // memset(zero_buffer, input_zero_point, zero_size);
286
+ zero_buffer_size = zero_size;
287
+ convolution->zero_buffer = zero_buffer;
288
+ convolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset);
289
+ }
290
+
291
+ std::unique_ptr<pytorch_qnnp_operator, QnnpackOperatorDeleter> convolution_op;
292
+ #ifdef USE_XNNPACK
293
+ xnnpack_operator xnnp_convolution_op;
294
+ #endif // USE_XNNPACK
295
+ std::unique_ptr<qnnpack::PrePackConvWeights> w;
296
+ at::Tensor orig_weight;
297
+ at::Tensor bias;
298
+ torch::List<int64_t> stride_;
299
+ torch::List<int64_t> padding_;
300
+ torch::List<int64_t> output_padding_;
301
+ torch::List<int64_t> dilation_;
302
+ int64_t groups_;
303
+ bool transpose_;
304
+ bool is_per_channel_;
305
+ c10::optional<double> input_scale;
306
+ std::vector<int64_t> kernel_;
307
+ at::Tensor w_scales;
308
+ std::vector<uint8_t> w_zero_points;
309
+ std::vector<float> requantization_scales;
310
+ size_t zero_buffer_size;
311
+
312
+ at::Tensor apply(
313
+ const at::Tensor& input,
314
+ double output_scale,
315
+ int64_t output_zero_point) override;
316
+
317
+ at::Tensor apply_relu(
318
+ const at::Tensor& input,
319
+ double output_scale,
320
+ int64_t output_zero_point) override;
321
+
322
+ at::Tensor apply_dynamic(
323
+ const at::Tensor& input,
324
+ bool reduce_range=false) override;
325
+
326
+ std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
327
+
328
+ static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
329
+ at::Tensor weight,
330
+ c10::optional<at::Tensor> bias,
331
+ torch::List<int64_t> stride,
332
+ torch::List<int64_t> padding,
333
+ torch::List<int64_t> output_padding,
334
+ torch::List<int64_t> dilation,
335
+ int64_t groups,
336
+ bool transpose);
337
+
338
+ torch::List<int64_t> stride() const override {
339
+ return stride_;
340
+ }
341
+
342
+ torch::List<int64_t> padding() const override {
343
+ return padding_;
344
+ }
345
+
346
+ torch::List<int64_t> output_padding() const override {
347
+ return output_padding_;
348
+ }
349
+
350
+ torch::List<int64_t> dilation() const override {
351
+ return dilation_;
352
+ }
353
+
354
+ int64_t groups() const override {
355
+ return groups_;
356
+ }
357
+
358
+ bool transpose() const override {
359
+ return transpose_;
360
+ }
361
+
362
+ bool per_channel() const {
363
+ return is_per_channel_;
364
+ }
365
+
366
+ private:
367
+ std::mutex qnnp_mutex_;
368
+ template <bool ReluFused>
369
+ at::Tensor apply_impl(
370
+ const at::Tensor& input,
371
+ double output_scale,
372
+ int64_t output_zero_point);
373
+
374
+ #ifdef USE_XNNPACK
375
+ template <typename scalar_t, bool ReluFused>
376
+ at::Tensor apply_impl_xnnp(
377
+ const at::Tensor& input,
378
+ double output_scale,
379
+ int64_t output_zero_point);
380
+ #endif // USE_XNNPACK
381
+ };
382
+
383
+ enum class Activation : uint8_t { NONE = 0, RELU = 1 };
384
+
385
+ #if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
386
+ template <class T>
387
+ inline float Round(const float x) {
388
+ return ::nearbyintf(x);
389
+ }
390
+ inline double Round(const double x) {
391
+ return ::nearbyint(x);
392
+ }
393
+ #else
394
+ template <class T>
395
+ inline T Round(const T x) {
396
+ return std::nearbyint(x);
397
+ }
398
+ #endif
399
+
400
+ template<typename T>
401
+ inline T QuantizeValue(float scale, int32_t zero_point, float value) {
402
+ const int32_t qmin = std::numeric_limits<T>::min();
403
+ const int32_t qmax = std::numeric_limits<T>::max();
404
+ auto r = zero_point + static_cast<int32_t>(Round(value / scale));
405
+ r = std::max(r, qmin);
406
+ r = std::min(r, qmax);
407
+ return static_cast<T>(r);
408
+ }
409
+
410
+ template<typename T>
411
+ inline std::pair<T, T> activationLimits(
412
+ float scale,
413
+ int32_t zero_point,
414
+ Activation Ac) {
415
+ switch (Ac) {
416
+ case Activation::NONE:
417
+ return {std::numeric_limits<T>::min(),
418
+ std::numeric_limits<T>::max()};
419
+ case Activation::RELU:
420
+ return {QuantizeValue<T>(scale, zero_point, 0.0),
421
+ std::numeric_limits<T>::max()};
422
+ default:
423
+ #ifdef _MSC_VER
424
+ __assume(0);
425
+ #else
426
+ __builtin_unreachable();
427
+ #endif
428
+ }
429
+ }
430
+
431
+ namespace at {
432
+ namespace native {
433
+ namespace qnnp_avgpool_helper {
434
+ Tensor qnnpack_avg_pool2d(
435
+ Tensor input,
436
+ IntArrayRef kernel_size,
437
+ IntArrayRef stride,
438
+ IntArrayRef padding,
439
+ bool ceil_mode,
440
+ bool count_include_pad,
441
+ c10::optional<int64_t> divisor_override);
442
+ } // qnnp_avgpool_helper
443
+ } // namespace native
444
+ } // namespace at
445
+
446
+ namespace {
447
+ C10_UNUSED std::vector<float> generate_requantization_scales(
448
+ const at::Tensor& weight_scales,
449
+ const float input_scale,
450
+ const float output_scale,
451
+ std::vector<float>& requant_scales) {
452
+ // Since weight scale is allocated with padding
453
+ // weight_scales.numel() gives us padded num elements.
454
+ const auto num_output_channels_padded = weight_scales.numel();
455
+ float *const weight_scales_data = weight_scales.data_ptr<float>();
456
+ if (static_cast<int64_t>(requant_scales.size()) < num_output_channels_padded) {
457
+ requant_scales.resize(num_output_channels_padded);
458
+ }
459
+ for (const auto i : c10::irange(num_output_channels_padded)) {
460
+ const auto inverse_output_scale = 1.f /output_scale;
461
+ requant_scales[i] = (weight_scales_data[i] * input_scale) * inverse_output_scale;
462
+ TORCH_CHECK(
463
+ (requant_scales[i] > 0.0f && std::isnormal(requant_scales[i])),
464
+ "failed to create op with requantization scale: ",
465
+ requant_scales[i],
466
+ ": requantization scale must be finite and positive");
467
+ }
468
+ return requant_scales;
469
+ }
470
+
471
+ C10_UNUSED std::pair<std::vector<uint8_t>, at::Tensor> make_zero_points_and_scales_tensor(
472
+ const at::Tensor& weight_contig,
473
+ bool transpose = false,
474
+ uint32_t groups = 1
475
+ ) {
476
+ const int out_ch_idx = transpose ? 1 : 0;
477
+ const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1);
478
+ // Add 8 to account for bufferring needed by QNNPACK.
479
+ const auto num_output_channels_padded = num_output_channels + kPaddingChannels;
480
+ const auto qtype = weight_contig.qscheme();
481
+ std::vector<uint8_t> weight_zp(num_output_channels_padded, 0);
482
+ // Adjust weight zero point, similar to weight data.
483
+ if (qtype == at::kPerTensorAffine) {
484
+ for (const auto i : c10::irange(num_output_channels)) {
485
+ weight_zp[i] = (uint8_t)(weight_contig.q_zero_point() + 128);
486
+ }
487
+ } else if (qtype == at::kPerChannelAffine) {
488
+ TORCH_CHECK(
489
+ weight_contig.q_per_channel_zero_points().scalar_type() == at::kLong,
490
+ "Per channel zero points dtype must be long int.");
491
+ const int64_t* per_channel_zero_points =
492
+ weight_contig.q_per_channel_zero_points().data_ptr<int64_t>();
493
+ for (const auto i : c10::irange(num_output_channels)) {
494
+ weight_zp[i] = (uint8_t)(per_channel_zero_points[i] + 128);
495
+ }
496
+ } else {
497
+ TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
498
+ }
499
+ at:: Tensor weight_scales =
500
+ at::empty(
501
+ {num_output_channels_padded},
502
+ at::device(at::kCPU).dtype(at::kFloat));
503
+ float *const weight_scales_data = weight_scales.data_ptr<float>();
504
+ if (qtype == at::kPerTensorAffine) {
505
+ for (const auto i : c10::irange(num_output_channels)) {
506
+ weight_scales_data[i] = weight_contig.q_scale();
507
+ }
508
+ } else if (qtype == at::kPerChannelAffine) {
509
+ TORCH_CHECK(
510
+ weight_contig.q_per_channel_scales().scalar_type() == at::kDouble,
511
+ "Per channel scales dtype must be double.");
512
+ const double *const per_channel_scales =
513
+ weight_contig.q_per_channel_scales().data_ptr<double>();
514
+ for (const auto i : c10::irange(num_output_channels)) {
515
+ weight_scales_data[i] = static_cast<float>(per_channel_scales[i]);
516
+ }
517
+ } else {
518
+ TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
519
+ }
520
+ for (const auto i : c10::irange(num_output_channels, num_output_channels_padded)) {
521
+ weight_scales_data[i] = 1.f;
522
+ }
523
+ return {weight_zp, weight_scales};
524
+ }
525
+ } // namespace
526
+
527
+ #endif