koichi12 commited on
Commit
a1a152c
·
verified ·
1 Parent(s): a378ef8

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/__init__.cpython-311.pyc +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_utils.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/comm.cpython-311.pyc +0 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nvtx.cpython-311.pyc +0 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/profiler.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/sparse.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/autocast_mode.py +144 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/error.py +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/graphs.py +479 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/jiterator.py +185 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nccl.py +137 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/profiler.py +61 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/random.py +179 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/sparse.py +1 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/streams.py +241 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ATen.h +37 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/AccumulateType.h +153 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Backend.h +2 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFixedAllocator.h +33 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CollapseDims.h +94 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h +29 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h +25 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h +808 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch_v2.h +186 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/EmptyTensor.h +160 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ExpandBase.h +30 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalTensorWrapper.h +408 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Generator.h +2 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedFallback.h +25 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h +160 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MapAllocator.h +139 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensor.h +1 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NestedTensorImpl.h +283 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PadNd.h +28 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel-inl.h +93 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel.h +160 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNative.h +19 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SavedTensorHooks.h +52 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorAccessor.h +2 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIteratorInternal.h +72 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorMeta.h +137 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorNames.h +75 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (64.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-311.pyc ADDED
Binary file (37.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-311.pyc ADDED
Binary file (36.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (2.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/comm.cpython-311.pyc ADDED
Binary file (520 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nvtx.cpython-311.pyc ADDED
Binary file (3.85 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/profiler.cpython-311.pyc ADDED
Binary file (3.51 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/sparse.cpython-311.pyc ADDED
Binary file (209 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/autocast_mode.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import functools
3
+
4
+ import torch
5
+
6
+ try:
7
+ import numpy as np
8
+
9
+ HAS_NUMPY = True
10
+ except ModuleNotFoundError:
11
+ np = None # type: ignore[assignment]
12
+ from typing import Any
13
+
14
+ __all__ = ["autocast", "custom_fwd", "custom_bwd"]
15
+
16
+
17
+ class autocast(torch.amp.autocast_mode.autocast):
18
+ r"""See :class:`torch.autocast`.
19
+
20
+ ``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)``
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ enabled: bool = True,
26
+ dtype: torch.dtype = torch.float16,
27
+ cache_enabled: bool = True,
28
+ ):
29
+ if torch._jit_internal.is_scripting():
30
+ self._enabled = enabled
31
+ self.device = "cuda"
32
+ self.fast_dtype = dtype
33
+ return
34
+ super().__init__(
35
+ "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
36
+ )
37
+
38
+ def __enter__(self):
39
+ if torch._jit_internal.is_scripting():
40
+ return self
41
+ return super().__enter__()
42
+
43
+ # TODO: discuss a unified TorchScript-friendly API for autocast
44
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
45
+ if torch._jit_internal.is_scripting():
46
+ return
47
+ return super().__exit__(exc_type, exc_val, exc_tb)
48
+
49
+ def __call__(self, func):
50
+ if torch._jit_internal.is_scripting():
51
+ return func
52
+ return super().__call__(func)
53
+
54
+
55
+ # Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
56
+ # may be falsely detected as "Iterables."
57
+ def _cast(value, dtype):
58
+ if isinstance(value, torch.Tensor):
59
+ is_eligible = (
60
+ value.is_floating_point()
61
+ and value.is_cuda
62
+ and (value.dtype is not torch.float64)
63
+ )
64
+ return value.to(dtype) if is_eligible else value
65
+ elif isinstance(value, (str, bytes)):
66
+ return value
67
+ elif HAS_NUMPY and isinstance(value, np.ndarray):
68
+ return value
69
+ elif isinstance(value, collections.abc.Mapping):
70
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
71
+ elif isinstance(value, collections.abc.Iterable):
72
+ iterable = (_cast(v, dtype) for v in value)
73
+ if isinstance(value, (list, tuple)):
74
+ return type(value)(iterable)
75
+ else:
76
+ return iterable
77
+ else:
78
+ return value
79
+
80
+
81
+ # custom_fwd is a decorator that may or may not be used with arguments, following
82
+ # https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
83
+ # this works:
84
+ # @custom_fwd
85
+ # def forward(...):
86
+ # this also works:
87
+ # @custom_fwd(cast_inputs=torch.float)
88
+ # def forward(...):
89
+ def custom_fwd(fwd=None, *, cast_inputs=None):
90
+ """
91
+ Create a helper decorator for ``forward`` methods of custom autograd functions.
92
+
93
+ Autograd functions are subclasses of :class:`torch.autograd.Function`.
94
+ See the :ref:`example page<amp-custom-examples>` for more detail.
95
+
96
+ Args:
97
+ cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
98
+ when ``forward`` runs in an autocast-enabled region, casts incoming
99
+ floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
100
+ then executes ``forward`` with autocast disabled.
101
+ If ``None``, ``forward``'s internal ops execute with the current autocast state.
102
+
103
+ .. note::
104
+ If the decorated ``forward`` is called outside an autocast-enabled region,
105
+ :func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
106
+ """
107
+ if fwd is None:
108
+ return functools.partial(custom_fwd, cast_inputs=cast_inputs)
109
+
110
+ @functools.wraps(fwd)
111
+ def decorate_fwd(*args, **kwargs):
112
+ args[0]._dtype = torch.get_autocast_gpu_dtype()
113
+ if cast_inputs is None:
114
+ args[0]._fwd_used_autocast = torch.is_autocast_enabled()
115
+ return fwd(*args, **kwargs)
116
+ else:
117
+ autocast_context = torch.is_autocast_enabled()
118
+ args[0]._fwd_used_autocast = False
119
+ if autocast_context:
120
+ with autocast(enabled=False):
121
+ return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
122
+ else:
123
+ return fwd(*args, **kwargs)
124
+
125
+ return decorate_fwd
126
+
127
+
128
+ # Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
129
+ # cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
130
+ # cast_inputs supplied to custom_fwd.
131
+ def custom_bwd(bwd):
132
+ """Create a helper decorator for backward methods of custom autograd functions.
133
+
134
+ Autograd functions are subclasses of :class:`torch.autograd.Function`.
135
+ Ensures that ``backward`` executes with the same autocast state as ``forward``.
136
+ See the :ref:`example page<amp-custom-examples>` for more detail.
137
+ """
138
+
139
+ @functools.wraps(bwd)
140
+ def decorate_bwd(*args, **kwargs):
141
+ with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
142
+ return bwd(*args, **kwargs)
143
+
144
+ return decorate_bwd
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/error.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/graphs.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch.utils import _pytree
6
+ from .._utils import _dummy_type
7
+
8
+ if not hasattr(torch._C, "_CudaStreamBase"):
9
+ # Define dummy base classes
10
+ torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
11
+ torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
12
+ torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
13
+ "_cuda_isCurrentStreamCapturing"
14
+ )
15
+
16
+ from torch._C import ( # noqa: F401
17
+ _cuda_isCurrentStreamCapturing,
18
+ _CUDAGraph,
19
+ _graph_pool_handle,
20
+ )
21
+
22
+
23
+ def is_current_stream_capturing():
24
+ r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
25
+
26
+ If a CUDA context does not exist on the current device, returns False without initializing the context.
27
+ """
28
+ return _cuda_isCurrentStreamCapturing()
29
+
30
+
31
+ # Python shim helps Sphinx process docstrings more reliably.
32
+ def graph_pool_handle():
33
+ r"""Return an opaque token representing the id of a graph memory pool.
34
+
35
+ See :ref:`Graph memory management<graph-memory-management>`.
36
+
37
+ .. warning::
38
+ This API is in beta and may change in future releases.
39
+ """
40
+ return _graph_pool_handle()
41
+
42
+
43
+ # Python shim helps Sphinx process docstrings more reliably.
44
+ class CUDAGraph(torch._C._CUDAGraph):
45
+ r"""Wrapper around a CUDA graph.
46
+
47
+ .. warning::
48
+ This API is in beta and may change in future releases.
49
+ """
50
+
51
+ def __new__(cls):
52
+ return super().__new__(cls)
53
+
54
+ def capture_begin(self, pool=None, capture_error_mode="global"):
55
+ r"""Begin capturing CUDA work on the current stream.
56
+
57
+ Typically, you shouldn't call ``capture_begin`` yourself.
58
+ Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
59
+ which call ``capture_begin`` internally.
60
+
61
+ Arguments:
62
+ pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
63
+ :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
64
+ with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
65
+ capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
66
+ Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
67
+ may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
68
+ actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
69
+ unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
70
+ """ # noqa: B950
71
+ super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
72
+
73
+ def capture_end(self):
74
+ r"""End CUDA graph capture on the current stream.
75
+
76
+ After ``capture_end``, ``replay`` may be called on this instance.
77
+
78
+ Typically, you shouldn't call ``capture_end`` yourself.
79
+ Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
80
+ which call ``capture_end`` internally.
81
+ """
82
+ super().capture_end()
83
+
84
+ def replay(self):
85
+ r"""Replay the CUDA work captured by this graph."""
86
+ super().replay()
87
+
88
+ def reset(self):
89
+ r"""Delete the graph currently held by this instance."""
90
+ super().reset()
91
+
92
+ def pool(self):
93
+ r"""Return an opaque token representing the id of this graph's memory pool.
94
+
95
+ This id can optionally be passed to another graph's ``capture_begin``,
96
+ which hints the other graph may share the same memory pool.
97
+ """
98
+ return super().pool()
99
+
100
+ def enable_debug_mode(self):
101
+ r"""Enable debugging mode for CUDAGraph.debug_dump."""
102
+ return super().enable_debug_mode()
103
+
104
+ def debug_dump(self, debug_path):
105
+ r"""
106
+ Arguments:
107
+ debug_path (required): Path to dump the graph to.
108
+
109
+ Calls a debugging function to dump the graph if the debugging is
110
+ enabled via CUDAGraph.enable_debug_mode()
111
+ """
112
+ return super().debug_dump(debug_path)
113
+
114
+
115
+ class graph:
116
+ r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
117
+
118
+ See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
119
+ detailed use, and constraints.
120
+
121
+ Arguments:
122
+ cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
123
+ pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
124
+ :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
125
+ may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
126
+ stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
127
+ If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
128
+ capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
129
+ Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
130
+ may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
131
+ actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
132
+ unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
133
+
134
+ .. note::
135
+ For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
136
+ used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
137
+
138
+ .. warning::
139
+ This API is in beta and may change in future releases.
140
+
141
+ .. _cudaStreamCaptureMode:
142
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
143
+ """ # noqa: B950
144
+
145
+ default_capture_stream: Optional["torch.cuda.Stream"] = None
146
+
147
+ def __init__(
148
+ self,
149
+ cuda_graph,
150
+ pool=None,
151
+ stream=None,
152
+ capture_error_mode: str = "global",
153
+ ):
154
+ # Lazy-init of default_capture_stream helps avoid circular-import errors.
155
+ # Not thread safe, but graphs already have the general (explicitly documented)
156
+ # restriction that only one capture may be underway at a time in the process.
157
+ if self.__class__.default_capture_stream is None:
158
+ self.__class__.default_capture_stream = torch.cuda.Stream()
159
+
160
+ self.pool = () if pool is None else (pool,)
161
+ self.capture_stream = (
162
+ stream if stream is not None else self.__class__.default_capture_stream
163
+ )
164
+ assert self.capture_stream is not None
165
+ self.stream_ctx = torch.cuda.stream(self.capture_stream)
166
+ self.cuda_graph = cuda_graph
167
+ self.capture_error_mode = capture_error_mode
168
+
169
+ def __enter__(self):
170
+ # Free as much memory as we can for the graph
171
+ torch.cuda.synchronize()
172
+ gc.collect()
173
+ torch.cuda.empty_cache()
174
+
175
+ # Stackoverflow seems comfortable with this pattern
176
+ # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
177
+ self.stream_ctx.__enter__()
178
+
179
+ self.cuda_graph.capture_begin(
180
+ *self.pool, capture_error_mode=self.capture_error_mode
181
+ )
182
+
183
+ def __exit__(self, exc_type, exc_value, traceback):
184
+ self.cuda_graph.capture_end()
185
+ self.stream_ctx.__exit__(exc_type, exc_value, traceback)
186
+ # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
187
+
188
+
189
+ def make_graphed_callables(
190
+ callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
191
+ ):
192
+ r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
193
+
194
+ Each graphed callable's forward pass runs its source callable's
195
+ forward CUDA work as a CUDA graph inside a single autograd node.
196
+
197
+ The graphed callable's forward pass also appends
198
+ a backward node to the autograd graph. During backward, this node runs the
199
+ callable's backward work as a CUDA graph.
200
+
201
+ Therefore, each graphed callable should be a drop-in replacement for its source callable
202
+ in an autograd-enabled training loop.
203
+
204
+ See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
205
+
206
+ If you pass a tuple of several callables, their captures will use the same memory pool.
207
+ See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
208
+
209
+ Arguments:
210
+ callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
211
+ See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
212
+ is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
213
+ they'll run in the live workload.
214
+ sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
215
+ If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
216
+ If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
217
+ num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
218
+ 11 iterations for warm up. Default: ``3``.
219
+ allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
220
+ (and therefore their grad is always zero) is an error. Defaults to False.
221
+ pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
222
+ :meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
223
+ with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
224
+ .. note::
225
+ The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
226
+ that's expected for the corresponding real input in the training loop.
227
+
228
+ .. warning::
229
+ This API is in beta and may change in future releases.
230
+
231
+ .. warning::
232
+ ``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
233
+
234
+ .. warning::
235
+ Returned callables do not support higher order differentiation (e.g., double backward).
236
+
237
+ .. warning::
238
+ In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
239
+ may be trainable. Buffers must have ``requires_grad=False``.
240
+
241
+ .. warning::
242
+ After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
243
+ you may not add or remove any of that Module's parameters or buffers.
244
+
245
+ .. warning::
246
+ :class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
247
+ registered on them at the time they are passed. However, registering hooks on modules *after* passing them
248
+ through :func:`~torch.cuda.make_graphed_callables` is allowed.
249
+
250
+ .. warning::
251
+ When running a graphed callable, you must pass its arguments in the same order and format
252
+ they appeared in that callable's ``sample_args``.
253
+
254
+ .. warning::
255
+ The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
256
+ caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
257
+ """
258
+ if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
259
+ raise RuntimeError(
260
+ "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
261
+ )
262
+
263
+ just_one_callable = False
264
+
265
+ if not isinstance(callables, tuple):
266
+ just_one_callable = True
267
+ callables = (callables,)
268
+ sample_args = (sample_args,)
269
+
270
+ flatten_sample_args = []
271
+
272
+ for c, args in zip(callables, sample_args):
273
+ if isinstance(c, torch.nn.Module):
274
+ assert (
275
+ len(c._backward_hooks) == 0
276
+ and len(c._forward_hooks) == 0
277
+ and len(c._forward_pre_hooks) == 0
278
+ ), (
279
+ "Modules must not have hooks registered at the time they are passed. However, registering hooks "
280
+ + "on modules after passing them through make_graphed_callables is allowed."
281
+ )
282
+ assert all(b.requires_grad is False for b in c.buffers()), (
283
+ "In any :class:`~torch.nn.Module` passed to "
284
+ + ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
285
+ + "``requires_grad=False``."
286
+ )
287
+ flatten_arg = _pytree.arg_tree_leaves(*args)
288
+ flatten_sample_args.append(tuple(flatten_arg))
289
+ assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
290
+ "In the beta API, sample_args "
291
+ + "for each callable must contain only Tensors. Other types are not allowed."
292
+ )
293
+
294
+ # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
295
+ # passes to forward (ie, its sample_args) AND the module's parameter attributes.
296
+ per_callable_len_user_args = [len(args) for args in flatten_sample_args]
297
+ per_callable_module_params = [
298
+ tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
299
+ for c in callables
300
+ ]
301
+ per_callable_static_input_surfaces = [
302
+ flatten_sample_args[i] + per_callable_module_params[i]
303
+ for i in range(len(callables))
304
+ ]
305
+
306
+ fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
307
+ bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
308
+
309
+ mempool = graph_pool_handle() if pool is None else pool
310
+
311
+ # Warmup
312
+ # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
313
+ # from ending up in any captures.
314
+ torch.cuda.synchronize()
315
+ with torch.cuda.stream(torch.cuda.Stream()):
316
+ for func, args, static_input_surface in zip(
317
+ callables, sample_args, per_callable_static_input_surfaces
318
+ ):
319
+ for _ in range(num_warmup_iters):
320
+ outputs = _pytree.tree_leaves(func(*args))
321
+ grad_inputs = torch.autograd.grad(
322
+ outputs=tuple(o for o in outputs if o.requires_grad),
323
+ inputs=tuple(i for i in static_input_surface if i.requires_grad),
324
+ grad_outputs=tuple(
325
+ torch.empty_like(o) for o in outputs if o.requires_grad
326
+ ),
327
+ only_inputs=True,
328
+ allow_unused=allow_unused_input,
329
+ )
330
+ del outputs, grad_inputs # type: ignore[possibly-undefined]
331
+ torch.cuda.synchronize()
332
+
333
+ # All captures here share a mempool. To avoid replays corrupting each other's memory,
334
+ # the safest approach is to capture all passes in the same order they'll run:
335
+ # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
336
+
337
+ # Capture forward graphs
338
+ per_callable_static_outputs = []
339
+ per_callable_output_unflatten_spec = []
340
+ for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
341
+ with torch.cuda.graph(fwd_graph, pool=mempool):
342
+ outputs = func(*args)
343
+
344
+ flatten_outputs, spec = _pytree.tree_flatten(outputs)
345
+ per_callable_static_outputs.append(tuple(flatten_outputs))
346
+ per_callable_output_unflatten_spec.append(spec)
347
+
348
+ # Capture backward graphs in reverse order
349
+ per_callable_static_grad_outputs = []
350
+ per_callable_static_grad_inputs = []
351
+ for static_input_surface, static_outputs, bwd_graph, module_params in zip(
352
+ reversed(per_callable_static_input_surfaces),
353
+ reversed(per_callable_static_outputs),
354
+ reversed(bwd_graphs),
355
+ reversed(per_callable_module_params),
356
+ ):
357
+ # For now, assumes all static_outputs require grad
358
+ # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
359
+ static_grad_outputs = tuple(
360
+ torch.empty_like(o) if o.requires_grad else None for o in static_outputs
361
+ )
362
+
363
+ with torch.cuda.graph(bwd_graph, pool=mempool):
364
+ grad_inputs = torch.autograd.grad(
365
+ outputs=tuple(o for o in static_outputs if o.requires_grad),
366
+ inputs=tuple(i for i in static_input_surface if i.requires_grad),
367
+ grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
368
+ only_inputs=True,
369
+ allow_unused=allow_unused_input,
370
+ )
371
+
372
+ # Constructs a tuple suitable for returning from Graphed.backward:
373
+ # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
374
+ # I couldn't think of a slick one-liner for this pattern.
375
+ static_grad_inputs = []
376
+ grad_idx = 0
377
+ for arg in static_input_surface:
378
+ if arg.requires_grad:
379
+ static_grad_inputs.append(grad_inputs[grad_idx])
380
+ grad_idx += 1
381
+ else:
382
+ static_grad_inputs.append(None) # type: ignore[arg-type]
383
+ static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
384
+
385
+ per_callable_static_grad_outputs.append(static_grad_outputs)
386
+ per_callable_static_grad_inputs.append(static_grad_inputs)
387
+
388
+ # Reverses the most recent two lists
389
+ per_callable_static_grad_outputs.reverse()
390
+ per_callable_static_grad_inputs.reverse()
391
+ # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
392
+
393
+ def make_graphed_autograd_function(
394
+ fwd_graph,
395
+ bwd_graph,
396
+ module_params,
397
+ len_user_args,
398
+ output_unflatten_spec,
399
+ static_input_surface,
400
+ static_outputs,
401
+ static_grad_outputs,
402
+ static_grad_inputs,
403
+ ):
404
+ class Graphed(torch.autograd.Function):
405
+ @staticmethod
406
+ def forward(ctx, *inputs):
407
+ # At this stage, only the user args may (potentially) be new tensors.
408
+ for i in range(len_user_args):
409
+ if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
410
+ static_input_surface[i].copy_(inputs[i])
411
+ fwd_graph.replay()
412
+ assert isinstance(static_outputs, tuple)
413
+ return tuple(o.detach() for o in static_outputs)
414
+
415
+ @staticmethod
416
+ @torch.autograd.function.once_differentiable
417
+ def backward(ctx, *grads):
418
+ assert len(grads) == len(static_grad_outputs)
419
+ for g, grad in zip(static_grad_outputs, grads):
420
+ if g is not None:
421
+ # don't copy if autograd gods have been kind and the
422
+ # incoming grad is already in the right place
423
+ if g.data_ptr() != grad.data_ptr():
424
+ g.copy_(grad)
425
+ bwd_graph.replay()
426
+
427
+ # Input args that didn't require grad expect a None gradient.
428
+ assert isinstance(static_grad_inputs, tuple)
429
+ return tuple(
430
+ b.detach() if b is not None else b for b in static_grad_inputs
431
+ )
432
+
433
+ def functionalized(*user_args):
434
+ # Runs the autograd function with inputs == all inputs to the graph that might require grad
435
+ # (explicit user args + module parameters)
436
+ # Assumes module params didn't change since capture.
437
+ flatten_user_args = _pytree.arg_tree_leaves(*user_args)
438
+ out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
439
+ return _pytree.tree_unflatten(out, output_unflatten_spec)
440
+
441
+ return functionalized
442
+
443
+ # Put together the final graphed callables
444
+ ret = []
445
+ for i, func in enumerate(callables):
446
+ graphed = make_graphed_autograd_function(
447
+ fwd_graphs[i],
448
+ bwd_graphs[i],
449
+ per_callable_module_params[i],
450
+ per_callable_len_user_args[i],
451
+ per_callable_output_unflatten_spec[i],
452
+ per_callable_static_input_surfaces[i],
453
+ per_callable_static_outputs[i],
454
+ per_callable_static_grad_outputs[i],
455
+ per_callable_static_grad_inputs[i],
456
+ )
457
+
458
+ if isinstance(func, torch.nn.Module):
459
+
460
+ def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
461
+ def new_fwd(*user_args):
462
+ # If the module's training-or-eval state matches what we graphed,
463
+ # run the graph, otherwise run the original forward method
464
+ if func.training == graph_training_state:
465
+ return graphed(*user_args)
466
+ else:
467
+ return orig_fwd(*user_args)
468
+
469
+ return new_fwd
470
+
471
+ func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment]
472
+ ret.append(func)
473
+ else:
474
+ ret.append(graphed)
475
+
476
+ if just_one_callable:
477
+ return ret[0]
478
+
479
+ return tuple(ret)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/jiterator.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Callable, List
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ __all__: List[str] = []
8
+
9
+
10
+ class _CodeParser:
11
+ def __init__(self, code_string: str):
12
+ optional_ws = r"\s*"
13
+ required_ws = r"\s+"
14
+ template_params = r"(?P<template_params>\<.+\>)"
15
+ return_type = r"(?P<return_type>\w+)"
16
+ function_name = r"(?P<function_name>\w+)"
17
+ function_params = r"(?P<function_params>\(.+\))"
18
+ function_body = r"(?P<function_body>\{.+\})"
19
+
20
+ pattern = (
21
+ optional_ws
22
+ + "template"
23
+ + optional_ws
24
+ + template_params
25
+ + optional_ws
26
+ + return_type
27
+ + required_ws
28
+ + function_name
29
+ + optional_ws
30
+ + function_params
31
+ + optional_ws
32
+ + function_body
33
+ + optional_ws
34
+ )
35
+
36
+ result = re.match(
37
+ pattern, code_string, re.DOTALL
38
+ ) # DOTALL for matching multiline
39
+
40
+ if result is None:
41
+ raise Exception(
42
+ f"Couldn't parse code, please check correctness:\n {code_string}"
43
+ )
44
+
45
+ self.template_params = result["template_params"]
46
+ self.return_type = result["return_type"]
47
+ self.function_name = result["function_name"]
48
+ self.function_params = result["function_params"]
49
+ self.function_body = result["function_body"]
50
+
51
+
52
+ class _JittedFunction:
53
+ def __init__(
54
+ self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
55
+ ):
56
+ self.code_string = code_string
57
+
58
+ assert (
59
+ return_by_ref or num_outputs == 1
60
+ ), "Return by value only works for single output. "
61
+ self.return_by_ref = return_by_ref
62
+ self.num_outputs = num_outputs
63
+
64
+ parsed_code = _CodeParser(code_string)
65
+ self.kernel_name = parsed_code.function_name
66
+
67
+ self.kwargs_dict = kwargs
68
+ self.is_cuda_available = torch.cuda.is_available()
69
+
70
+ def __call__(self, *tensors: Tensor, **kwargs):
71
+ # Jiterator follow torch.cuda's lazy initialization behavior
72
+ # Defer checking cuda's availability at the function invocation time
73
+ assert (
74
+ self.is_cuda_available
75
+ ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
76
+
77
+ assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
78
+
79
+ expanded_kwargs = self.kwargs_dict.copy()
80
+ for key, value in kwargs.items():
81
+ if key in self.kwargs_dict:
82
+ expanded_kwargs[key] = value
83
+ else:
84
+ raise KeyError(f"{key} is not declared in function definition")
85
+
86
+ return torch._C._cuda_jiterator_compile_and_launch_kernel(
87
+ self.code_string,
88
+ self.kernel_name,
89
+ self.return_by_ref,
90
+ self.num_outputs,
91
+ tensors,
92
+ expanded_kwargs,
93
+ )
94
+
95
+
96
+ def _create_jit_fn(code_string: str, **kwargs) -> Callable:
97
+ """
98
+ Create a jiterator-generated cuda kernel for an elementwise op.
99
+
100
+ The code string has to be a valid CUDA function that describes the computation for a single element. The code
101
+ string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
102
+ into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
103
+ local temp dir.
104
+
105
+ Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
106
+
107
+ Args:
108
+ code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
109
+ kwargs (Dict, optional): Keyword arguments for generated function
110
+
111
+ Example::
112
+
113
+ code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
114
+ jitted_fn = create_jit_fn(code_string, alpha=1.0)
115
+ a = torch.rand(3, device='cuda')
116
+ b = torch.rand(3, device='cuda')
117
+ # invoke jitted function like a regular python function
118
+ result = jitted_fn(a, b, alpha=3.14)
119
+
120
+ code_string also allows multiple function definitions, and the last function will be treated as the entry function.
121
+
122
+ Example::
123
+
124
+ code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
125
+ code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
126
+ jitted_fn = create_jit_fn(code_string, val=0.0)
127
+ a = torch.rand(3, device='cuda')
128
+ b = torch.rand(3, device='cuda')
129
+ # invoke jitted function like a regular python function
130
+ result = jitted_fn(a, b) # using default val=0.0
131
+
132
+ Jiterator can be used together with python registration to override an operator's cuda kernel.
133
+ Following example is overriding gelu's cuda kernel with relu.
134
+
135
+ Example::
136
+
137
+ code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
138
+ my_gelu = create_jit_fn(code_string)
139
+ my_lib = torch.library.Library("aten", "IMPL")
140
+ my_lib.impl('aten::gelu', my_gelu, "CUDA")
141
+ # torch.nn.GELU and torch.nn.function.gelu are now overridden
142
+ a = torch.rand(3, device='cuda')
143
+ torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
144
+
145
+ .. warning::
146
+ This API is in beta and may change in future releases.
147
+
148
+ .. warning::
149
+ This API only supports up to 8 inputs and 1 output
150
+
151
+ .. warning::
152
+ All input tensors must live in CUDA device
153
+ """
154
+ return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
155
+
156
+
157
+ def _create_multi_output_jit_fn(
158
+ code_string: str, num_outputs: int, **kwargs
159
+ ) -> Callable:
160
+ """
161
+ Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
162
+
163
+ Args:
164
+ code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
165
+ num_outputs(int): number of outputs return by the kernel
166
+ kwargs (Dict, optional): Keyword arguments for generated function
167
+
168
+ Example::
169
+
170
+ code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
171
+ jitted_fn = create_jit_fn(code_string, alpha=1.0)
172
+ a = torch.rand(3, device='cuda')
173
+ b = torch.rand(3, device='cuda')
174
+ # invoke jitted function like a regular python function
175
+ result = jitted_fn(a, b, alpha=3.14)
176
+
177
+ .. warning::
178
+ This API is in beta and may change in future releases.
179
+
180
+ .. warning::
181
+ This API only supports up to 8 inputs and 8 outputs
182
+ """
183
+ return _JittedFunction(
184
+ code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
185
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nccl.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import warnings
3
+ from typing import Optional, Sequence, Union
4
+
5
+ import torch.cuda
6
+
7
+
8
+ __all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
9
+
10
+ SUM = 0 # ncclRedOp_t
11
+
12
+
13
+ def is_available(tensors):
14
+ if not hasattr(torch._C, "_nccl_all_reduce"):
15
+ warnings.warn("PyTorch is not compiled with NCCL support")
16
+ return False
17
+
18
+ devices = set()
19
+ for tensor in tensors:
20
+ if tensor.is_sparse:
21
+ return False
22
+ if not tensor.is_contiguous():
23
+ return False
24
+ if not tensor.is_cuda:
25
+ return False
26
+ device = tensor.get_device()
27
+ if device in devices:
28
+ return False
29
+ devices.add(device)
30
+
31
+ return True
32
+
33
+
34
+ def version():
35
+ ver = torch._C._nccl_version()
36
+ major = ver >> 32
37
+ minor = (ver >> 16) & 65535
38
+ patch = ver & 65535
39
+ suffix = torch._C._nccl_version_suffix().decode("utf-8")
40
+ if suffix == "":
41
+ return (major, minor, patch)
42
+ else:
43
+ return (major, minor, patch, suffix)
44
+
45
+
46
+ def unique_id():
47
+ return torch._C._nccl_unique_id()
48
+
49
+
50
+ def init_rank(num_ranks, uid, rank):
51
+ return torch._C._nccl_init_rank(num_ranks, uid, rank)
52
+
53
+
54
+ def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
55
+ if not isinstance(inputs, collections.abc.Container) or isinstance(
56
+ inputs, torch.Tensor
57
+ ):
58
+ raise TypeError("Inputs should be a collection of tensors")
59
+
60
+
61
+ def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
62
+ _check_sequence_type(inputs)
63
+ if outputs is None:
64
+ outputs = inputs
65
+ _check_sequence_type(outputs)
66
+ torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
67
+
68
+
69
+ # `output` used to be `outputs`, taking in a list of tensors. So we have two
70
+ # arguments for BC reasons.
71
+ def reduce(
72
+ inputs: Sequence[torch.Tensor],
73
+ output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
74
+ root: int = 0,
75
+ op: int = SUM,
76
+ streams: Optional[Sequence[torch.cuda.Stream]] = None,
77
+ comms=None,
78
+ *,
79
+ outputs: Optional[Sequence[torch.Tensor]] = None,
80
+ ) -> None:
81
+ _check_sequence_type(inputs)
82
+ _output: torch.Tensor
83
+ if outputs is not None:
84
+ if output is not None:
85
+ raise ValueError(
86
+ "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
87
+ "favor of 'output', taking in a single output tensor. The signature of reduce is: "
88
+ "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
89
+ )
90
+ else:
91
+ warnings.warn(
92
+ "nccl.reduce with an output tensor list is deprecated. "
93
+ "Please specify a single output tensor with argument 'output' instead instead."
94
+ )
95
+ _output = outputs[root]
96
+ elif not isinstance(output, torch.Tensor) and isinstance(
97
+ output, collections.abc.Sequence
98
+ ):
99
+ # User called old API with positional arguments of list of output tensors.
100
+ warnings.warn(
101
+ "nccl.reduce with an output tensor list is deprecated. "
102
+ "Please specify a single output tensor."
103
+ )
104
+ _output = output[root]
105
+ else:
106
+ _output = inputs[root] if output is None else output
107
+ torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
108
+
109
+
110
+ def broadcast(
111
+ inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
112
+ ) -> None:
113
+ _check_sequence_type(inputs)
114
+ torch._C._nccl_broadcast(inputs, root, streams, comms)
115
+
116
+
117
+ def all_gather(
118
+ inputs: Sequence[torch.Tensor],
119
+ outputs: Sequence[torch.Tensor],
120
+ streams=None,
121
+ comms=None,
122
+ ) -> None:
123
+ _check_sequence_type(inputs)
124
+ _check_sequence_type(outputs)
125
+ torch._C._nccl_all_gather(inputs, outputs, streams, comms)
126
+
127
+
128
+ def reduce_scatter(
129
+ inputs: Sequence[torch.Tensor],
130
+ outputs: Sequence[torch.Tensor],
131
+ op: int = SUM,
132
+ streams=None,
133
+ comms=None,
134
+ ) -> None:
135
+ _check_sequence_type(inputs)
136
+ _check_sequence_type(outputs)
137
+ torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/profiler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import tempfile
3
+
4
+ import torch
5
+ from . import check_error, cudart
6
+
7
+ __all__ = ["init", "start", "stop", "profile"]
8
+
9
+ DEFAULT_FLAGS = [
10
+ "gpustarttimestamp",
11
+ "gpuendtimestamp",
12
+ "gridsize3d",
13
+ "threadblocksize",
14
+ "streamid",
15
+ "enableonstart 0",
16
+ "conckerneltrace",
17
+ ]
18
+
19
+
20
+ def init(output_file, flags=None, output_mode="key_value"):
21
+ rt = cudart()
22
+ if not hasattr(rt, "cudaOutputMode"):
23
+ raise AssertionError("HIP does not support profiler initialization!")
24
+ if (
25
+ hasattr(torch.version, "cuda")
26
+ and torch.version.cuda is not None
27
+ and int(torch.version.cuda.split(".")[0]) >= 12
28
+ ):
29
+ # Check https://github.com/pytorch/pytorch/pull/91118
30
+ # cudaProfilerInitialize is no longer needed after CUDA 12
31
+ raise AssertionError("CUDA12+ does not need profiler initialization!")
32
+ flags = DEFAULT_FLAGS if flags is None else flags
33
+ if output_mode == "key_value":
34
+ output_mode_enum = rt.cudaOutputMode.KeyValuePair
35
+ elif output_mode == "csv":
36
+ output_mode_enum = rt.cudaOutputMode.CSV
37
+ else:
38
+ raise RuntimeError(
39
+ "supported CUDA profiler output modes are: key_value and csv"
40
+ )
41
+ with tempfile.NamedTemporaryFile(delete=True) as f:
42
+ f.write(b"\n".join(f.encode("ascii") for f in flags))
43
+ f.flush()
44
+ check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum))
45
+
46
+
47
+ def start():
48
+ check_error(cudart().cudaProfilerStart())
49
+
50
+
51
+ def stop():
52
+ check_error(cudart().cudaProfilerStop())
53
+
54
+
55
+ @contextlib.contextmanager
56
+ def profile():
57
+ try:
58
+ start()
59
+ yield
60
+ finally:
61
+ stop()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/random.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, List, Union
2
+
3
+ import torch
4
+ from .. import Tensor
5
+ from . import _lazy_call, _lazy_init, current_device, device_count
6
+
7
+ __all__ = [
8
+ "get_rng_state",
9
+ "get_rng_state_all",
10
+ "set_rng_state",
11
+ "set_rng_state_all",
12
+ "manual_seed",
13
+ "manual_seed_all",
14
+ "seed",
15
+ "seed_all",
16
+ "initial_seed",
17
+ ]
18
+
19
+
20
+ def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
21
+ r"""Return the random number generator state of the specified GPU as a ByteTensor.
22
+
23
+ Args:
24
+ device (torch.device or int, optional): The device to return the RNG state of.
25
+ Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
26
+
27
+ .. warning::
28
+ This function eagerly initializes CUDA.
29
+ """
30
+ _lazy_init()
31
+ if isinstance(device, str):
32
+ device = torch.device(device)
33
+ elif isinstance(device, int):
34
+ device = torch.device("cuda", device)
35
+ idx = device.index
36
+ if idx is None:
37
+ idx = current_device()
38
+ default_generator = torch.cuda.default_generators[idx]
39
+ return default_generator.get_state()
40
+
41
+
42
+ def get_rng_state_all() -> List[Tensor]:
43
+ r"""Return a list of ByteTensor representing the random number states of all devices."""
44
+ results = []
45
+ for i in range(device_count()):
46
+ results.append(get_rng_state(i))
47
+ return results
48
+
49
+
50
+ def set_rng_state(
51
+ new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
52
+ ) -> None:
53
+ r"""Set the random number generator state of the specified GPU.
54
+
55
+ Args:
56
+ new_state (torch.ByteTensor): The desired state
57
+ device (torch.device or int, optional): The device to set the RNG state.
58
+ Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
59
+ """
60
+ with torch._C._DisableFuncTorch():
61
+ new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
62
+ if isinstance(device, str):
63
+ device = torch.device(device)
64
+ elif isinstance(device, int):
65
+ device = torch.device("cuda", device)
66
+
67
+ def cb():
68
+ idx = device.index
69
+ if idx is None:
70
+ idx = current_device()
71
+ default_generator = torch.cuda.default_generators[idx]
72
+ default_generator.set_state(new_state_copy)
73
+
74
+ _lazy_call(cb)
75
+
76
+
77
+ def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
78
+ r"""Set the random number generator state of all devices.
79
+
80
+ Args:
81
+ new_states (Iterable of torch.ByteTensor): The desired state for each device.
82
+ """
83
+ for i, state in enumerate(new_states):
84
+ set_rng_state(state, i)
85
+
86
+
87
+ def manual_seed(seed: int) -> None:
88
+ r"""Set the seed for generating random numbers for the current GPU.
89
+
90
+ It's safe to call this function if CUDA is not available; in that
91
+ case, it is silently ignored.
92
+
93
+ Args:
94
+ seed (int): The desired seed.
95
+
96
+ .. warning::
97
+ If you are working with a multi-GPU model, this function is insufficient
98
+ to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
99
+ """
100
+ seed = int(seed)
101
+
102
+ def cb():
103
+ idx = current_device()
104
+ default_generator = torch.cuda.default_generators[idx]
105
+ default_generator.manual_seed(seed)
106
+
107
+ _lazy_call(cb, seed=True)
108
+
109
+
110
+ def manual_seed_all(seed: int) -> None:
111
+ r"""Set the seed for generating random numbers on all GPUs.
112
+
113
+ It's safe to call this function if CUDA is not available; in that
114
+ case, it is silently ignored.
115
+
116
+ Args:
117
+ seed (int): The desired seed.
118
+ """
119
+ seed = int(seed)
120
+
121
+ def cb():
122
+ for i in range(device_count()):
123
+ default_generator = torch.cuda.default_generators[i]
124
+ default_generator.manual_seed(seed)
125
+
126
+ _lazy_call(cb, seed_all=True)
127
+
128
+
129
+ def seed() -> None:
130
+ r"""Set the seed for generating random numbers to a random number for the current GPU.
131
+
132
+ It's safe to call this function if CUDA is not available; in that
133
+ case, it is silently ignored.
134
+
135
+ .. warning::
136
+ If you are working with a multi-GPU model, this function will only initialize
137
+ the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
138
+ """
139
+
140
+ def cb():
141
+ idx = current_device()
142
+ default_generator = torch.cuda.default_generators[idx]
143
+ default_generator.seed()
144
+
145
+ _lazy_call(cb)
146
+
147
+
148
+ def seed_all() -> None:
149
+ r"""Set the seed for generating random numbers to a random number on all GPUs.
150
+
151
+ It's safe to call this function if CUDA is not available; in that
152
+ case, it is silently ignored.
153
+ """
154
+
155
+ def cb():
156
+ random_seed = 0
157
+ seeded = False
158
+ for i in range(device_count()):
159
+ default_generator = torch.cuda.default_generators[i]
160
+ if not seeded:
161
+ default_generator.seed()
162
+ random_seed = default_generator.initial_seed()
163
+ seeded = True
164
+ else:
165
+ default_generator.manual_seed(random_seed)
166
+
167
+ _lazy_call(cb)
168
+
169
+
170
+ def initial_seed() -> int:
171
+ r"""Return the current random seed of the current GPU.
172
+
173
+ .. warning::
174
+ This function eagerly initializes CUDA.
175
+ """
176
+ _lazy_init()
177
+ idx = current_device()
178
+ default_generator = torch.cuda.default_generators[idx]
179
+ return default_generator.initial_seed()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/sparse.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # The Tensor classes are added to this module by python_tensor.cpp
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/streams.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+
3
+ import torch
4
+ from torch._streambase import _EventBase, _StreamBase
5
+ from .._utils import _dummy_type
6
+
7
+
8
+ if not hasattr(torch._C, "_CudaStreamBase"):
9
+ # Define dummy base classes
10
+ torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase")
11
+ torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
12
+
13
+
14
+ class Stream(torch._C._CudaStreamBase, _StreamBase):
15
+ r"""Wrapper around a CUDA stream.
16
+
17
+ A CUDA stream is a linear sequence of execution that belongs to a specific
18
+ device, independent from other streams. See :ref:`cuda-semantics` for
19
+ details.
20
+
21
+ Args:
22
+ device(torch.device or int, optional): a device on which to allocate
23
+ the stream. If :attr:`device` is ``None`` (default) or a negative
24
+ integer, this will use the current device.
25
+ priority(int, optional): priority of the stream, should be 0 or
26
+ negative, where negative numbers indicate higher priority. By default,
27
+ streams have priority 0.
28
+
29
+ """
30
+
31
+ def __new__(cls, device=None, priority=0, **kwargs):
32
+ # setting device manager is expensive, so we avoid it unless necessary
33
+ if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
34
+ return super().__new__(cls, priority=priority, **kwargs)
35
+ else:
36
+ with torch.cuda.device(device):
37
+ return super().__new__(cls, priority=priority, **kwargs)
38
+
39
+ def wait_event(self, event):
40
+ r"""Make all future work submitted to the stream wait for an event.
41
+
42
+ Args:
43
+ event (torch.cuda.Event): an event to wait for.
44
+
45
+ .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
46
+ `CUDA Stream documentation`_ for more info.
47
+
48
+ This function returns without waiting for :attr:`event`: only future
49
+ operations are affected.
50
+
51
+ .. _CUDA Stream documentation:
52
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
53
+ """
54
+ event.wait(self)
55
+
56
+ def wait_stream(self, stream):
57
+ r"""Synchronize with another stream.
58
+
59
+ All future work submitted to this stream will wait until all kernels
60
+ submitted to a given stream at the time of call complete.
61
+
62
+ Args:
63
+ stream (Stream): a stream to synchronize.
64
+
65
+ .. note:: This function returns without waiting for currently enqueued
66
+ kernels in :attr:`stream`: only future operations are affected.
67
+ """
68
+ self.wait_event(stream.record_event())
69
+
70
+ def record_event(self, event=None):
71
+ r"""Record an event.
72
+
73
+ Args:
74
+ event (torch.cuda.Event, optional): event to record. If not given, a new one
75
+ will be allocated.
76
+
77
+ Returns:
78
+ Recorded event.
79
+ """
80
+ if event is None:
81
+ event = Event()
82
+ event.record(self)
83
+ return event
84
+
85
+ def query(self):
86
+ r"""Check if all the work submitted has been completed.
87
+
88
+ Returns:
89
+ A boolean indicating if all kernels in this stream are completed.
90
+ """
91
+ return super().query()
92
+
93
+ def synchronize(self):
94
+ r"""Wait for all the kernels in this stream to complete.
95
+
96
+ .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
97
+ `CUDA Stream documentation`_ for more info.
98
+ """
99
+ super().synchronize()
100
+
101
+ @property
102
+ def _as_parameter_(self):
103
+ return ctypes.c_void_p(self.cuda_stream)
104
+
105
+ def __eq__(self, o):
106
+ if isinstance(o, Stream):
107
+ return super().__eq__(o)
108
+ return False
109
+
110
+ def __hash__(self):
111
+ return hash((self.cuda_stream, self.device))
112
+
113
+ def __repr__(self):
114
+ return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
115
+
116
+
117
+ class ExternalStream(Stream):
118
+ r"""Wrapper around an externally allocated CUDA stream.
119
+
120
+ This class is used to wrap streams allocated in other libraries in order
121
+ to facilitate data exchange and multi-library interactions.
122
+
123
+ .. note:: This class doesn't manage the stream life-cycle, it is the user
124
+ responsibility to keep the referenced stream alive while this class is
125
+ being used.
126
+
127
+ Args:
128
+ stream_ptr(int): Integer representation of the `cudaStream_t` value.
129
+ allocated externally.
130
+ device(torch.device or int, optional): the device where the stream
131
+ was originally allocated. if device is specified incorrectly,
132
+ subsequent launches using this stream may fail.
133
+ """
134
+
135
+ def __new__(cls, stream_ptr, device=None, **kwargs):
136
+ with torch.cuda.device(device):
137
+ return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
138
+
139
+
140
+ class Event(torch._C._CudaEventBase, _EventBase):
141
+ r"""Wrapper around a CUDA event.
142
+
143
+ CUDA events are synchronization markers that can be used to monitor the
144
+ device's progress, to accurately measure timing, and to synchronize CUDA
145
+ streams.
146
+
147
+ The underlying CUDA events are lazily initialized when the event is first
148
+ recorded or exported to another process. After creation, only streams on the
149
+ same device may record the event. However, streams on any device can wait on
150
+ the event.
151
+
152
+ Args:
153
+ enable_timing (bool, optional): indicates if the event should measure time
154
+ (default: ``False``)
155
+ blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
156
+ interprocess (bool): if ``True``, the event can be shared between processes
157
+ (default: ``False``)
158
+
159
+ .. _CUDA Event Documentation:
160
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
161
+ """
162
+
163
+ def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
164
+ return super().__new__(
165
+ cls,
166
+ enable_timing=enable_timing,
167
+ blocking=blocking,
168
+ interprocess=interprocess,
169
+ )
170
+
171
+ @classmethod
172
+ def from_ipc_handle(cls, device, handle):
173
+ r"""Reconstruct an event from an IPC handle on the given device."""
174
+ return super().from_ipc_handle(device, handle)
175
+
176
+ def record(self, stream=None):
177
+ r"""Record the event in a given stream.
178
+
179
+ Uses ``torch.cuda.current_stream()`` if no stream is specified. The
180
+ stream's device must match the event's device.
181
+ """
182
+ if stream is None:
183
+ stream = torch.cuda.current_stream()
184
+ super().record(stream)
185
+
186
+ def wait(self, stream=None):
187
+ r"""Make all future work submitted to the given stream wait for this event.
188
+
189
+ Use ``torch.cuda.current_stream()`` if no stream is specified.
190
+
191
+ .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
192
+ `CUDA Event documentation`_ for more info.
193
+ """
194
+ if stream is None:
195
+ stream = torch.cuda.current_stream()
196
+ super().wait(stream)
197
+
198
+ def query(self):
199
+ r"""Check if all work currently captured by event has completed.
200
+
201
+ Returns:
202
+ A boolean indicating if all work currently captured by event has
203
+ completed.
204
+ """
205
+ return super().query()
206
+
207
+ def elapsed_time(self, end_event):
208
+ r"""Return the time elapsed.
209
+
210
+ Time reported in milliseconds after the event was recorded and
211
+ before the end_event was recorded.
212
+ """
213
+ return super().elapsed_time(end_event)
214
+
215
+ def synchronize(self):
216
+ r"""Wait for the event to complete.
217
+
218
+ Waits until the completion of all work currently captured in this event.
219
+ This prevents the CPU thread from proceeding until the event completes.
220
+
221
+ .. note:: This is a wrapper around ``cudaEventSynchronize()``: see
222
+ `CUDA Event documentation`_ for more info.
223
+ """
224
+ super().synchronize()
225
+
226
+ def ipc_handle(self):
227
+ r"""Return an IPC handle of this event.
228
+
229
+ If not recorded yet, the event will use the current device.
230
+ """
231
+ return super().ipc_handle()
232
+
233
+ @property
234
+ def _as_parameter_(self):
235
+ return ctypes.c_void_p(self.cuda_event)
236
+
237
+ def __repr__(self):
238
+ if self.cuda_event:
239
+ return f"<torch.cuda.Event {self._as_parameter_.value:#x}>"
240
+ else:
241
+ return "<torch.cuda.Event uninitialized>"
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.31 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc ADDED
Binary file (9.08 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc ADDED
Binary file (86.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc ADDED
Binary file (29.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc ADDED
Binary file (23.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ATen.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if !defined(_MSC_VER) && __cplusplus < 201703L
4
+ #error C++17 or later compatible compiler is required to use ATen.
5
+ #endif
6
+
7
+ #include <ATen/Context.h>
8
+ #include <ATen/Device.h>
9
+ #include <ATen/DeviceGuard.h>
10
+ #include <ATen/DimVector.h>
11
+ #include <ATen/Dispatch.h>
12
+ #include <ATen/Formatting.h>
13
+ #include <ATen/Functions.h>
14
+ #include <ATen/NamedTensor.h>
15
+ #include <ATen/ScalarOps.h>
16
+ #include <ATen/Tensor.h>
17
+ #include <ATen/TensorGeometry.h>
18
+ #include <ATen/TensorIndexing.h>
19
+ #include <ATen/TensorOperators.h>
20
+ #include <ATen/Version.h>
21
+ #include <ATen/core/ATenGeneral.h>
22
+ #include <ATen/core/Generator.h>
23
+ #include <ATen/core/Reduction.h>
24
+ #include <ATen/core/Scalar.h>
25
+ #include <ATen/core/UnsafeFromTH.h>
26
+ #include <ATen/core/ivalue.h>
27
+ #include <ATen/core/jit_type.h>
28
+ #include <c10/core/Allocator.h>
29
+ #include <c10/core/InferenceMode.h>
30
+ #include <c10/core/Layout.h>
31
+ #include <c10/core/Storage.h>
32
+ #include <c10/core/TensorOptions.h>
33
+ #include <c10/util/Exception.h>
34
+
35
+ // TODO: try to remove this
36
+ // There is some back story, see https://github.com/pytorch/pytorch/issues/48684
37
+ #include <ATen/NativeFunctions.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/AccumulateType.h ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/Config.h>
3
+ #include <c10/core/DeviceType.h>
4
+ #include <c10/core/ScalarType.h>
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Float8_e4m3fn.h>
7
+ #include <c10/util/Float8_e4m3fnuz.h>
8
+ #include <c10/util/Float8_e5m2.h>
9
+ #include <c10/util/Float8_e5m2fnuz.h>
10
+ #include <c10/util/Half.h>
11
+
12
+ // Defines the accumulation type for a scalar type.
13
+ // Example:
14
+ // using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>;
15
+ //
16
+ // Accumulation types are an important concept in numeric computing
17
+ // because you frequently want to perform intermediate computations
18
+ // at a higher precision than the input and output precision, to avoid
19
+ // compounding internal rounding errors. Accumulation is the most
20
+ // well-known intermediate computation (it is of great importance for
21
+ // sum reduction and matrix multiply, for example), but in PyTorch
22
+ // acc_type ends up getting used for all sorts of other intermediate
23
+ // computations, so it perhaps would be more accurately (ahem) called an
24
+ // "accurate" type. acc_type is especially important for reduced
25
+ // precision operations like float16 and bfloat16, where relatively
26
+ // benign looking inputs can easily end up overflowing/underflowing.
27
+ //
28
+ // acc_type is parametrized by whether or not you are running on CUDA
29
+ // or not, because on CUDA double precision operations are expensive
30
+ // and so by default, we don't actually want to use double as an
31
+ // acc_type on CUDA. A lot of things are typed out below, but
32
+ // basically, the table is generated by a few rules:
33
+ //
34
+ // If bool:
35
+ // Use 'bool' as acc_type.
36
+ // If floating point:
37
+ // If CUDA, use 'float' as acc_type (unless scalar_t is double),
38
+ // otherwise (CPU) use 'double'
39
+ // If integral:
40
+ // Use 'int64_t' as acc_type
41
+ //
42
+ // You're not forced to use this template; if you happen to know
43
+ // something specific about your use case, you can specify your own
44
+ // desired behavior. This template, however, will give you a reasonable
45
+ // default that will work for all dtypes supported in PyTorch.
46
+
47
+ #if defined(__CUDACC__)
48
+ #include <cuda.h>
49
+ #include <cuda_fp16.h>
50
+ #elif defined(__HIPCC__)
51
+ #include <hip/hip_fp16.h>
52
+ #include <hip/hip_runtime.h>
53
+ #endif
54
+
55
+ namespace at {
56
+
57
+ template <typename T, c10::DeviceType D>
58
+ struct AccumulateTypeDevice {};
59
+
60
+ template <typename T, bool>
61
+ struct AccumulateType {};
62
+
63
+ template <typename T>
64
+ struct AccumulateType<T, false> {
65
+ using type = typename AccumulateTypeDevice<T, c10::DeviceType::CPU>::type;
66
+ };
67
+
68
+ template <typename T>
69
+ struct AccumulateType<T, true> {
70
+ using type = typename AccumulateTypeDevice<T, c10::DeviceType::CUDA>::type;
71
+ };
72
+
73
+ template <typename T, c10::DeviceType device>
74
+ using acc_type_device = typename AccumulateTypeDevice<T, device>::type;
75
+
76
+ template <typename T, bool is_cuda>
77
+ using acc_type = typename AccumulateType<T, is_cuda>::type;
78
+
79
+ #define ACC_TYPE(t, acc_t, device_type) \
80
+ template <> \
81
+ struct AccumulateTypeDevice<t, device_type> { \
82
+ using type = acc_t; \
83
+ };
84
+ #define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
85
+ #define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
86
+ #define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
87
+
88
+ MPS_ACC_TYPE(BFloat16, float);
89
+ MPS_ACC_TYPE(Half, float);
90
+ MPS_ACC_TYPE(Float8_e5m2, float);
91
+ MPS_ACC_TYPE(Float8_e4m3fn, float);
92
+ MPS_ACC_TYPE(Float8_e5m2fnuz, float);
93
+ MPS_ACC_TYPE(Float8_e4m3fnuz, float);
94
+ MPS_ACC_TYPE(float, float);
95
+ MPS_ACC_TYPE(double, float);
96
+ MPS_ACC_TYPE(int8_t, int64_t);
97
+ MPS_ACC_TYPE(uint8_t, int64_t);
98
+ MPS_ACC_TYPE(char, int64_t);
99
+ MPS_ACC_TYPE(int16_t, int64_t);
100
+ MPS_ACC_TYPE(int32_t, int64_t);
101
+ MPS_ACC_TYPE(int64_t, int64_t);
102
+ MPS_ACC_TYPE(bool, bool);
103
+ MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
104
+ MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
105
+ MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
106
+
107
+ #if defined(__CUDACC__) || defined(__HIPCC__)
108
+ CUDA_ACC_TYPE(half, float);
109
+ #endif
110
+ CUDA_ACC_TYPE(BFloat16, float);
111
+ CUDA_ACC_TYPE(Half, float);
112
+ CUDA_ACC_TYPE(Float8_e5m2, float);
113
+ CUDA_ACC_TYPE(Float8_e4m3fn, float);
114
+ CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
115
+ CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
116
+ CUDA_ACC_TYPE(float, float);
117
+ CUDA_ACC_TYPE(double, double);
118
+ CUDA_ACC_TYPE(int8_t, int64_t);
119
+ CUDA_ACC_TYPE(uint8_t, int64_t);
120
+ CUDA_ACC_TYPE(char, int64_t);
121
+ CUDA_ACC_TYPE(int16_t, int64_t);
122
+ CUDA_ACC_TYPE(int32_t, int64_t);
123
+ CUDA_ACC_TYPE(int64_t, int64_t);
124
+ CUDA_ACC_TYPE(bool, bool);
125
+ CUDA_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
126
+ CUDA_ACC_TYPE(c10::complex<float>, c10::complex<float>);
127
+ CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
128
+
129
+ CPU_ACC_TYPE(BFloat16, float);
130
+ CPU_ACC_TYPE(Half, float);
131
+ CPU_ACC_TYPE(Float8_e5m2, float);
132
+ CPU_ACC_TYPE(Float8_e4m3fn, float);
133
+ CPU_ACC_TYPE(Float8_e5m2fnuz, float);
134
+ CPU_ACC_TYPE(Float8_e4m3fnuz, float);
135
+ CPU_ACC_TYPE(float, double);
136
+ CPU_ACC_TYPE(double, double);
137
+ CPU_ACC_TYPE(int8_t, int64_t);
138
+ CPU_ACC_TYPE(uint8_t, int64_t);
139
+ CPU_ACC_TYPE(char, int64_t);
140
+ CPU_ACC_TYPE(int16_t, int64_t);
141
+ CPU_ACC_TYPE(int32_t, int64_t);
142
+ CPU_ACC_TYPE(int64_t, int64_t);
143
+ CPU_ACC_TYPE(bool, bool);
144
+ CPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
145
+ CPU_ACC_TYPE(c10::complex<float>, c10::complex<double>);
146
+ CPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
147
+
148
+ TORCH_API c10::ScalarType toAccumulateType(
149
+ c10::ScalarType type,
150
+ c10::DeviceType device);
151
+ TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
152
+
153
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Backend.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <c10/core/Backend.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFixedAllocator.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Allocator.h>
4
+ #include <c10/util/Exception.h>
5
+
6
+ // This file creates a fake allocator that just throws exceptions if
7
+ // it is actually used.
8
+
9
+ // state passed to the allocator is the std::function<void(void*)> called
10
+ // when the blob is release by ATen
11
+
12
+ namespace at {
13
+
14
+ static cpu_fixed_malloc(void*, ptrdiff_t) {
15
+ AT_ERROR("attempting to resize a tensor view of an external blob");
16
+ }
17
+
18
+ static cpu_fixed_realloc(void*, void*, ptrdiff_t) {
19
+ AT_ERROR("attempting to resize a tensor view of an external blob");
20
+ }
21
+
22
+ static cpu_fixed_free(void* state, void* allocation) {
23
+ auto on_release = static_cast<std::function<void(void*)>*>(state);
24
+ (*on_release)(allocation);
25
+ delete on_release;
26
+ }
27
+
28
+ static Allocator CPU_fixed_allocator = {
29
+ cpu_fixed_malloc,
30
+ cpu_fixed_realloc,
31
+ cpu_fixed_free};
32
+
33
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CollapseDims.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <c10/util/Exception.h>
2
+ #include <utility>
3
+
4
+ namespace at {
5
+
6
+ /*
7
+ [collapse dims] Updates sizes, and strides to reflect a "collapse" of
8
+ the info, possibly excluding the optional excludeDim. A "collapsed" version
9
+ of the info is the fewest dims that order the tensor's elements in the same
10
+ way as the original info. If excludeDim is specified, the collapse is the
11
+ fewest dims that order the tensor's elements as the original and preserve the
12
+ excluded dimension, unless the tensor collapses to a point.
13
+
14
+ This function returns a pair of values.
15
+
16
+ 1) The (new) index of the preserved dimension if excludeDim is
17
+ specified. 0 if the tensor is collapsed to a point. -1
18
+ otherwise.
19
+
20
+ 2) The new number of dimensions.
21
+ */
22
+ template <typename T>
23
+ inline std::pair<int64_t, int64_t> collapse_dims(
24
+ T* sizes,
25
+ T* strides,
26
+ int64_t dims,
27
+ const int excludeDim = -1) {
28
+ TORCH_CHECK(
29
+ excludeDim >= -1 && excludeDim < dims,
30
+ "expected excluded dim between -1 and dims - 1");
31
+
32
+ int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
33
+ int64_t newIndex = -1;
34
+ int64_t oldIndex = 0;
35
+ int64_t remappedExcludedDim = -1;
36
+
37
+ while (oldIndex < dims) {
38
+ // Finds a dimension to collapse into
39
+ for (; oldIndex < stopDim; ++oldIndex) {
40
+ if (sizes[oldIndex] == 1) {
41
+ continue;
42
+ }
43
+
44
+ ++newIndex;
45
+ sizes[newIndex] = sizes[oldIndex];
46
+ strides[newIndex] = strides[oldIndex];
47
+ ++oldIndex;
48
+ break;
49
+ }
50
+
51
+ // Collapses dims
52
+ for (; oldIndex < stopDim; ++oldIndex) {
53
+ if (sizes[oldIndex] == 1) {
54
+ continue;
55
+ }
56
+
57
+ if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
58
+ sizes[newIndex] *= sizes[oldIndex];
59
+ strides[newIndex] = strides[oldIndex];
60
+ } else {
61
+ ++newIndex;
62
+ sizes[newIndex] = sizes[oldIndex];
63
+ strides[newIndex] = strides[oldIndex];
64
+ }
65
+ }
66
+
67
+ // Handles excludeDim being set (oldIndex == excludeDim)
68
+ if (oldIndex != dims) {
69
+ // Preserves excluded dimension
70
+ ++newIndex;
71
+ sizes[newIndex] = sizes[oldIndex];
72
+ strides[newIndex] = strides[oldIndex];
73
+ remappedExcludedDim = newIndex;
74
+
75
+ // Restarts iteration after excludeDim
76
+ ++oldIndex;
77
+ stopDim = dims;
78
+ }
79
+ }
80
+
81
+ // Handles special case of all dims size 1
82
+ if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
83
+ dims = 1;
84
+ sizes[0] = 1;
85
+ strides[0] = 1;
86
+
87
+ return std::pair<int64_t, int64_t>(0, 1);
88
+ }
89
+
90
+ dims = newIndex + 1;
91
+ return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
92
+ }
93
+
94
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBody.h>
2
+
3
+ // TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
4
+ // Code introduced to avoid cyclic dependency in static dispatch is no longer
5
+ // needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
6
+ // to Operators.cpp for supporting multiple backends with multiple kernels.
7
+ //
8
+ // Note [Avoiding Include Cycles In Static Dispatch]
9
+ // In order to avoid #include cycles in the static dispatch build, we've carefully split out
10
+ // the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
11
+ //
12
+ // Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
13
+ // - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
14
+ // all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
15
+ // directly inlined into TensorBody.h.
16
+ // - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
17
+ // which include functions that have defaultable optional<Tensor> arguments.
18
+ // That requires knowing the full Tensor class definition.
19
+ //
20
+ // We break the cycle by doing the following:
21
+ // - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
22
+ // - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
23
+ // - CPUFunctions_inl.h includes everything else
24
+ // - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
25
+ // and then it includes CPUFunctions_inl.h.
26
+ // - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
27
+ // - This also means that static dispatch build, CPUFunctions.h only needs to
28
+ // #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
29
+ #include <ATen/CompositeImplicitAutogradFunctions_inl.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ // @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
3
+
4
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
5
+
6
+ // The only #includes we need are for custom classes that have defaults in the C++ API
7
+ #include <c10/core/MemoryFormat.h>
8
+ #include <c10/core/Scalar.h>
9
+ #include <ATen/core/Reduction.h>
10
+
11
+ #if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
12
+ #error This change adds a dependency on all pytorch operators, meaning the \
13
+ file will need to be re-compiled every time an operator is changed or added. \
14
+ Consider including a specific operator from \
15
+ <ATen/ops/{my_operator}_compositeimplicitautogradnestedtensor_dispatch.h>. \
16
+ See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
17
+ #endif
18
+
19
+ #include <ATen/ops/randn_like_compositeimplicitautogradnestedtensor_dispatch.h>
20
+ #include <ATen/ops/reshape_compositeimplicitautogradnestedtensor_dispatch.h>
21
+ #include <ATen/ops/reshape_as_compositeimplicitautogradnestedtensor_dispatch.h>
22
+ #include <ATen/ops/zeros_like_compositeimplicitautogradnestedtensor_dispatch.h>
23
+
24
+
25
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/DeprecatedTypeProperties.h>
4
+ #include <c10/macros/Macros.h>
5
+ #include <c10/util/Exception.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/util/Metaprogramming.h>
8
+ #include <c10/util/complex.h>
9
+ #include <c10/util/string_view.h>
10
+
11
+ #ifdef __CUDACC__
12
+ #include <cuda.h> // For CUDA_VERSION
13
+ #endif
14
+
15
+ #ifdef TEMPLATE_SELECTIVE_BUILD
16
+ #include <ATen/selected_mobile_ops.h>
17
+ #else
18
+ namespace at {
19
+ /**
20
+ * The method should_include_kernel_dtype() returns true/false
21
+ * based on whether the switching code for a specific dtype should be
22
+ * included based on build time constants generated from tracing model
23
+ * execution. This method will be implmeneted via code-generation and
24
+ * included in this file when code-gen is ready.
25
+ */
26
+ inline constexpr bool should_include_kernel_dtype(
27
+ const char* /*kernel_tag_str*/,
28
+ at::ScalarType /*scalar_type*/
29
+ ) {
30
+ return true;
31
+ }
32
+ } // namespace at
33
+ #endif
34
+
35
+ /**
36
+ * In the Facebook internal build (using BUCK), this macro is enabled by
37
+ * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
38
+ * binary.
39
+ */
40
+ #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
41
+ namespace at {
42
+ namespace detail {
43
+ TORCH_API void record_kernel_function_dtype(std::string name);
44
+ }
45
+ } // namespace at
46
+
47
+ #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
48
+ at::detail::record_kernel_function_dtype( \
49
+ std::string(NAME) + "$" + toString(enum_type));
50
+ #else
51
+ #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
52
+ #endif
53
+
54
+ #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
55
+ do { \
56
+ if constexpr (!at::should_include_kernel_dtype( \
57
+ at_dispatch_name, enum_type)) { \
58
+ AT_ERROR( \
59
+ "dtype '", \
60
+ toString(enum_type), \
61
+ "' not selected for kernel tag ", \
62
+ at_dispatch_name); \
63
+ } \
64
+ } while (0)
65
+
66
+ #define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
67
+ case enum_type: { \
68
+ AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
69
+ using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
70
+ return __VA_ARGS__(); \
71
+ }
72
+
73
+ #define AT_DISPATCH_CASE(enum_type, ...) \
74
+ AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
75
+
76
+ #define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
77
+ case enum_type: { \
78
+ AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
79
+ using scalar_t = scalar_type; \
80
+ using underlying_t C10_UNUSED = typename scalar_t::underlying; \
81
+ const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
82
+ const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
83
+ return __VA_ARGS__(); \
84
+ }
85
+
86
+ #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
87
+ enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
88
+ case enum_type: { \
89
+ AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
90
+ using scalar_t = scalar_type; \
91
+ using underlying_t C10_UNUSED = typename scalar_t::underlying; \
92
+ const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
93
+ const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
94
+ C10_UNUSED int bit_width = bitwidth; \
95
+ C10_UNUSED int64_t quant_min = qmin; \
96
+ C10_UNUSED int64_t quant_max = qmax; \
97
+ return __VA_ARGS__(); \
98
+ }
99
+
100
+ namespace detail {
101
+
102
+ inline at::ScalarType scalar_type(at::ScalarType s) {
103
+ return s;
104
+ }
105
+
106
+ C10_DEPRECATED_MESSAGE(
107
+ "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
108
+ "pass an at::ScalarType instead")
109
+ inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
110
+ return t.scalarType();
111
+ }
112
+
113
+ C10_DEPRECATED_MESSAGE(
114
+ "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
115
+ "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
116
+ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
117
+
118
+ C10_DEPRECATED_MESSAGE(
119
+ "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
120
+ "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
121
+ "instead")
122
+ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
123
+
124
+ } // namespace detail
125
+
126
+ // The AT_DISPATCH_* family of macros provides the ability to
127
+ // conveniently generate specializations of a kernel over all of the
128
+ // dtypes we care about in PyTorch. We call it "dispatch" because
129
+ // we are "dispatching" to the correct, dtype-specific kernel.
130
+ //
131
+ // A standard usage looks like:
132
+ //
133
+ // AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
134
+ // // Your code here, with 'scalar_t' now defined to
135
+ // // be the dtype in question
136
+ // });
137
+ //
138
+ // There are many variations of this macro, so it's important to
139
+ // understand exactly /which/ dtypes you want to get instantiated, as
140
+ // well as what the "default" set is.
141
+ //
142
+ // The default set of dtypes that are instantiated (e.g., by
143
+ // AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
144
+ // and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
145
+ // but NOT booleans (bool), half-precision floats (Half) or
146
+ // complex number (c10::complex<float>, c10::complex<double>).
147
+ // This "cut" is somewhat historical (the default types are the
148
+ // ones that TH historically supported), but it also reflects the
149
+ // fact that the non-default types are "poorly" behaved (booleans
150
+ // are NOT integers mod 2, half precision operations ~essentially
151
+ // don't exist on CPU, complex numbers are an experimental application).
152
+ //
153
+ // Here are the questions you should generally ask to decide which
154
+ // dispatch you want:
155
+ //
156
+ // 1. Is this an integral or floating point specific operation?
157
+ // (If so, you'll want one of the FLOATING or INTEGRAL macros.)
158
+ //
159
+ // 2. Should half be supported? (If you're on CPU, the answer is almost
160
+ // definitely no. If you do want support, use one of the AND_HALF
161
+ // macros)
162
+ //
163
+ // Much rarer situations:
164
+ //
165
+ // 3. Should bool be supported? (You often have to write your kernel
166
+ // differently if arithmetic operations are involved.) If so,
167
+ // Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
168
+ //
169
+ // 4. Should complex be supported? The answer is almost always no,
170
+ // unless you are working on "generic" code that should work on
171
+ // all dtypes.
172
+ //
173
+ // Parameters:
174
+ // -----------
175
+ //
176
+ // 1. The NAME argument is a "tag" that is used to trace and then
177
+ // conditionally compile fragments of the case statements such
178
+ // that the kernel functions are specialized only for the dtypes
179
+ // that are needed. The NAME parameter *must* be a build time
180
+ // const char* (can't be std::string, etc...)
181
+ //
182
+ // Please ensure that the NAME is unique for every implementation
183
+ // or you run the risk of over-including code for the kernel
184
+ // functions. There is no risk of missing out on any code, so
185
+ // it's mostly a risk of a Type-2 error, and not a Type-1 error.
186
+ //
187
+ // Switch-like syntax:
188
+ // -------------------
189
+ // There is also a switch-case like syntax which is useful if a kernel
190
+ // needs to be specialized for particular scalar types
191
+ //
192
+ // AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
193
+ // AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
194
+ // op_integral<scalar_t>(iter);
195
+ // })
196
+ // AT_DISPATCH_CASE_FLOATING_TYPES([&] {
197
+ // op_floating<scalar_t>(iter);
198
+ // })
199
+ // AT_DISPATCH_CASE(kBool, [&] {
200
+ // op_bool(iter);
201
+ // })
202
+ // );
203
+ //
204
+ // For each AT_DISPATCH_FOO macro, there is a corresponding
205
+ // AT_DISPATCH_CASE_FOO macro which can be used inside of an
206
+ // AT_DISPATCH_SWITCH block.
207
+
208
+ // NB: the the_type variable is not used, but we have kept it for
209
+ // backwards compatibility. It's probably not used by anyone though;
210
+ // but we're just being safe (and it doesn't hurt.) Note we must
211
+ // use it to shut up warnings about unused store.
212
+
213
+ #define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
214
+ [&] { \
215
+ const auto& the_type = TYPE; \
216
+ constexpr const char* at_dispatch_name = NAME; \
217
+ /* don't use TYPE again in case it is an expensive or side-effect op */ \
218
+ at::ScalarType _st = ::detail::scalar_type(the_type); \
219
+ RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
220
+ switch (_st) { \
221
+ __VA_ARGS__ \
222
+ default: \
223
+ AT_ERROR( \
224
+ '"', \
225
+ at_dispatch_name, \
226
+ "\" not implemented for '", \
227
+ toString(_st), \
228
+ "'"); \
229
+ } \
230
+ }()
231
+
232
+ #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
233
+ AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
234
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
235
+
236
+ #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
237
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
238
+
239
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
240
+ AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
241
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
242
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
243
+
244
+ #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
245
+ AT_DISPATCH_SWITCH( \
246
+ TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
247
+
248
+ #define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
249
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
250
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
251
+
252
+ #define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
253
+ AT_DISPATCH_SWITCH( \
254
+ TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
255
+
256
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
257
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
258
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
259
+
260
+ #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
261
+ AT_DISPATCH_SWITCH( \
262
+ TYPE, \
263
+ NAME, \
264
+ AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
265
+
266
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
267
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
268
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
269
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
270
+
271
+ #define AT_DISPATCH_FLOATING_TYPES_AND2( \
272
+ SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
273
+ AT_DISPATCH_SWITCH( \
274
+ TYPE, \
275
+ NAME, \
276
+ AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
277
+ SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
278
+
279
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
280
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
281
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
282
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
283
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
284
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
285
+
286
+ #define AT_DISPATCH_FLOATING_TYPES_AND3( \
287
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
288
+ AT_DISPATCH_SWITCH( \
289
+ TYPE, \
290
+ NAME, \
291
+ AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
292
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
293
+
294
+ #define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
295
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
296
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
297
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
298
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
299
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
300
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
301
+
302
+ #define AT_DISPATCH_FLOATING_TYPES_AND4( \
303
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
304
+ AT_DISPATCH_SWITCH( \
305
+ TYPE, \
306
+ NAME, \
307
+ AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
308
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
309
+
310
+ #define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
311
+ AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
312
+ AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
313
+
314
+ #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
315
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
316
+
317
+ #define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
318
+ AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
319
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
320
+
321
+ #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
322
+ AT_DISPATCH_SWITCH( \
323
+ TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
324
+
325
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
326
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
327
+ AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
328
+
329
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
330
+ AT_DISPATCH_SWITCH( \
331
+ TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
332
+
333
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
334
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
335
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
336
+
337
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
338
+ SCALARTYPE, TYPE, NAME, ...) \
339
+ AT_DISPATCH_SWITCH( \
340
+ TYPE, \
341
+ NAME, \
342
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
343
+ SCALARTYPE, __VA_ARGS__))
344
+
345
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
346
+ SCALARTYPE1, SCALARTYPE2, ...) \
347
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
348
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
349
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
350
+
351
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
352
+ SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
353
+ AT_DISPATCH_SWITCH( \
354
+ TYPE, \
355
+ NAME, \
356
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
357
+ SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
358
+
359
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
360
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
361
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
362
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
363
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
364
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
365
+
366
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
367
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
368
+ AT_DISPATCH_SWITCH( \
369
+ TYPE, \
370
+ NAME, \
371
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
372
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
373
+
374
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
375
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
376
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
377
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
378
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
379
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
380
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
381
+
382
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
383
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
384
+ AT_DISPATCH_SWITCH( \
385
+ TYPE, \
386
+ NAME, \
387
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
388
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
389
+
390
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
391
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
392
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
393
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
394
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
395
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
396
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
397
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
398
+
399
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \
400
+ SCALARTYPE1, \
401
+ SCALARTYPE2, \
402
+ SCALARTYPE3, \
403
+ SCALARTYPE4, \
404
+ SCALARTYPE5, \
405
+ TYPE, \
406
+ NAME, \
407
+ ...) \
408
+ AT_DISPATCH_SWITCH( \
409
+ TYPE, \
410
+ NAME, \
411
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
412
+ SCALARTYPE1, \
413
+ SCALARTYPE2, \
414
+ SCALARTYPE3, \
415
+ SCALARTYPE4, \
416
+ SCALARTYPE5, \
417
+ __VA_ARGS__))
418
+
419
+ #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
420
+ SCALARTYPE1, \
421
+ SCALARTYPE2, \
422
+ SCALARTYPE3, \
423
+ SCALARTYPE4, \
424
+ SCALARTYPE5, \
425
+ SCALARTYPE6, \
426
+ ...) \
427
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
428
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
429
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
430
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
431
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
432
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
433
+ AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
434
+
435
+ #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
436
+ SCALARTYPE1, \
437
+ SCALARTYPE2, \
438
+ SCALARTYPE3, \
439
+ SCALARTYPE4, \
440
+ SCALARTYPE5, \
441
+ SCALARTYPE6, \
442
+ TYPE, \
443
+ NAME, \
444
+ ...) \
445
+ AT_DISPATCH_SWITCH( \
446
+ TYPE, \
447
+ NAME, \
448
+ AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
449
+ SCALARTYPE1, \
450
+ SCALARTYPE2, \
451
+ SCALARTYPE3, \
452
+ SCALARTYPE4, \
453
+ SCALARTYPE5, \
454
+ SCALARTYPE6, \
455
+ __VA_ARGS__))
456
+
457
+ #define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
458
+ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
459
+ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
460
+ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
461
+ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
462
+ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
463
+
464
+ #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
465
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
466
+
467
+ #define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
468
+ AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
469
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
470
+
471
+ #define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
472
+ AT_DISPATCH_SWITCH( \
473
+ TYPE, \
474
+ NAME, \
475
+ AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
476
+
477
+ #define AT_DISPATCH_CASE_ALL_TYPES(...) \
478
+ AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
479
+ AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
480
+
481
+ #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
482
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
483
+
484
+ #define AT_DISPATCH_CASE_QINT_TYPES(...) \
485
+ AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
486
+ AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
487
+ AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
488
+
489
+ #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
490
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
491
+
492
+ #define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
493
+ AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
494
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
495
+
496
+ #define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
497
+ AT_DISPATCH_SWITCH( \
498
+ TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
499
+
500
+ #define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
501
+ AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
502
+ AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
503
+
504
+ #define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
505
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
506
+
507
+ #define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
508
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
509
+ at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
510
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
511
+ at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
512
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
513
+ at::kQInt32, \
514
+ at::qint32, \
515
+ CHAR_BIT * sizeof(int), \
516
+ INT_MIN, \
517
+ INT_MAX, \
518
+ __VA_ARGS__) \
519
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
520
+ at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
521
+ AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
522
+ at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
523
+
524
+ #define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
525
+ AT_DISPATCH_SWITCH( \
526
+ TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
527
+
528
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
529
+ AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
530
+ AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
531
+
532
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
533
+ AT_DISPATCH_SWITCH( \
534
+ TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
535
+
536
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
537
+ AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
538
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
539
+
540
+ #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
541
+ AT_DISPATCH_SWITCH( \
542
+ TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
543
+
544
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
545
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
546
+ AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
547
+
548
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
549
+ AT_DISPATCH_SWITCH( \
550
+ TYPE, \
551
+ NAME, \
552
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
553
+
554
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
555
+ AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
556
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
557
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
558
+
559
+ #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
560
+ AT_DISPATCH_SWITCH( \
561
+ TYPE, \
562
+ NAME, \
563
+ AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
564
+
565
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
566
+ SCALARTYPE1, SCALARTYPE2, ...) \
567
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
568
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
569
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
570
+
571
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
572
+ SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
573
+ AT_DISPATCH_SWITCH( \
574
+ TYPE, \
575
+ NAME, \
576
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
577
+ SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
578
+
579
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
580
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
581
+ AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
582
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
583
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
584
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
585
+
586
+ #define AT_DISPATCH_ALL_TYPES_AND3( \
587
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
588
+ AT_DISPATCH_SWITCH( \
589
+ TYPE, \
590
+ NAME, \
591
+ AT_DISPATCH_CASE_ALL_TYPES_AND3( \
592
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
593
+
594
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
595
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
596
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
597
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
598
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
599
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
600
+
601
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
602
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
603
+ AT_DISPATCH_SWITCH( \
604
+ TYPE, \
605
+ NAME, \
606
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
607
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
608
+
609
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
610
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
611
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
612
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
613
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
614
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
615
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
616
+
617
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
618
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
619
+ AT_DISPATCH_SWITCH( \
620
+ TYPE, \
621
+ NAME, \
622
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
623
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
624
+
625
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
626
+ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
627
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
628
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
629
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
630
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
631
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
632
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
633
+
634
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
635
+ SCALARTYPE1, \
636
+ SCALARTYPE2, \
637
+ SCALARTYPE3, \
638
+ SCALARTYPE4, \
639
+ SCALARTYPE5, \
640
+ TYPE, \
641
+ NAME, \
642
+ ...) \
643
+ AT_DISPATCH_SWITCH( \
644
+ TYPE, \
645
+ NAME, \
646
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
647
+ SCALARTYPE1, \
648
+ SCALARTYPE2, \
649
+ SCALARTYPE3, \
650
+ SCALARTYPE4, \
651
+ SCALARTYPE5, \
652
+ __VA_ARGS__))
653
+
654
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
655
+ SCALARTYPE1, \
656
+ SCALARTYPE2, \
657
+ SCALARTYPE3, \
658
+ SCALARTYPE4, \
659
+ SCALARTYPE5, \
660
+ SCALARTYPE6, \
661
+ ...) \
662
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
663
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
664
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
665
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
666
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
667
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
668
+ AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
669
+
670
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
671
+ SCALARTYPE1, \
672
+ SCALARTYPE2, \
673
+ SCALARTYPE3, \
674
+ SCALARTYPE4, \
675
+ SCALARTYPE5, \
676
+ SCALARTYPE6, \
677
+ TYPE, \
678
+ NAME, \
679
+ ...) \
680
+ AT_DISPATCH_SWITCH( \
681
+ TYPE, \
682
+ NAME, \
683
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
684
+ SCALARTYPE1, \
685
+ SCALARTYPE2, \
686
+ SCALARTYPE3, \
687
+ SCALARTYPE4, \
688
+ SCALARTYPE5, \
689
+ SCALARTYPE6, \
690
+ __VA_ARGS__))
691
+
692
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
693
+ SCALARTYPE1, \
694
+ SCALARTYPE2, \
695
+ SCALARTYPE3, \
696
+ SCALARTYPE4, \
697
+ SCALARTYPE5, \
698
+ SCALARTYPE6, \
699
+ SCALARTYPE7, \
700
+ ...) \
701
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
702
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
703
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
704
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
705
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
706
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
707
+ AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
708
+ AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
709
+
710
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
711
+ SCALARTYPE1, \
712
+ SCALARTYPE2, \
713
+ SCALARTYPE3, \
714
+ SCALARTYPE4, \
715
+ SCALARTYPE5, \
716
+ SCALARTYPE6, \
717
+ SCALARTYPE7, \
718
+ TYPE, \
719
+ NAME, \
720
+ ...) \
721
+ AT_DISPATCH_SWITCH( \
722
+ TYPE, \
723
+ NAME, \
724
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
725
+ SCALARTYPE1, \
726
+ SCALARTYPE2, \
727
+ SCALARTYPE3, \
728
+ SCALARTYPE4, \
729
+ SCALARTYPE5, \
730
+ SCALARTYPE6, \
731
+ SCALARTYPE7, \
732
+ __VA_ARGS__))
733
+
734
+ #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
735
+ SCALARTYPE1, \
736
+ SCALARTYPE2, \
737
+ SCALARTYPE3, \
738
+ SCALARTYPE4, \
739
+ SCALARTYPE5, \
740
+ SCALARTYPE6, \
741
+ SCALARTYPE7, \
742
+ SCALARTYPE8, \
743
+ ...) \
744
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
745
+ AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
746
+ AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
747
+ AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
748
+ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
749
+ AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
750
+ AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
751
+ AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \
752
+ AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
753
+
754
+ #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
755
+ SCALARTYPE1, \
756
+ SCALARTYPE2, \
757
+ SCALARTYPE3, \
758
+ SCALARTYPE4, \
759
+ SCALARTYPE5, \
760
+ SCALARTYPE6, \
761
+ SCALARTYPE7, \
762
+ SCALARTYPE8, \
763
+ TYPE, \
764
+ NAME, \
765
+ ...) \
766
+ AT_DISPATCH_SWITCH( \
767
+ TYPE, \
768
+ NAME, \
769
+ AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
770
+ SCALARTYPE1, \
771
+ SCALARTYPE2, \
772
+ SCALARTYPE3, \
773
+ SCALARTYPE4, \
774
+ SCALARTYPE5, \
775
+ SCALARTYPE6, \
776
+ SCALARTYPE7, \
777
+ SCALARTYPE8, \
778
+ __VA_ARGS__))
779
+
780
+ #define AT_DISPATCH_CASE_BIT_TYPES(...) \
781
+ AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
782
+ AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
783
+ AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
784
+ AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \
785
+ AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
786
+
787
+ #define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
788
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
789
+
790
+ #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
791
+ AT_DISPATCH_SWITCH( \
792
+ TYPE, \
793
+ NAME, \
794
+ AT_PRIVATE_CASE_TYPE_USING_HINT( \
795
+ at::ScalarType::Int, index_t, __VA_ARGS__) \
796
+ AT_PRIVATE_CASE_TYPE_USING_HINT( \
797
+ at::ScalarType::Long, index_t, __VA_ARGS__))
798
+
799
+ // ----------------------------------------------------------------------------
800
+ // DEPRECATED MACROS, DON'T USE THESE
801
+ // ----------------------------------------------------------------------------
802
+
803
+ #define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
804
+ detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
805
+ AT_DISPATCH_SWITCH( \
806
+ TYPE, \
807
+ NAME, \
808
+ AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch_v2.h ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/Dispatch.h>
2
+
3
+ // This is a new implementation of the AT_DISPATCH macro family from
4
+ // ATen/Dispatch.h
5
+ //
6
+ // The intended usage is:
7
+ //
8
+ // ScalarType scalar_type;
9
+ //
10
+ // AT_DISPATCH_V2(
11
+ // scalar_type,
12
+ // "debug string",
13
+ // AT_WRAP([&] {
14
+ // ... code to specialize with scalar_t ...
15
+ // }),
16
+ // kHalf,
17
+ // AT_EXPAND(AT_ALL_TYPES),
18
+ // ... as many types arguments as needed ...
19
+ // )
20
+ //
21
+ // For example, given an old style:
22
+ //
23
+ // AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
24
+ // kComplexHalf,
25
+ // kHalf,
26
+ // self.scalar_type(),
27
+ // "_local_scalar_dense_cpu",
28
+ // [&] {
29
+ // scalar_t value = *self.data_ptr<scalar_t>();
30
+ // r = Scalar(value);
31
+ // }
32
+ // )
33
+ //
34
+ // You now write:
35
+ //
36
+ // AT_DISPATCH_V2(
37
+ // self.scalar_type(),
38
+ // "_local_scalar_dense_cpu",
39
+ // AT_WRAP([&] {
40
+ // scalar_t value = *self.data_ptr<scalar_t>();
41
+ // r = Scalar(value);
42
+ // }),
43
+ // AT_EXPAND(AT_ALL_TYPES),
44
+ // AT_EXPAND(AT_COMPLEX_TYPES),
45
+ // kComplexHalf,
46
+ // kHalf,
47
+ // )
48
+ //
49
+ // Notably, it sports the following improvements:
50
+ //
51
+ // - It is not necessary to specify the arity (e.g.,
52
+ // AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3,4,...})
53
+ // when using the macro
54
+ //
55
+ // - It is not necessary to specify each dtype individually; if
56
+ // there is a set of related dtypes and you want to dispatch
57
+ // over all of them, you can simply say, e.g., AT_EXPAND(AT_INTEGRAL_TYPES)
58
+ // in your argument list.
59
+ //
60
+ // However, you must remember to wrap the payload body in AT_WRAP, or commas
61
+ // inside your lambda will be improperly handled. Furthermore, if you more
62
+ // entries to ScalarType than can be supported by this macro, it will fail
63
+ // with an obscure error (due to attempting to concatenate AT_AP with
64
+ // something that is not a number).
65
+ //
66
+ // The implementation strategy is to use the count arguments trick
67
+ // (e.g., as described in https://stackoverflow.com/a/2124385/23845)
68
+ // to discover how many dtypes have been passed, and then dispatch to a
69
+ // hand-written macro for each arity that applies as many DISPATCH_CASE as
70
+ // necessary. The hand-written macros can be regenerated for other arities
71
+ // with the script below.
72
+ //
73
+ // There is some delicacy in the implementation in controlling when
74
+ // macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
75
+ // relied on GPT4 to help me get it right.
76
+
77
+ // Public API macros
78
+
79
+ // See documentation above
80
+ #define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
81
+ AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
82
+
83
+ // This macro lets you pass an arbitrary expression that may contain internal
84
+ // commas to another macro without having the commas causing the expression
85
+ // to be interpreted as being multiple arguments
86
+ #define AT_WRAP(...) __VA_ARGS__
87
+
88
+ #define AT_FLOAT8_TYPES \
89
+ c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
90
+ c10::kFloat8_e4m3fnuz
91
+
92
+ #define AT_INTEGRAL_TYPES \
93
+ c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
94
+ #define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
95
+ #define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
96
+ #define AT_INTEGRAL_TYPES_V2 \
97
+ AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
98
+ #define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
99
+ #define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
100
+ // NB: not *actually* all types
101
+ #define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
102
+ #define AT_ALL_TYPES_AND_COMPLEX \
103
+ AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
104
+
105
+ // Helper macros
106
+
107
+ #define AT_AP_VAR(N, T, ...) \
108
+ AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
109
+ #define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
110
+ #define AT_CONCAT_AUX(a, b) a##b
111
+ #define AT_EXPAND(X) X
112
+
113
+ // Ensure we never have too many scalar types for the expansion here to
114
+ // support. To bump this, you must regenerate the macros below.
115
+ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 45);
116
+
117
+ // Python code to regenerate generate code below:
118
+ #if 0
119
+
120
+ num_args = 45
121
+
122
+ nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
123
+ args = ', '.join(f'_{i}' for i in range(1, num_args+1))
124
+
125
+ print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
126
+ print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
127
+
128
+ for i in range(1, num_args+1):
129
+ args = ', '.join(f'_{i}' for i in range(1, i+1))
130
+ cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
131
+ print(f'#define AT_AP{i}(N, {args}) {cases}')
132
+
133
+ #endif
134
+
135
+ // Begin generated code
136
+ // clang-format off
137
+
138
+ #define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
139
+ #define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, N, ...) N
140
+ #define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
141
+ #define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
142
+ #define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
143
+ #define AT_AP4(N, _1, _2, _3, _4) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N)
144
+ #define AT_AP5(N, _1, _2, _3, _4, _5) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N)
145
+ #define AT_AP6(N, _1, _2, _3, _4, _5, _6) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N)
146
+ #define AT_AP7(N, _1, _2, _3, _4, _5, _6, _7) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N)
147
+ #define AT_AP8(N, _1, _2, _3, _4, _5, _6, _7, _8) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N)
148
+ #define AT_AP9(N, _1, _2, _3, _4, _5, _6, _7, _8, _9) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N)
149
+ #define AT_AP10(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N)
150
+ #define AT_AP11(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N)
151
+ #define AT_AP12(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N)
152
+ #define AT_AP13(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N)
153
+ #define AT_AP14(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N)
154
+ #define AT_AP15(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N)
155
+ #define AT_AP16(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N)
156
+ #define AT_AP17(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N)
157
+ #define AT_AP18(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N)
158
+ #define AT_AP19(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N)
159
+ #define AT_AP20(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N)
160
+ #define AT_AP21(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N)
161
+ #define AT_AP22(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N)
162
+ #define AT_AP23(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N)
163
+ #define AT_AP24(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N)
164
+ #define AT_AP25(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N)
165
+ #define AT_AP26(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N)
166
+ #define AT_AP27(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N)
167
+ #define AT_AP28(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N)
168
+ #define AT_AP29(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N)
169
+ #define AT_AP30(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N)
170
+ #define AT_AP31(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N)
171
+ #define AT_AP32(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N)
172
+ #define AT_AP33(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N)
173
+ #define AT_AP34(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N)
174
+ #define AT_AP35(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N)
175
+ #define AT_AP36(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N)
176
+ #define AT_AP37(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N)
177
+ #define AT_AP38(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N)
178
+ #define AT_AP39(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N)
179
+ #define AT_AP40(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N)
180
+ #define AT_AP41(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N)
181
+ #define AT_AP42(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N)
182
+ #define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N)
183
+ #define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N)
184
+ #define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N)
185
+ // End generated code
186
+ // clang-format on
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/EmptyTensor.h ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/TensorBase.h>
3
+
4
+ namespace at::detail {
5
+
6
+ inline void check_size_nonnegative(ArrayRef<int64_t> size) {
7
+ for (const auto& x : size) {
8
+ TORCH_CHECK(
9
+ x >= 0,
10
+ "Trying to create tensor with negative dimension ",
11
+ x,
12
+ ": ",
13
+ size);
14
+ }
15
+ }
16
+
17
+ inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
18
+ for (const auto& x : size) {
19
+ TORCH_CHECK(
20
+ x.expect_size(__FILE__, __LINE__),
21
+ "Trying to create tensor with negative dimension ",
22
+ x,
23
+ ": ",
24
+ size);
25
+ }
26
+ }
27
+
28
+ TORCH_API size_t computeStorageNbytesContiguous(
29
+ IntArrayRef sizes,
30
+ size_t itemsize,
31
+ size_t storage_offset = 0);
32
+ TORCH_API SymInt computeStorageNbytesContiguous(
33
+ SymIntArrayRef sizes,
34
+ const SymInt& itemsize,
35
+ const SymInt& storage_offset = 0);
36
+ TORCH_API size_t computeStorageNbytes(
37
+ IntArrayRef sizes,
38
+ IntArrayRef strides,
39
+ size_t itemsize,
40
+ size_t storage_offset = 0);
41
+ TORCH_API SymInt computeStorageNbytes(
42
+ SymIntArrayRef sizes,
43
+ SymIntArrayRef strides,
44
+ const SymInt& itemsize,
45
+ const SymInt& storage_offset = 0);
46
+
47
+ TORCH_API TensorBase empty_generic(
48
+ IntArrayRef size,
49
+ c10::Allocator* allocator,
50
+ c10::DispatchKeySet ks,
51
+ ScalarType scalar_type,
52
+ c10::optional<c10::MemoryFormat> memory_format_opt);
53
+
54
+ TORCH_API TensorBase empty_strided_generic(
55
+ IntArrayRef size,
56
+ IntArrayRef stride,
57
+ c10::Allocator* allocator,
58
+ c10::DispatchKeySet ks,
59
+ ScalarType scalar_type);
60
+
61
+ TORCH_API TensorBase empty_strided_symint_generic(
62
+ SymIntArrayRef size,
63
+ SymIntArrayRef stride,
64
+ c10::Allocator* allocator,
65
+ c10::DispatchKeySet ks,
66
+ ScalarType scalar_type);
67
+
68
+ TORCH_API TensorBase empty_cpu(
69
+ IntArrayRef size,
70
+ ScalarType dtype,
71
+ bool pin_memory = false,
72
+ c10::optional<c10::MemoryFormat> memory_format_opt = c10::nullopt);
73
+
74
+ TORCH_API TensorBase empty_cpu(
75
+ IntArrayRef size,
76
+ c10::optional<ScalarType> dtype_opt,
77
+ c10::optional<Layout> layout_opt,
78
+ c10::optional<Device> device_opt,
79
+ c10::optional<bool> pin_memory_opt,
80
+ c10::optional<c10::MemoryFormat> memory_format_opt);
81
+
82
+ TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options);
83
+
84
+ TORCH_API TensorBase empty_strided_cpu(
85
+ IntArrayRef size,
86
+ IntArrayRef stride,
87
+ ScalarType dtype,
88
+ bool pin_memory = false);
89
+
90
+ TORCH_API TensorBase empty_strided_cpu(
91
+ IntArrayRef size,
92
+ IntArrayRef stride,
93
+ c10::optional<ScalarType> dtype_opt,
94
+ c10::optional<Layout> layout_opt,
95
+ c10::optional<Device> device_opt,
96
+ c10::optional<bool> pin_memory_opt);
97
+
98
+ TORCH_API TensorBase empty_strided_cpu(
99
+ IntArrayRef size,
100
+ IntArrayRef stride,
101
+ const TensorOptions& options);
102
+
103
+ TORCH_API TensorBase empty_meta(
104
+ IntArrayRef size,
105
+ ScalarType dtype,
106
+ c10::optional<c10::MemoryFormat> memory_format_opt = c10::nullopt);
107
+
108
+ TORCH_API TensorBase empty_meta(
109
+ IntArrayRef size,
110
+ c10::optional<ScalarType> dtype_opt,
111
+ c10::optional<Layout> layout_opt,
112
+ c10::optional<Device> device_opt,
113
+ c10::optional<bool> pin_memory_opt,
114
+ c10::optional<c10::MemoryFormat> memory_format_opt);
115
+
116
+ TORCH_API TensorBase empty_symint_meta(
117
+ SymIntArrayRef size,
118
+ c10::optional<ScalarType> dtype_opt,
119
+ c10::optional<Layout> layout_opt,
120
+ c10::optional<Device> device_opt,
121
+ c10::optional<bool> pin_memory_opt,
122
+ c10::optional<c10::MemoryFormat> memory_format_opt);
123
+
124
+ TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options);
125
+
126
+ TORCH_API TensorBase
127
+ empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype);
128
+
129
+ TORCH_API TensorBase empty_strided_meta(
130
+ IntArrayRef size,
131
+ IntArrayRef stride,
132
+ c10::optional<ScalarType> dtype_opt,
133
+ c10::optional<Layout> layout_opt,
134
+ c10::optional<Device> device_opt,
135
+ c10::optional<bool> pin_memory_opt);
136
+
137
+ TORCH_API TensorBase empty_strided_meta(
138
+ IntArrayRef size,
139
+ IntArrayRef stride,
140
+ const TensorOptions& options);
141
+
142
+ TORCH_API TensorBase empty_strided_symint_meta(
143
+ SymIntArrayRef size,
144
+ SymIntArrayRef stride,
145
+ ScalarType dtype);
146
+
147
+ TORCH_API TensorBase empty_strided_symint_meta(
148
+ SymIntArrayRef size,
149
+ SymIntArrayRef stride,
150
+ c10::optional<ScalarType> dtype_opt,
151
+ c10::optional<Layout> layout_opt,
152
+ c10::optional<Device> device_opt,
153
+ c10::optional<bool> pin_memory_opt);
154
+
155
+ TORCH_API TensorBase empty_strided_symint_meta(
156
+ SymIntArrayRef size,
157
+ SymIntArrayRef stride,
158
+ const TensorOptions& options);
159
+
160
+ } // namespace at::detail
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ExpandBase.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBase.h>
2
+
3
+ // Broadcasting utilities for working with TensorBase
4
+ namespace at {
5
+ namespace internal {
6
+ TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
7
+ } // namespace internal
8
+
9
+ inline c10::MaybeOwned<TensorBase> expand_size(
10
+ const TensorBase& self,
11
+ IntArrayRef size) {
12
+ if (size.equals(self.sizes())) {
13
+ return c10::MaybeOwned<TensorBase>::borrowed(self);
14
+ }
15
+ return c10::MaybeOwned<TensorBase>::owned(
16
+ at::internal::expand_slow_path(self, size));
17
+ }
18
+ c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) =
19
+ delete;
20
+
21
+ inline c10::MaybeOwned<TensorBase> expand_inplace(
22
+ const TensorBase& tensor,
23
+ const TensorBase& to_expand) {
24
+ return expand_size(to_expand, tensor.sizes());
25
+ }
26
+ c10::MaybeOwned<TensorBase> expand_inplace(
27
+ const TensorBase& tensor,
28
+ TensorBase&& to_expand) = delete;
29
+
30
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalTensorWrapper.h ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #pragma once
3
+
4
+ #include <ATen/ArrayRef.h>
5
+ #include <ATen/FunctionalStorageImpl.h>
6
+ #include <ATen/core/IListRef.h>
7
+ #include <ATen/core/List.h>
8
+ #include <ATen/core/boxing/BoxedKernel.h>
9
+ #include <ATen/core/boxing/impl/boxing.h>
10
+ #include <ATen/core/dispatch/Dispatcher.h>
11
+
12
+ #include <c10/core/DispatchKey.h>
13
+
14
+ namespace at {
15
+
16
+ // Note [Functionalization Pass In Core]
17
+ // The Functionalization pass is used to remove aliasing from a pytorch program.
18
+ //
19
+ // This is useful for backends that don't support aliasing, like XLA and Vulkan.
20
+ // It's also necessary in order to remove mutation from a program, which is
21
+ // needed in Functorch.
22
+ //
23
+ // Consider this program:
24
+ // a = torch.ones(...)
25
+ // b = a.view(...)
26
+ // b.add_(1)
27
+ //
28
+ // In this program, b is meant to alias with a due to the use of view(). At the
29
+ // end of the program, both a and b are full of 2's. However, backends that
30
+ // don't support aliasing aren't able to correctly implement the view()
31
+ // operator. Instead, they can opt into the Functionalization pass, which will
32
+ // sit between the user and the backend, and provide the necessary aliasing
33
+ // logic.
34
+ //
35
+ // The functionalization pass will turn the above program into a slightly
36
+ // different program that has the same semantics, transparently to the user,
37
+ // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
38
+ // a.view_copy(...) # view() replaced with view_copy(). Backends like
39
+ // XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
40
+ // pass machinery knows that a and b are aliased - it applies b's mutation to a
41
+ // too.
42
+ //
43
+ // So, how does the functionalization pass keep track of which tensors are
44
+ // aliased? The pass works by wrapping EVERY tensor in the program inside of a
45
+ // FunctionalTensorWrapper, which knows about its alias'd tensors.
46
+ //
47
+ // See Note [Functionalization: Alias Removal] for details on the aliasing
48
+ // machinery. See Note [Functionalization: Mutation Removal] for details on
49
+ // mutation removal.
50
+ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
51
+ explicit FunctionalTensorWrapper(const Tensor& value);
52
+ // Additional constructor to create a FunctionalTensorWrapper directly from an
53
+ // underlying tensor that was created from a view. For example, the code b =
54
+ // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
55
+ // view1_meta)
56
+ explicit FunctionalTensorWrapper(
57
+ const Tensor& view_value,
58
+ const FunctionalTensorWrapper* base,
59
+ const functionalization::ViewMeta& meta);
60
+
61
+ // Get the underlying, actual tensor, that doesn't know anything about
62
+ // functionalization.
63
+ const Tensor& value() const {
64
+ return value_;
65
+ };
66
+ // The concept of "level" is only ever important to functorch; it's exposed
67
+ // here as more of a hook for functorch to use.
68
+ int64_t level() const {
69
+ return level_;
70
+ };
71
+ void set_level(int64_t level) {
72
+ level_ = level;
73
+ }
74
+ bool has_metadata_mutation() const {
75
+ return has_metadata_mutation_;
76
+ };
77
+
78
+ // Denotes a mutation that's hidden from autograd,
79
+ // e.g. for the purposes of passing a tensor to a triton kernel
80
+ void mark_mutation_hidden_from_autograd() {
81
+ mutation_hidden_from_autograd_counter_++;
82
+ }
83
+ void mark_mutation_during_no_grad_or_inference_mode() {
84
+ mutation_during_no_grad_or_inference_mode_++;
85
+ }
86
+ // Are all the mutations happening to the tensor hidden from autograd
87
+ bool are_all_mutations_hidden_from_autograd() const {
88
+ return mutation_hidden_from_autograd_counter_ == mutation_counter_;
89
+ }
90
+ // Did all mutations happen under no_grad or inference_mode
91
+ // (We also need to ignore mutations fully hidden from autograd here)
92
+ bool are_all_mutations_under_no_grad_or_inference_mode() const {
93
+ return mutation_hidden_from_autograd_counter_ +
94
+ mutation_during_no_grad_or_inference_mode_ ==
95
+ mutation_counter_;
96
+ }
97
+
98
+ // Sync's the underlying tensor with its alias, if it's out of date. This
99
+ // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
100
+ // Replay the views (if any) to regenerate the current tensor off of the
101
+ // updated alias.
102
+ void sync_();
103
+ // Performs step (1) of the sync. This is its own public API because it's
104
+ // needed by view_inplace ops like transpose_. See Note [Functionalization
105
+ // Pass - Inplace View Ops]
106
+ void regenerate_from_base();
107
+ // Performs step (2) of the sync. This is its own public API because it's
108
+ // needed by functorch. functorch wants to make sure that all input tensors to
109
+ // a functionalized program have been properly synced so it can properly
110
+ // propagate mutations to inputs. It can't just call sync_(), because the
111
+ // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
112
+ // a noop. We use the reference count on storage_ to determine if the wrapper
113
+ // is aliased, and by the time functorch is ready to propagate updates to
114
+ // inputs, any intermediate views of the input created by the program will
115
+ // have been deallocated. This function also returns whether or not the base
116
+ // actually had any updates to apply.
117
+ bool apply_updates();
118
+ // Takes the current state of value_ and snapshots it, sending it as a pending
119
+ // update to the alias.
120
+ void commit_update();
121
+ // When any tensor is mutated, the tensor increments its alias's "generation".
122
+ // Separately, each tensor maintains its own "generation" counter, which is
123
+ // used to determine if it's up-to-date with its alias. The act of syncing a
124
+ // tensor will set a tensor's generation equal to its alias's generation.
125
+ bool is_up_to_date() const;
126
+ // Freezes the storage of this tensor, preventing subsequent mutations
127
+ void freeze_storage() const;
128
+ // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
129
+ // describing the series of view ops that ran to generate the current tensor
130
+ // from the base tensor. This method is used by inplace-view ops like
131
+ // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
132
+ // tensor by replaying the views off of the alias.
133
+ void mutate_view_meta(const at::functionalization::ViewMeta& meta);
134
+
135
+ // Custom implementation of self.set_(src)
136
+ void set__impl(const FunctionalTensorWrapper* other);
137
+
138
+ // Returns whether the current tensor's data was ever mutated
139
+ bool has_data_mutation();
140
+ //
141
+ // Returns whether the current FunctionalTensorWrapper
142
+ // experienced a set_() call.
143
+ bool was_storage_changed() {
144
+ return was_storage_changed_;
145
+ }
146
+
147
+ // The functionalization pass can be used to remove mutations.
148
+ // It does so by replacing any mutation op with it's corresponding
149
+ // out-of-place op, followed by a call to replace_(). e.g:
150
+ //
151
+ // a.add_(1)
152
+ //
153
+ // will turn into:
154
+ //
155
+ // tmp = a.add(1)
156
+ // a.replace_(tmp)
157
+ //
158
+ // replace_() swaps out the wrapped tensor, value_, with tmp.
159
+ void replace_(const Tensor& other);
160
+
161
+ bool is_multi_output_view() {
162
+ return is_multi_output_view_;
163
+ }
164
+
165
+ // See Note[resize_() in functionalization pass]
166
+ void maybe_replace_storage(const Tensor& other);
167
+
168
+ // Replaces the storage with a new functional storage,
169
+ // and clears the view_metas_ stack.
170
+ // WARNING: Calling this function will sever the aliasing relationship between
171
+ // the current FunctionalTensorWrapper and any of its outstanding aliases.
172
+ // Please only call if you know what you're doing.
173
+ void _unsafe_reset_storage();
174
+
175
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
176
+ const c10::VariableVersion& version_counter,
177
+ bool allow_tensor_metadata_change) const override;
178
+
179
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
180
+ c10::VariableVersion&& version_counter,
181
+ bool allow_tensor_metadata_change) const override;
182
+
183
+ ~FunctionalTensorWrapper() override = default;
184
+
185
+ // FunctionalTensorWrapper overrides all custom size/stride function,
186
+ // so that if the inner tensor has a custom implementation
187
+ // we make sure to call that implementation.
188
+ at::IntArrayRef sizes_custom() const override;
189
+ at::IntArrayRef strides_custom() const override;
190
+ int64_t dim_custom() const override;
191
+ int64_t numel_custom() const override;
192
+ bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
193
+ c10::SymIntArrayRef sym_sizes_custom() const override;
194
+ c10::SymInt sym_size_custom(int64_t d) const override;
195
+ c10::SymIntArrayRef sym_strides_custom() const override;
196
+ c10::SymInt sym_storage_offset_custom() const override;
197
+ c10::Device device_custom() const override;
198
+
199
+ private:
200
+ const char* tensorimpl_type_name() const override;
201
+ void set_constructor_metadata();
202
+ functionalization::FunctionalStorageImpl* functional_storage_impl() const;
203
+
204
+ // This is used to re-implement shallow_copy_and_detach for
205
+ // FunctionalTensorWrapper. The implementation is identical, but we just need
206
+ // to return a subclass instead of a plain TensorImpl.
207
+ // TODO: maybe it's possible to arrange for that to happen automatically
208
+ // without an override here?
209
+ template <typename VariableVersion>
210
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
211
+ VariableVersion&& version_counter,
212
+ bool allow_tensor_metadata_change) const;
213
+
214
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
215
+ void copy_tensor_metadata_and_refresh(
216
+ const FunctionalTensorWrapper* src_impl,
217
+ FunctionalTensorWrapper* dest_impl,
218
+ const c10::VariableVersion& version_counter,
219
+ bool allow_tensor_metadata_change) const;
220
+
221
+ // Note that value is not taken by reference: internally, the wrapper will
222
+ // change the value tensor that it points to over time.
223
+ Tensor value_;
224
+ int64_t level_{};
225
+ // These two counters are used for identifying
226
+ // whether all the mutations on a given tensor are hidden from autograd or
227
+ // not. If we have an input mutation that is hidden from autograd, then once
228
+ // we convert the input mutation to a copy_() we know it will be safe to hide
229
+ // the copy_() from autograd as well.
230
+ uint64_t mutation_counter_ = 0;
231
+ uint64_t mutation_hidden_from_autograd_counter_ = 0;
232
+ uint64_t mutation_during_no_grad_or_inference_mode_ = 0;
233
+ bool has_metadata_mutation_ = false;
234
+ bool is_multi_output_view_ = false;
235
+ // Did the tensor experience a set_() call.
236
+ bool was_storage_changed_ = false;
237
+
238
+ size_t generation_ = 0;
239
+ std::vector<at::functionalization::ViewMeta> view_metas_;
240
+
241
+ protected:
242
+ static void copy_tensor_metadata(
243
+ const FunctionalTensorWrapper* src_impl,
244
+ FunctionalTensorWrapper* dest_impl,
245
+ const c10::VariableVersion& version_counter,
246
+ bool allow_tensor_metadata_change);
247
+ };
248
+
249
+ // Utility functions for the functionalization pass.
250
+
251
+ namespace functionalization {
252
+ namespace impl {
253
+
254
+ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
255
+ const Tensor& tensor) {
256
+ auto functional_impl =
257
+ static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
258
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
259
+ return functional_impl;
260
+ }
261
+
262
+ TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
263
+ TORCH_API bool isFunctionalTensor(const c10::optional<Tensor>& t);
264
+ TORCH_API bool isFunctionalTensor(
265
+ const c10::List<c10::optional<Tensor>>& t_list);
266
+ TORCH_API bool isFunctionalTensor(ITensorListRef list);
267
+
268
+ TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
269
+ TORCH_API c10::optional<Tensor> to_functional_tensor(
270
+ const c10::optional<Tensor>& tensor);
271
+ TORCH_API c10::List<c10::optional<Tensor>> to_functional_tensor(
272
+ const c10::List<c10::optional<Tensor>>& t_list);
273
+ TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
274
+
275
+ TORCH_API void freeze_functional_tensor(const Tensor& tensor);
276
+
277
+ TORCH_API Tensor
278
+ from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
279
+ TORCH_API c10::optional<Tensor> from_functional_tensor(
280
+ const c10::optional<Tensor>& t,
281
+ bool assert_functional = true);
282
+ TORCH_API c10::List<c10::optional<Tensor>> from_functional_tensor(
283
+ const c10::List<c10::optional<Tensor>>& t_list);
284
+ TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
285
+
286
+ TORCH_API void sync(const at::Tensor& t);
287
+ TORCH_API void sync(const c10::optional<Tensor>& t);
288
+ TORCH_API void sync(const c10::List<c10::optional<Tensor>>& t_list);
289
+ TORCH_API void sync(ITensorListRef t_list);
290
+
291
+ TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
292
+ TORCH_API void replace_(
293
+ const ITensorListRef functional_tensor,
294
+ ITensorListRef other);
295
+
296
+ TORCH_API void commit_update(const Tensor& functional_tensor);
297
+ TORCH_API void commit_update(ITensorListRef functional_tensor);
298
+
299
+ TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
300
+
301
+ TORCH_API void mark_mutation_hidden_from_autograd(
302
+ const Tensor& functional_tensor);
303
+
304
+ TORCH_API bool are_all_mutations_hidden_from_autograd(
305
+ const Tensor& functional_tensor);
306
+
307
+ TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
308
+ const Tensor& functional_tensor);
309
+
310
+ // These two methods are XLA-specific logic and are no-ops
311
+ // for the normal functionalization flow.
312
+ TORCH_API void propagate_xla_data(
313
+ const Tensor& functional_tensor,
314
+ const Tensor& other);
315
+ TORCH_API void propagate_xla_data(
316
+ const ITensorListRef functional_tensor,
317
+ ITensorListRef other);
318
+
319
+ Tensor create_functional_tensor_with_view_meta(
320
+ const Tensor& view_to_wrap,
321
+ const Tensor& base,
322
+ functionalization::ViewMeta meta,
323
+ int64_t out_idx = 0);
324
+ std::vector<Tensor> create_functional_tensor_with_view_meta(
325
+ ITensorListRef view_to_wrap,
326
+ const Tensor& base,
327
+ const functionalization::ViewMeta& meta);
328
+
329
+ void mutate_view_meta(
330
+ const Tensor& self,
331
+ const functionalization::ViewMeta& meta);
332
+
333
+ void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
334
+ void set_sizes_strides_offset(
335
+ const std::vector<Tensor>& outs,
336
+ const std::vector<Tensor>& meta_outs);
337
+
338
+ // ~~~~~ TLS used in functionalization ~~~~~
339
+
340
+ TORCH_API bool getFunctionalizationReapplyViewsTLS();
341
+ TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
342
+
343
+ class TORCH_API FunctionalizationReapplyViewsGuard {
344
+ public:
345
+ FunctionalizationReapplyViewsGuard(bool reapply_views)
346
+ : prev_(getFunctionalizationReapplyViewsTLS()) {
347
+ setFunctionalizationReapplyViewsTLS(reapply_views);
348
+ }
349
+
350
+ ~FunctionalizationReapplyViewsGuard() {
351
+ setFunctionalizationReapplyViewsTLS(prev_);
352
+ }
353
+
354
+ FunctionalizationReapplyViewsGuard(
355
+ const FunctionalizationReapplyViewsGuard&) = delete;
356
+ FunctionalizationReapplyViewsGuard operator=(
357
+ const FunctionalizationReapplyViewsGuard&) = delete;
358
+ FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
359
+ delete;
360
+ FunctionalizationReapplyViewsGuard operator=(
361
+ FunctionalizationReapplyViewsGuard&&) = delete;
362
+
363
+ private:
364
+ bool prev_;
365
+ };
366
+
367
+ } // namespace impl
368
+
369
+ // Helper function to call an out-of-place composite aten kernel that may use
370
+ // mutations / views internally, and functionalize them.
371
+ TORCH_API void functionalize_op_helper(
372
+ const c10::OperatorHandle& op,
373
+ torch::jit::Stack* stack);
374
+
375
+ template <class Op, bool symint, class ReturnType, class... ParameterTypes>
376
+ struct _functionalize_aten_op final {};
377
+
378
+ template <class Op, bool symint, class ReturnType, class... ParameterTypes>
379
+ struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
380
+ static ReturnType call(
381
+ typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
382
+ using FuncType = ReturnType(
383
+ typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
384
+ auto op = c10::Dispatcher::singleton()
385
+ .findSchemaOrThrow(
386
+ (const char*)Op::name, (const char*)Op::overload_name)
387
+ .typed<FuncType>();
388
+
389
+ return c10::impl::BoxedKernelWrapper<FuncType>::call(
390
+ c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
391
+ op,
392
+ // BoxedKernelWrapper knows to ignore this keyset argument,
393
+ // because functionalize_op_helper doesn't take in a DispatchKeySet
394
+ c10::DispatchKeySet(),
395
+ args...);
396
+ }
397
+ };
398
+
399
+ template <class Op>
400
+ using functionalize_aten_op =
401
+ _functionalize_aten_op<Op, false, typename Op::schema>;
402
+
403
+ template <class Op>
404
+ using functionalize_aten_op_symint =
405
+ _functionalize_aten_op<Op, true, typename Op::schema>;
406
+
407
+ } // namespace functionalization
408
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Generator.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/Generator.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedFallback.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/core/op_registration/op_registration.h>
4
+ #include <torch/library.h>
5
+
6
+ namespace at {
7
+
8
+ // If an operator doesn't have a batching rule implemented then we fallback
9
+ // to this implementation. The fallback only works on out-of-place operators
10
+ // that return only tensors with new memory. (e.g., no in-place operators, no
11
+ // view operations).
12
+ //
13
+ // The fallback effectively takes all of the BatchedTensors in `stack`, slices
14
+ // them, and runs `op` on all of the corresponding slices to produce slices
15
+ // of the outputs. The output slices then get `torch.stack`ed to create the
16
+ // final returns.
17
+ //
18
+ // The performance of the fallback is not very good because it introduces an
19
+ // extra copy from stacking the sliced outputs. Because of this, we prefer to
20
+ // write batching rules for operators whenever possible.
21
+ void batchedTensorForLoopFallback(
22
+ const c10::OperatorHandle& op,
23
+ torch::jit::Stack* stack);
24
+
25
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <bitset>
4
+
5
+ #include <ATen/ArrayRef.h>
6
+ #include <ATen/SmallVector.h>
7
+ #include <ATen/Tensor.h>
8
+
9
+ namespace at {
10
+
11
+ // We assume this in a few other places in the codebase,
12
+ // but there isn't a centralized definition.
13
+ constexpr int64_t kVmapMaxTensorDims = 64;
14
+
15
+ // The valid vmap levels range from [0, 64). This effectively means that we
16
+ // support a maximum of 64 nested vmaps.
17
+ constexpr int64_t kVmapNumLevels = 64;
18
+
19
+ // Store this number of elements of BatchDims on the stack. Most people will
20
+ // probably use <= 5 nested vmaps, but adjust this number as necessary.
21
+ constexpr int64_t kBatchDimsStackSize = 5;
22
+
23
+ // a BatchDim represents a "private" dimension on a Tensor created inside of
24
+ // vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
25
+ // is being vmap'ed over and the `level` being an identifier for which vmap
26
+ // said dimension was created inside. The `dim` corresponds to a "physical
27
+ // dim" - it is a dimension index on the underlying physical tensor that is
28
+ // being vmapped over.
29
+ struct BatchDim {
30
+ BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
31
+ int64_t dim() const {
32
+ return dim_;
33
+ }
34
+ int64_t level() const {
35
+ return level_;
36
+ }
37
+
38
+ private:
39
+ int64_t dim_;
40
+ int64_t level_;
41
+ };
42
+
43
+ using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
44
+ using BatchDimsRef = ArrayRef<BatchDim>;
45
+
46
+ // A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
47
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
48
+ // BatchedTensorImpl.
49
+ //
50
+ // The batch dimensions are treated as being "private"; they are not
51
+ // user-visible. For example, in the following Tensor,
52
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
53
+ // dimensions 0 and 1 are batch dimensions.
54
+ //
55
+ // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
56
+ // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
57
+ // tensor.
58
+ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
59
+ explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
60
+
61
+ // Returns a reference to BatchDims that represent which dimensions of this
62
+ // tensor are private.
63
+ BatchDimsRef bdims() const {
64
+ return bdims_;
65
+ }
66
+
67
+ // BatchedTensorImpl wraps a Tensor
68
+ const Tensor& value() const {
69
+ return value_;
70
+ };
71
+
72
+ // Given a public dimension index, return the dimension index in the
73
+ // underlying value() tensor. For example, if we have
74
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
75
+ // dim=2)])
76
+ // bt.actualDim(0) -> 1
77
+ // bt.actualDim(1) -> 3
78
+ // bt.actualDim(2) -> Error
79
+ int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
80
+
81
+ // We have to override this because we opted into CustomStrides
82
+ IntArrayRef strides_custom() const override;
83
+ // Override a bunch of methods inherited from TensorImpl to return error
84
+ // messages.
85
+ bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
86
+ void set_size(int64_t dim, int64_t new_size) override;
87
+ void set_stride(int64_t dim, int64_t new_stride) override;
88
+ void set_storage_offset(int64_t storage_offset) override;
89
+ #ifdef DEBUG
90
+ bool has_storage() const override;
91
+ #endif
92
+
93
+ private:
94
+ // see NOTE: [BatchedTensorImpl levels invariant]
95
+ void checkInvariants() const;
96
+ const char* tensorimpl_type_name() const override;
97
+
98
+ Tensor value_;
99
+
100
+ // Note: [BatchedTensorImpl levels invariant]
101
+ // There is an invariant that the BatchDims must be stored in increasing
102
+ // `level` order. That is, for i < j, bdims_[i].level must be less than
103
+ // bdims_[j].level.
104
+ BatchDims bdims_;
105
+ };
106
+
107
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
108
+ // BatchedTensorImpl.
109
+ inline bool isBatchedTensor(const Tensor& tensor) {
110
+ return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
111
+ }
112
+
113
+ // It is unsafe to call this on a Tensor that is not backed by a
114
+ // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
115
+ inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
116
+ return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
117
+ }
118
+
119
+ inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
120
+ if (!isBatchedTensor(tensor)) {
121
+ return nullptr;
122
+ }
123
+ return unsafeGetBatchedImpl(tensor);
124
+ }
125
+
126
+ // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
127
+ inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
128
+ BatchDimsRef bdims) {
129
+ std::bitset<kVmapMaxTensorDims> is_bdim;
130
+ for (const auto& bdim : bdims) {
131
+ is_bdim.set(bdim.dim());
132
+ }
133
+ return is_bdim;
134
+ }
135
+
136
+ // Creates a bitset for all of the levels present in `bdims`
137
+ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
138
+ std::bitset<kVmapNumLevels> result;
139
+ for (const auto& bdim : bdims) {
140
+ result.set(bdim.level());
141
+ }
142
+ return result;
143
+ }
144
+
145
+ inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
146
+ out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
147
+ return out;
148
+ }
149
+
150
+ // Use this to construct a BatchedTensor from a regular Tensor
151
+ TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
152
+
153
+ // Adds a batch dim to `tensor`, returning a BatchedTensor
154
+ TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
155
+
156
+ // Checks if an inplace operation on self and other is "vmap compatible".
157
+ // See NOTE: [vmap-incompatible in-place operations] for the definition of this.
158
+ TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
159
+
160
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MapAllocator.h ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Allocator.h>
4
+ #include <c10/util/string_view.h>
5
+
6
+ namespace at {
7
+
8
+ enum MappedAllocatorModes {
9
+ ALLOCATOR_MAPPED_SHARED = 1,
10
+ ALLOCATOR_MAPPED_SHAREDMEM = 2,
11
+ ALLOCATOR_MAPPED_EXCLUSIVE = 4,
12
+ ALLOCATOR_MAPPED_NOCREATE = 8,
13
+ ALLOCATOR_MAPPED_KEEPFD = 16,
14
+ ALLOCATOR_MAPPED_FROMFD = 32,
15
+ ALLOCATOR_MAPPED_UNLINK = 64
16
+ };
17
+
18
+ // Sentinel value/type to help distinguish the file descriptor constructor from
19
+ // the non-file descriptor constructor
20
+ enum WithFd { WITH_FD };
21
+
22
+ TORCH_API std::string NewProcessWideShmHandle();
23
+
24
+ class TORCH_API MapAllocator {
25
+ public:
26
+ MapAllocator(c10::string_view filename, int flags, size_t size);
27
+ MapAllocator(
28
+ WithFd,
29
+ c10::string_view filename,
30
+ int fd,
31
+ int flags,
32
+ size_t size);
33
+ MapAllocator(const MapAllocator&) = delete;
34
+ MapAllocator& operator=(const MapAllocator&) = delete;
35
+ MapAllocator(MapAllocator&&) = delete;
36
+ MapAllocator& operator=(MapAllocator&&) = delete;
37
+
38
+ const char* filename() const {
39
+ return filename_.c_str();
40
+ }
41
+ int fd() const {
42
+ #ifdef _WIN32
43
+ TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows");
44
+ #else
45
+ return fd_;
46
+ #endif
47
+ }
48
+ ptrdiff_t size() const {
49
+ return size_;
50
+ }
51
+ // Return a pointer to the actual data for this allocator
52
+ // (in the case of the refcounted allocator, this is offset
53
+ // from the base pointer.)
54
+ virtual void* data() const {
55
+ return base_ptr_;
56
+ }
57
+
58
+ static MapAllocator* fromDataPtr(const at::DataPtr&);
59
+ static at::DataPtr makeDataPtr(
60
+ c10::string_view filename,
61
+ int flags,
62
+ size_t size,
63
+ size_t* actual_size_out);
64
+ static at::DataPtr makeDataPtr(
65
+ WithFd,
66
+ const char* filename,
67
+ int fd,
68
+ int flags,
69
+ size_t size,
70
+ size_t* actual_size_out);
71
+
72
+ // Closes the data. Helps us avoid destructor shenanigans
73
+ virtual void close();
74
+
75
+ // This is very dangerous. You have to redefine this destructor for each
76
+ // subclass
77
+ virtual ~MapAllocator();
78
+
79
+ protected:
80
+ bool closed_ = false;
81
+ std::string filename_;
82
+ int flags_ = 0;
83
+ ptrdiff_t size_; /* mapped size */
84
+ #ifdef _WIN32
85
+ void* handle_;
86
+ void* event_;
87
+ std::string eventname_;
88
+ #else
89
+ int fd_ = -1;
90
+ #endif
91
+ void* base_ptr_ = nullptr;
92
+ };
93
+
94
+ // Base-from-member idiom
95
+ struct TORCH_API RefcountedMapAllocatorArgCheck {
96
+ RefcountedMapAllocatorArgCheck(int flags);
97
+ };
98
+
99
+ class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
100
+ public MapAllocator {
101
+ public:
102
+ RefcountedMapAllocator(const char* filename, int flags, size_t size);
103
+ RefcountedMapAllocator(
104
+ WithFd,
105
+ const char* filename,
106
+ int fd,
107
+ int flags,
108
+ size_t size);
109
+
110
+ static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
111
+ static at::DataPtr makeDataPtr(
112
+ const char* filename,
113
+ int flags,
114
+ size_t size,
115
+ size_t* actual_size_out);
116
+ static at::DataPtr makeDataPtr(
117
+ WithFd,
118
+ const char* filename,
119
+ int fd,
120
+ int flags,
121
+ size_t size,
122
+ size_t* actual_size_out);
123
+
124
+ void* data() const override;
125
+
126
+ void incref();
127
+ int decref();
128
+ void close() override;
129
+
130
+ ~RefcountedMapAllocator() override {
131
+ RefcountedMapAllocator::close();
132
+ }
133
+
134
+ protected:
135
+ void checkFlags();
136
+ void initializeAlloc();
137
+ };
138
+
139
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensor.h ADDED
@@ -0,0 +1 @@
 
 
1
+ #include <ATen/core/NamedTensor.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NestedTensorImpl.h ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/MemoryOverlap.h>
3
+ #include <ATen/Tensor.h>
4
+ #include <c10/core/DispatchKey.h>
5
+ #include <c10/core/DispatchKeySet.h>
6
+ #include <c10/core/MemoryFormat.h>
7
+ #include <c10/core/TensorImpl.h>
8
+ #include <c10/util/ArrayRef.h>
9
+ #include <c10/util/Exception.h>
10
+ #include <c10/util/Metaprogramming.h>
11
+ #include <c10/util/irange.h>
12
+
13
+ namespace at::native {
14
+ struct NestedTensorImpl;
15
+ inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
16
+ int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
17
+
18
+ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
19
+ explicit NestedTensorImpl(
20
+ Storage storage,
21
+ c10::DispatchKeySet key_set,
22
+ const caffe2::TypeMeta data_type,
23
+ at::Tensor nested_sizes,
24
+ at::Tensor nested_strides,
25
+ at::Tensor storage_offsets);
26
+
27
+ explicit NestedTensorImpl(
28
+ const at::Tensor& buffer,
29
+ at::Tensor nested_sizes,
30
+ at::Tensor nested_strides,
31
+ at::Tensor storage_offsets);
32
+ // assume contiguous, `nested_strides` and `offsets`
33
+ // can be infered from `nested_sizes`
34
+ explicit NestedTensorImpl(
35
+ const at::Tensor& buffer,
36
+ const at::Tensor& nested_sizes);
37
+
38
+ // This constructor is used creating view tensors from nested tensors
39
+ explicit NestedTensorImpl(
40
+ c10::TensorImpl::ImplType impl_type,
41
+ const at::Tensor& base_tensor,
42
+ at::Tensor nested_sizes,
43
+ at::Tensor nested_strides,
44
+ at::Tensor storage_offsets);
45
+
46
+ // TODO: don't expose private implementation details like this; in
47
+ // particular, resizing this tensor will mess up our dim() and
48
+ // callers cannot fix it.
49
+ const Tensor& get_nested_sizes() const {
50
+ return nested_sizes_;
51
+ }
52
+ // TODO: don't expose private implementation details like this
53
+ const Tensor& get_nested_strides() const {
54
+ return nested_strides_;
55
+ }
56
+ const Tensor& get_storage_offsets() const {
57
+ return storage_offsets_;
58
+ }
59
+ // Returns nullopt if the ith dimension is irregular. The ith dimension
60
+ // of a NestedTensor is regular if the unbound tensors match in
61
+ // size at the (i-1)th dimension.
62
+ c10::optional<int64_t> opt_size(int64_t d) const;
63
+
64
+ int64_t size(int64_t d) const {
65
+ c10::optional<int64_t> optional_size = this->opt_size(d);
66
+ TORCH_CHECK(
67
+ optional_size.has_value(),
68
+ "Given dimension ",
69
+ d,
70
+ " is irregular and does not have a size.");
71
+ return *optional_size;
72
+ }
73
+ /**
74
+ * Return a view of the nested tensor as a 1 dimensional contiguous tensor.
75
+ *
76
+ * The buffer tensor created by this function shares the same storage_impl as
77
+ * the original nested tensor, and therefore can be seen as a view.
78
+ *
79
+ * @return A newly constructed view tensor
80
+ */
81
+ at::Tensor get_buffer() const {
82
+ TORCH_CHECK(
83
+ nested_tensor_impl_is_contiguous(this),
84
+ "NestedTensor must be contiguous to get buffer.");
85
+ return get_unsafe_storage_as_tensor();
86
+ }
87
+ /**
88
+ * If possible use get_buffer() instead. This function returns the storage
89
+ * as a tensor directly, which is not safe to use in general. If using this
90
+ * function, The caller must ensure to account for nested_sizes,
91
+ * nested_strides and storage_offsets.
92
+ *
93
+ * @return A newly constructed view tensor
94
+ */
95
+ at::Tensor get_unsafe_storage_as_tensor() const {
96
+ auto buffer_key_set_ = generate_buffer_key_set();
97
+ const auto buffer_size = get_buffer_size();
98
+ auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
99
+ c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
100
+ buffer_tensor_impl->set_sizes_contiguous(
101
+ c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
102
+ return Tensor(buffer_tensor_impl);
103
+ }
104
+
105
+ size_t get_buffer_size() const {
106
+ return storage_.nbytes() / data_type_.itemsize();
107
+ }
108
+
109
+ protected:
110
+ const char* tensorimpl_type_name() const override;
111
+
112
+ // TODO: numel_custom and is_contiguous_custom can be profitably overridden
113
+ // with real implementations
114
+ int64_t numel_custom() const override;
115
+ c10::SymInt sym_numel_custom() const override;
116
+ bool is_contiguous_custom(MemoryFormat) const override;
117
+ int64_t size_custom(int64_t d) const override {
118
+ return this->size(d);
119
+ }
120
+ c10::SymInt sym_size_custom(int64_t d) const override {
121
+ return c10::SymInt{this->size(d)};
122
+ }
123
+ IntArrayRef sizes_custom() const override;
124
+ c10::SymIntArrayRef sym_sizes_custom() const override;
125
+ IntArrayRef strides_custom() const override;
126
+ c10::SymIntArrayRef sym_strides_custom() const override;
127
+
128
+ // this one is real
129
+ int64_t dim_custom() const override;
130
+
131
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
132
+ const c10::VariableVersion& version_counter,
133
+ bool allow_tensor_metadata_change) const override;
134
+
135
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
136
+ c10::VariableVersion&& version_counter,
137
+ bool allow_tensor_metadata_change) const override;
138
+
139
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
140
+ copy_tensor_metadata(
141
+ /*src_impl=*/impl.get(),
142
+ /*dest_impl=*/this,
143
+ /*version_counter=*/version_counter(),
144
+ /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
145
+ }
146
+
147
+ private:
148
+ // Must be called after any changes to our dim() to sync the state
149
+ // to TensorImpl.
150
+ void refresh_dim();
151
+
152
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
153
+ const at::Tensor nested_sizes_, nested_strides_;
154
+ // The starting positions of the underlying tensors in contiguous buffer
155
+ // i.e. the buffer memory offsets to get the underlying tensors
156
+ // The reason to keep this metadata is that, without strong enough constraint
157
+ // it cannot be derived from `nested_sizes_`
158
+ // and `nested_strides_`:
159
+ // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
160
+ // this can happen e.g. after slicing a nested tensor
161
+ // 2. when multiple tensors share a same memory
162
+ // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
163
+ // Some strong enough constraints are:
164
+ // 1. every underlying tensor is contiguous in memory
165
+ // && nesting in ascending order
166
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
167
+ const at::Tensor storage_offsets_;
168
+ // NOTE: -1 here means the size is missing
169
+ // Optional to allow it to be computed lazily from nested.
170
+ // TODO: maybe we can remove this metadata since
171
+ // we can compute it from `nested_sizes_`
172
+ mutable c10::optional<std::vector<int64_t>> opt_sizes_;
173
+
174
+ template <typename VariableVersion>
175
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
176
+ VariableVersion&& version_counter,
177
+ bool allow_tensor_metadata_change) const;
178
+
179
+ /**
180
+ * Generates a non-nested key_set from a nested tensor.
181
+ *
182
+ * For many nested tensor kernel implementations a buffer tensor
183
+ * is generated and redispatched to a non-nested kernel this function
184
+ * generates the key set used by that buffer tensor
185
+ *
186
+ * @return Appropriate key set for non-nested tensor
187
+ */
188
+ inline c10::DispatchKeySet generate_buffer_key_set() const {
189
+ auto buffer_key_set = this->key_set();
190
+ const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
191
+ // Remove nested tensor specific keys
192
+ buffer_key_set = buffer_key_set -
193
+ c10::DispatchKeySet{
194
+ c10::DispatchKey::NestedTensor,
195
+ c10::DispatchKey::AutogradNestedTensor};
196
+
197
+ // Add dense tensor specific keys
198
+ buffer_key_set =
199
+ buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
200
+ buffer_key_set = Autograd
201
+ ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
202
+ : buffer_key_set;
203
+
204
+ return buffer_key_set;
205
+ }
206
+ };
207
+
208
+ inline NestedTensorImpl* get_nested_tensor_impl_or_null(
209
+ const at::Tensor& tensor) {
210
+ if (tensor.is_nested()) {
211
+ return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
212
+ }
213
+ return nullptr;
214
+ }
215
+
216
+ inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
217
+ TORCH_CHECK(
218
+ tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
219
+ return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
220
+ }
221
+
222
+ inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
223
+ int64_t ntensors = nt->size(0);
224
+ if (ntensors == 0) {
225
+ return true;
226
+ }
227
+ const Tensor &sizemat = nt->get_nested_sizes(),
228
+ &stridemat = nt->get_nested_strides();
229
+ int64_t* offsets_ptr = nt->get_storage_offsets().data_ptr<int64_t>();
230
+ int64_t orig_dim = sizemat.size(1);
231
+ // nesting scalars
232
+ if (orig_dim == 0) {
233
+ // each scalar must be contiguous
234
+ // if there is blank memory between underlying scalars
235
+ for (int64_t i = 0; i < ntensors; i++) {
236
+ if (offsets_ptr[i] != i) {
237
+ return false;
238
+ }
239
+ }
240
+ }
241
+ // nesting tensors
242
+ else {
243
+ // if any underlying tensor is non-contiguous
244
+ const int64_t *sizemat_ptr = sizemat.data_ptr<int64_t>(),
245
+ *stridemat_ptr = stridemat.data_ptr<int64_t>();
246
+ for (int64_t i = 0; i < ntensors; i++) {
247
+ if (stridemat_ptr[orig_dim - 1] != 1) {
248
+ return false;
249
+ }
250
+ int64_t product = sizemat_ptr[orig_dim - 1];
251
+ for (int64_t j = orig_dim - 2; j >= 0; j--) {
252
+ if (stridemat_ptr[j] != product) {
253
+ return false;
254
+ }
255
+ product *= sizemat_ptr[j];
256
+ }
257
+ sizemat_ptr += orig_dim;
258
+ stridemat_ptr += orig_dim;
259
+ }
260
+ // if there is blank memory between underlying tensors
261
+ if (offsets_ptr[0] != 0) {
262
+ return false;
263
+ }
264
+ sizemat_ptr = sizemat.data_ptr<int64_t>();
265
+ stridemat_ptr = stridemat.data_ptr<int64_t>();
266
+ for (int64_t i = 1; i < ntensors; i++) {
267
+ if (offsets_ptr[i] !=
268
+ offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
269
+ return false;
270
+ }
271
+ sizemat_ptr += orig_dim;
272
+ stridemat_ptr += orig_dim;
273
+ }
274
+ }
275
+ // everything is fine
276
+ return true;
277
+ }
278
+
279
+ inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
280
+ return get_nested_tensor_impl(tensor)->get_nested_sizes();
281
+ }
282
+
283
+ } // namespace at::native
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PadNd.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <c10/util/Exception.h>
3
+ #include <c10/util/string_view.h>
4
+
5
+ namespace at {
6
+
7
+ enum class padding_mode {
8
+ reflect,
9
+ replicate,
10
+ circular,
11
+ constant,
12
+ };
13
+
14
+ static inline c10::string_view padding_mode_string(padding_mode m) {
15
+ switch (m) {
16
+ case padding_mode::reflect:
17
+ return "reflect";
18
+ case padding_mode::replicate:
19
+ return "replicate";
20
+ case padding_mode::circular:
21
+ return "circular";
22
+ case padding_mode::constant:
23
+ return "constant";
24
+ }
25
+ TORCH_CHECK(false, "Invalid padding mode (", static_cast<int64_t>(m), ")");
26
+ }
27
+
28
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel-inl.h ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/util/Exception.h>
4
+ #include <c10/util/ParallelGuard.h>
5
+ #include <c10/util/SmallVector.h>
6
+
7
+ namespace at {
8
+
9
+ template <class F>
10
+ inline void parallel_for(
11
+ const int64_t begin,
12
+ const int64_t end,
13
+ const int64_t grain_size,
14
+ const F& f) {
15
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
16
+ if (begin >= end) {
17
+ return;
18
+ }
19
+
20
+ #ifdef INTRA_OP_PARALLEL
21
+ at::internal::lazy_init_num_threads();
22
+ const auto numiter = end - begin;
23
+ const bool use_parallel =
24
+ (numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
25
+ at::get_num_threads() > 1);
26
+ if (!use_parallel) {
27
+ internal::ThreadIdGuard tid_guard(0);
28
+ c10::ParallelGuard guard(true);
29
+ f(begin, end);
30
+ return;
31
+ }
32
+
33
+ internal::invoke_parallel(
34
+ begin, end, grain_size, [&](int64_t begin, int64_t end) {
35
+ c10::ParallelGuard guard(true);
36
+ f(begin, end);
37
+ });
38
+ #else
39
+ internal::ThreadIdGuard tid_guard(0);
40
+ c10::ParallelGuard guard(true);
41
+ f(begin, end);
42
+ #endif
43
+ }
44
+
45
+ template <class scalar_t, class F, class SF>
46
+ inline scalar_t parallel_reduce(
47
+ const int64_t begin,
48
+ const int64_t end,
49
+ const int64_t grain_size,
50
+ const scalar_t ident,
51
+ const F& f,
52
+ const SF& sf) {
53
+ TORCH_CHECK(grain_size >= 0);
54
+ if (begin >= end) {
55
+ return ident;
56
+ }
57
+
58
+ #ifdef INTRA_OP_PARALLEL
59
+ at::internal::lazy_init_num_threads();
60
+ const auto max_threads = at::get_num_threads();
61
+ const bool use_parallel =
62
+ ((end - begin) > grain_size && !at::in_parallel_region() &&
63
+ max_threads > 1);
64
+ if (!use_parallel) {
65
+ internal::ThreadIdGuard tid_guard(0);
66
+ c10::ParallelGuard guard(true);
67
+ return f(begin, end, ident);
68
+ }
69
+
70
+ c10::SmallVector<scalar_t, 64> results(max_threads, ident);
71
+ internal::invoke_parallel(
72
+ begin,
73
+ end,
74
+ grain_size,
75
+ [&](const int64_t my_begin, const int64_t my_end) {
76
+ const auto tid = at::get_thread_num();
77
+ c10::ParallelGuard guard(true);
78
+ results[tid] = f(my_begin, my_end, ident);
79
+ });
80
+
81
+ scalar_t result = ident;
82
+ for (auto partial_result : results) {
83
+ result = sf(result, partial_result);
84
+ }
85
+ return result;
86
+ #else
87
+ internal::ThreadIdGuard tid_guard(0);
88
+ c10::ParallelGuard guard(true);
89
+ return f(begin, end, ident);
90
+ #endif
91
+ }
92
+
93
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel.h ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/Config.h>
3
+ #include <c10/macros/Macros.h>
4
+ #include <functional>
5
+ #include <string>
6
+
7
+ namespace at {
8
+
9
+ inline int64_t divup(int64_t x, int64_t y) {
10
+ return (x + y - 1) / y;
11
+ }
12
+
13
+ // Called during new thread initialization
14
+ TORCH_API void init_num_threads();
15
+
16
+ // Sets the number of threads to be used in parallel region
17
+ TORCH_API void set_num_threads(int);
18
+
19
+ // Returns the maximum number of threads that may be used in a parallel region
20
+ TORCH_API int get_num_threads();
21
+
22
+ // Returns the current thread number (starting from 0)
23
+ // in the current parallel region, or 0 in the sequential region
24
+ TORCH_API int get_thread_num();
25
+
26
+ // Checks whether the code runs in parallel region
27
+ TORCH_API bool in_parallel_region();
28
+
29
+ namespace internal {
30
+
31
+ // Initialise num_threads lazily at first parallel call
32
+ inline void lazy_init_num_threads() {
33
+ thread_local bool init = false;
34
+ if (C10_UNLIKELY(!init)) {
35
+ at::init_num_threads();
36
+ init = true;
37
+ }
38
+ }
39
+
40
+ TORCH_API void set_thread_num(int);
41
+
42
+ class TORCH_API ThreadIdGuard {
43
+ public:
44
+ ThreadIdGuard(int new_id) : old_id_(at::get_thread_num()) {
45
+ set_thread_num(new_id);
46
+ }
47
+
48
+ ~ThreadIdGuard() {
49
+ set_thread_num(old_id_);
50
+ }
51
+
52
+ private:
53
+ int old_id_;
54
+ };
55
+
56
+ } // namespace internal
57
+
58
+ /*
59
+ parallel_for
60
+
61
+ begin: index at which to start applying user function
62
+
63
+ end: index at which to stop applying user function
64
+
65
+ grain_size: number of elements per chunk. impacts the degree of parallelization
66
+
67
+ f: user function applied in parallel to the chunks, signature:
68
+ void f(int64_t begin, int64_t end)
69
+
70
+ Warning: parallel_for does NOT copy thread local
71
+ states from the current thread to the worker threads.
72
+ This means for example that Tensor operations CANNOT be used in the
73
+ body of your function, only data pointers.
74
+ */
75
+ template <class F>
76
+ inline void parallel_for(
77
+ const int64_t begin,
78
+ const int64_t end,
79
+ const int64_t grain_size,
80
+ const F& f);
81
+
82
+ /*
83
+ parallel_reduce
84
+
85
+ begin: index at which to start applying reduction
86
+
87
+ end: index at which to stop applying reduction
88
+
89
+ grain_size: number of elements per chunk. impacts number of elements in
90
+ intermediate results tensor and degree of parallelization.
91
+
92
+ ident: identity for binary combination function sf. sf(ident, x) needs to return
93
+ x.
94
+
95
+ f: function for reduction over a chunk. f needs to be of signature scalar_t
96
+ f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy)
97
+
98
+ sf: function to combine two partial results. sf needs to be of signature
99
+ scalar_t sf(scalar_t x, scalar_t y)
100
+
101
+ For example, you might have a tensor of 10000 entires and want to sum together
102
+ all the elements. Parallel_reduce with a grain_size of 2500 will then allocate
103
+ an intermediate result tensor with 4 elements. Then it will execute the function
104
+ "f" you provide and pass the beginning and end index of these chunks, so
105
+ 0-2499, 2500-4999, etc. and the combination identity. It will then write out
106
+ the result from each of these chunks into the intermediate result tensor. After
107
+ that it'll reduce the partial results from each chunk into a single number using
108
+ the combination function sf and the identity ident. For a total summation this
109
+ would be "+" and 0 respectively. This is similar to tbb's approach [1], where
110
+ you need to provide a function to accumulate a subrange, a function to combine
111
+ two partial results and an identity.
112
+
113
+ Warning: parallel_reduce does NOT copy thread local
114
+ states from the current thread to the worker threads.
115
+ This means for example that Tensor operations CANNOT be used in the
116
+ body of your function, only data pointers.
117
+
118
+ [1] https://software.intel.com/en-us/node/506154
119
+ */
120
+ template <class scalar_t, class F, class SF>
121
+ inline scalar_t parallel_reduce(
122
+ const int64_t begin,
123
+ const int64_t end,
124
+ const int64_t grain_size,
125
+ const scalar_t ident,
126
+ const F& f,
127
+ const SF& sf);
128
+
129
+ // Returns a detailed string describing parallelization settings
130
+ TORCH_API std::string get_parallel_info();
131
+
132
+ // Sets number of threads used for inter-op parallelism
133
+ TORCH_API void set_num_interop_threads(int);
134
+
135
+ // Returns the number of threads used for inter-op parallelism
136
+ TORCH_API int get_num_interop_threads();
137
+
138
+ // Launches inter-op parallel task
139
+ TORCH_API void launch(std::function<void()> func);
140
+ namespace internal {
141
+ void launch_no_thread_state(std::function<void()> fn);
142
+ } // namespace internal
143
+
144
+ // Launches intra-op parallel task
145
+ TORCH_API void intraop_launch(std::function<void()> func);
146
+
147
+ // Returns number of intra-op threads used by default
148
+ TORCH_API int intraop_default_num_threads();
149
+
150
+ } // namespace at
151
+
152
+ #if AT_PARALLEL_OPENMP
153
+ #include <ATen/ParallelOpenMP.h> // IWYU pragma: keep
154
+ #elif AT_PARALLEL_NATIVE
155
+ #include <ATen/ParallelNative.h> // IWYU pragma: keep
156
+ #elif AT_PARALLEL_NATIVE_TBB
157
+ #include <ATen/ParallelNativeTBB.h> // IWYU pragma: keep
158
+ #endif
159
+
160
+ #include <ATen/Parallel-inl.h> // IWYU pragma: keep
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNative.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <algorithm>
4
+ #include <cstddef>
5
+ #include <exception>
6
+
7
+ #include <c10/util/Exception.h>
8
+
9
+ #define INTRA_OP_PARALLEL
10
+
11
+ namespace at::internal {
12
+
13
+ TORCH_API void invoke_parallel(
14
+ const int64_t begin,
15
+ const int64_t end,
16
+ const int64_t grain_size,
17
+ const std::function<void(int64_t, int64_t)>& f);
18
+
19
+ } // namespace at::internal
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SavedTensorHooks.h ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/macros/Export.h>
4
+ #include <c10/util/Optional.h>
5
+ #include <c10/util/python_stub.h>
6
+ #include <stack>
7
+ #include <string>
8
+
9
+ #include <utility>
10
+
11
+ namespace at {
12
+
13
+ namespace impl {
14
+
15
+ struct TORCH_API SavedTensorDefaultHooksTLS {
16
+ // PyObject is defined in c10/util/python_stub.h
17
+ std::stack<std::pair<PyObject*, PyObject*>> stack;
18
+
19
+ // See NOTE: [Disabling SavedTensorDefaultHooks] for context
20
+ // NOTE: [disabled_error_message invariant]
21
+ // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
22
+ // We did this for efficiency (so we didn't have to keep a separate bool
23
+ // around)
24
+ c10::optional<std::string> disabled_error_message;
25
+ };
26
+
27
+ } // namespace impl
28
+
29
+ struct TORCH_API SavedTensorDefaultHooks {
30
+ static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
31
+ static void pop_hooks();
32
+ static std::pair<PyObject*, PyObject*> get_hooks();
33
+ static void lazy_initialize();
34
+ static std::stack<std::pair<PyObject*, PyObject*>> get_stack();
35
+ static void set_stack(std::stack<std::pair<PyObject*, PyObject*>>);
36
+
37
+ static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
38
+ static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
39
+
40
+ // NOTE: [Disabling SavedTensorDefaultHooks]
41
+ // A developer of a PyTorch feature may choose to disable SavedTensorDefault
42
+ // hooks, especially if their feature does not work with it. If they are
43
+ // disabled, then the following will raise an error:
44
+ // - Attempting to push_hooks
45
+ // - calling disable(message) with a non-zero stack (from get_stack) size
46
+ static void disable(const std::string& error_message);
47
+ static void enable();
48
+ static bool is_enabled();
49
+ static const c10::optional<std::string>& get_disabled_error_message();
50
+ };
51
+
52
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorAccessor.h ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #pragma once
2
+ #include <ATen/core/TensorAccessor.h>
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIteratorInternal.h ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/native/TensorIterator.h>
3
+ #include <c10/util/SmallBuffer.h>
4
+ #include <c10/util/irange.h>
5
+
6
+ namespace at {
7
+
8
+ struct DimCounter {
9
+ DimCounter(IntArrayRef shape, Range range);
10
+
11
+ void increment(const std::array<int64_t, 2>& step);
12
+ bool is_done() const;
13
+ std::array<int64_t, 2> max_2d_step() const;
14
+
15
+ IntArrayRef shape;
16
+ Range range;
17
+ c10::SmallBuffer<int64_t, 4> values;
18
+ int64_t offset;
19
+ };
20
+
21
+ namespace internal {
22
+
23
+ inline void get_data_ptrs(
24
+ char** ptrs,
25
+ ArrayRef<char*> base,
26
+ IntArrayRef strides,
27
+ IntArrayRef counter) {
28
+ const auto ntensors = base.size();
29
+ const auto ndim = counter.size();
30
+ std::copy(base.begin(), base.end(), ptrs);
31
+ for (const auto dim : c10::irange(ndim)) {
32
+ int64_t value = counter[dim];
33
+ for (const auto arg : c10::irange(ntensors)) {
34
+ ptrs[arg] += value * strides[dim * ntensors + arg];
35
+ }
36
+ }
37
+ }
38
+
39
+ inline void serial_for_each(
40
+ IntArrayRef shape,
41
+ IntArrayRef strides,
42
+ char** base_ptrs,
43
+ size_t ntensors,
44
+ typename TensorIteratorBase::loop2d_t loop,
45
+ Range range) {
46
+ const auto ndim = shape.size();
47
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
48
+ strides.size() == ntensors * std::max(size_t{2}, ndim));
49
+
50
+ if (ndim <= 1) {
51
+ if (range.begin == 0) {
52
+ loop(base_ptrs, strides.data(), range.size(), 1);
53
+ } else {
54
+ c10::SmallBuffer<char*, 4> ptrs(ntensors);
55
+ get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin});
56
+ loop(ptrs.data(), strides.data(), range.size(), 1);
57
+ }
58
+ } else {
59
+ c10::SmallBuffer<char*, 4> ptrs(ntensors);
60
+ auto counter = DimCounter(shape, range);
61
+ while (!counter.is_done()) {
62
+ get_data_ptrs(
63
+ ptrs.data(), {base_ptrs, ntensors}, strides, counter.values);
64
+ auto step = counter.max_2d_step();
65
+ loop(ptrs.data(), strides.data(), step[0], step[1]);
66
+ counter.increment(step);
67
+ }
68
+ }
69
+ }
70
+
71
+ } // namespace internal
72
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorMeta.h ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/DimVector.h>
4
+ #include <ATen/core/Dimname.h>
5
+ #include <c10/core/TensorOptions.h>
6
+ #include <c10/util/strides.h>
7
+
8
+ namespace at {
9
+
10
+ class Tensor;
11
+
12
+ namespace impl {
13
+
14
+ // Use this to define the prototype for a meta function. There are two
15
+ // versions; one that takes one argument (just the operator name), or FUNC2
16
+ // variant that takes two arguments (operator name and overload name).
17
+ //
18
+ // Example usage:
19
+ //
20
+ // TORCH_META_FUNC2(add, Tensor) (
21
+ // const Tensor& self, const Tensor& other
22
+ // ) {
23
+ // ... compute sizes and options ...
24
+ // set_output(sizes, options);
25
+ // }
26
+ //
27
+ #define TORCH_META_FUNC(name) void structured_##name::meta
28
+ #define TORCH_META_FUNC2(name, overload) \
29
+ void structured_##name##_##overload::meta
30
+
31
+ // These are versions of TORCH_META_FUNC(2) that include a precompute_out struct
32
+ // as a return value. They should be used when the kernel in question has
33
+ // precomputed values declared in native_functions.yaml and the corresponding
34
+ // implementation should return an instance of the aforementioned struct.
35
+ #define TORCH_PRECOMPUTE_META_FUNC(name) \
36
+ structured_##name::meta_return_ty structured_##name::meta
37
+ #define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \
38
+ structured_##name##_##overload::meta_return_ty \
39
+ structured_##name##_##overload::meta
40
+
41
+ // Use this to create a precompute struct in a meta function.
42
+ #define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
43
+ #define TORCH_PRECOMPUTE_STRUCT2(name, overload) \
44
+ structured_##name##_##overload::precompute_out<>
45
+
46
+ // Use this to define the prototype for an implementation. This takes only
47
+ // one argument, which is the name of the dispatch key entry you're
48
+ // implementing.
49
+ //
50
+ // Example usage:
51
+ //
52
+ // TORCH_IMPL_FUNC(add_cpu) (
53
+ // Tensor& result, const Tensor& self, const Tensor& other
54
+ // ) {
55
+ // ... do the actual implementation ...
56
+ // }
57
+ //
58
+ #define TORCH_IMPL_FUNC(name) void structured_##name::impl
59
+
60
+ // Base class for all structured kernel classes. The set_output virtual
61
+ // method is varied depending whether or not the operator is
62
+ // functional/out/inplace, and could also be specialized for CPU/CUDA/etc
63
+ // (although presently it isn't).
64
+ //
65
+ // A notable subclass of this interface is TensorIteratorBase.
66
+ struct TORCH_API MetaBase {
67
+ MetaBase() = default;
68
+ MetaBase(const MetaBase&) = default;
69
+ MetaBase& operator=(const MetaBase&) = default;
70
+ MetaBase(MetaBase&&) noexcept = default;
71
+ MetaBase& operator=(MetaBase&&) noexcept = default;
72
+ virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
73
+
74
+ // Note: [set_output_*]
75
+ // See: https://github.com/pytorch/pytorch/issues/69813
76
+ // Whenever defining the output properties in the META function of a
77
+ // structured kernel (what was usually done with `set_output`), use one of
78
+ // these 3 variants, instead. In order to decide which variant to use, check
79
+ // the following decision tree:
80
+ //
81
+ // - Can the kernel you are going to implement support output tensors
82
+ // with arbitrary strides?
83
+ // |
84
+ // -- YES: `set_output_raw_strided`
85
+ // |
86
+ // -- NO: Should the output tensor strides be contiguous?
87
+ // |
88
+ // -- YES: `set_output_contiguous`
89
+ // |
90
+ // -- NO: `set_output_strided`
91
+ //
92
+ // Use this function whenever the kernel requires specific strides for the
93
+ // output. If `strides` does not match the given output strides, proxy outputs
94
+ // will be created and passed to the IMPL function.
95
+ virtual void set_output_strided(
96
+ int64_t output_idx,
97
+ IntArrayRef sizes,
98
+ IntArrayRef strides,
99
+ TensorOptions options,
100
+ DimnameList names = {}) {
101
+ TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
102
+ }
103
+
104
+ // Use this function whenever the kernel knows how to handle arbitrary strided
105
+ // outputs. This function has the same behavior as the old `set_output`: it
106
+ // will only re-stride if the given output was resized.
107
+ virtual void set_output_raw_strided(
108
+ int64_t output_idx,
109
+ IntArrayRef sizes,
110
+ IntArrayRef strides_hint,
111
+ TensorOptions options,
112
+ DimnameList names = {}) {
113
+ TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
114
+ }
115
+
116
+ // Use this function if the kernel requires contiguous strides.
117
+ // Alias for `set_output_strided`, but with contiguous strides.
118
+ void set_output_contiguous(
119
+ int64_t output_idx,
120
+ IntArrayRef sizes,
121
+ TensorOptions options,
122
+ DimnameList names = {}) {
123
+ auto strides = c10::contiguous_strides(sizes);
124
+ set_output_strided(output_idx, sizes, strides, options, names);
125
+ }
126
+
127
+ // Returns a reference to an undefined tensor if there is no presupplied
128
+ // output
129
+ const Tensor& maybe_get_output() {
130
+ return maybe_get_output(0);
131
+ }
132
+ virtual ~MetaBase() = default;
133
+ };
134
+
135
+ } // namespace impl
136
+
137
+ } // namespace at
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorNames.h ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/WrapDimUtils.h>
4
+
5
+ namespace at::namedinference {
6
+
7
+ // TensorName and TensorNames are wrappers around Dimname and DimnameList
8
+ // that contain helper functions to make writing name inference rules easier.
9
+ //
10
+ // A TensorName represents a Dimname associated with some DimnameList (from a
11
+ // Tensor). This encapsulates all the information that is needed to check if
12
+ // names *match* and to *unify* names.
13
+ //
14
+ // Definition: Two names in two tensors *match* if they are equal, or if at
15
+ // least one of them is a wildcard that can be *refined* to the other name.
16
+ //
17
+ // Definition: unify(name, other) fails if the names do not match. Otherwise,
18
+ // it returns the most refined of name and other.
19
+ //
20
+ // Here is an example of checking if two names match.
21
+ // tensor: Tensor[A, None]
22
+ // other: Tensor[A]
23
+ //
24
+ // Let's say we wish to check if tensor.names[-1] matches other.names[-1].
25
+ // None (in tensor) cannot match A (in other) because if the None were refined
26
+ // to A, `tensor` would have duplicate names [A, A]. Therefore we need to check
27
+ // tensor.names [A, None] for the existence of A.
28
+ struct TORCH_API TensorName {
29
+ explicit TensorName(ArrayRef<Dimname> origin, int origin_idx)
30
+ : origin_(origin),
31
+ name_(origin[maybe_wrap_dim(
32
+ origin_idx,
33
+ static_cast<int64_t>(origin.size()))]),
34
+ origin_idx_(origin_idx) {}
35
+
36
+ // op_name is only used for error reporting.
37
+ const TensorName& unify(const TensorName& other, const char* op_name) const;
38
+ Dimname toDimname() const;
39
+
40
+ private:
41
+ ArrayRef<Dimname> origin_;
42
+ Dimname name_;
43
+ int origin_idx_; // A named tensor can have at most 64 dims.
44
+
45
+ TORCH_API friend std::ostream& operator<<(
46
+ std::ostream& out,
47
+ const TensorName& tensorname);
48
+ };
49
+
50
+ using TensorNameVec = SmallVector<TensorName, 10>;
51
+
52
+ struct TORCH_API TensorNames {
53
+ explicit TensorNames(ArrayRef<Dimname> names);
54
+
55
+ // Create TensorNames from names[start:end]. Each individual TensorName stores
56
+ // `names`, NOT names[start:end], because the original tensor's names are
57
+ // `names`.
58
+ explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end);
59
+
60
+ // op_name is only used for error reporting.
61
+ TensorNames& unifyFromRightInplace(
62
+ const TensorNames& other,
63
+ const char* op_name = "unify");
64
+ void checkUnique(const char* op_name) const;
65
+
66
+ void append(TensorName name);
67
+ std::vector<Dimname> toDimnameVec() const;
68
+
69
+ private:
70
+ explicit TensorNames(TensorNameVec&& names) : names_(std::move(names)){};
71
+
72
+ TensorNameVec names_;
73
+ };
74
+
75
+ } // namespace at::namedinference