Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/comm.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nvtx.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/profiler.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/sparse.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/autocast_mode.py +144 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/error.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/graphs.py +479 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/jiterator.py +185 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nccl.py +137 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/profiler.py +61 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/random.py +179 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/sparse.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/streams.py +241 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ATen.h +37 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/AccumulateType.h +153 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Backend.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFixedAllocator.h +33 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CollapseDims.h +94 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h +808 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch_v2.h +186 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/EmptyTensor.h +160 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ExpandBase.h +30 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalTensorWrapper.h +408 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Generator.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedFallback.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h +160 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MapAllocator.h +139 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensor.h +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NestedTensorImpl.h +283 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PadNd.h +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel-inl.h +93 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel.h +160 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNative.h +19 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SavedTensorHooks.h +52 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorAccessor.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIteratorInternal.h +72 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorMeta.h +137 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorNames.h +75 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (64.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_memory_viz.cpython-311.pyc
ADDED
|
Binary file (37.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_sanitizer.cpython-311.pyc
ADDED
|
Binary file (36.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/_utils.cpython-311.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/comm.cpython-311.pyc
ADDED
|
Binary file (520 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/nvtx.cpython-311.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/profiler.cpython-311.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/sparse.cpython-311.pyc
ADDED
|
Binary file (209 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/autocast_mode.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
HAS_NUMPY = True
|
| 10 |
+
except ModuleNotFoundError:
|
| 11 |
+
np = None # type: ignore[assignment]
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class autocast(torch.amp.autocast_mode.autocast):
|
| 18 |
+
r"""See :class:`torch.autocast`.
|
| 19 |
+
|
| 20 |
+
``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)``
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
enabled: bool = True,
|
| 26 |
+
dtype: torch.dtype = torch.float16,
|
| 27 |
+
cache_enabled: bool = True,
|
| 28 |
+
):
|
| 29 |
+
if torch._jit_internal.is_scripting():
|
| 30 |
+
self._enabled = enabled
|
| 31 |
+
self.device = "cuda"
|
| 32 |
+
self.fast_dtype = dtype
|
| 33 |
+
return
|
| 34 |
+
super().__init__(
|
| 35 |
+
"cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def __enter__(self):
|
| 39 |
+
if torch._jit_internal.is_scripting():
|
| 40 |
+
return self
|
| 41 |
+
return super().__enter__()
|
| 42 |
+
|
| 43 |
+
# TODO: discuss a unified TorchScript-friendly API for autocast
|
| 44 |
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
|
| 45 |
+
if torch._jit_internal.is_scripting():
|
| 46 |
+
return
|
| 47 |
+
return super().__exit__(exc_type, exc_val, exc_tb)
|
| 48 |
+
|
| 49 |
+
def __call__(self, func):
|
| 50 |
+
if torch._jit_internal.is_scripting():
|
| 51 |
+
return func
|
| 52 |
+
return super().__call__(func)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
|
| 56 |
+
# may be falsely detected as "Iterables."
|
| 57 |
+
def _cast(value, dtype):
|
| 58 |
+
if isinstance(value, torch.Tensor):
|
| 59 |
+
is_eligible = (
|
| 60 |
+
value.is_floating_point()
|
| 61 |
+
and value.is_cuda
|
| 62 |
+
and (value.dtype is not torch.float64)
|
| 63 |
+
)
|
| 64 |
+
return value.to(dtype) if is_eligible else value
|
| 65 |
+
elif isinstance(value, (str, bytes)):
|
| 66 |
+
return value
|
| 67 |
+
elif HAS_NUMPY and isinstance(value, np.ndarray):
|
| 68 |
+
return value
|
| 69 |
+
elif isinstance(value, collections.abc.Mapping):
|
| 70 |
+
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
|
| 71 |
+
elif isinstance(value, collections.abc.Iterable):
|
| 72 |
+
iterable = (_cast(v, dtype) for v in value)
|
| 73 |
+
if isinstance(value, (list, tuple)):
|
| 74 |
+
return type(value)(iterable)
|
| 75 |
+
else:
|
| 76 |
+
return iterable
|
| 77 |
+
else:
|
| 78 |
+
return value
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# custom_fwd is a decorator that may or may not be used with arguments, following
|
| 82 |
+
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
|
| 83 |
+
# this works:
|
| 84 |
+
# @custom_fwd
|
| 85 |
+
# def forward(...):
|
| 86 |
+
# this also works:
|
| 87 |
+
# @custom_fwd(cast_inputs=torch.float)
|
| 88 |
+
# def forward(...):
|
| 89 |
+
def custom_fwd(fwd=None, *, cast_inputs=None):
|
| 90 |
+
"""
|
| 91 |
+
Create a helper decorator for ``forward`` methods of custom autograd functions.
|
| 92 |
+
|
| 93 |
+
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
| 94 |
+
See the :ref:`example page<amp-custom-examples>` for more detail.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
| 98 |
+
when ``forward`` runs in an autocast-enabled region, casts incoming
|
| 99 |
+
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
|
| 100 |
+
then executes ``forward`` with autocast disabled.
|
| 101 |
+
If ``None``, ``forward``'s internal ops execute with the current autocast state.
|
| 102 |
+
|
| 103 |
+
.. note::
|
| 104 |
+
If the decorated ``forward`` is called outside an autocast-enabled region,
|
| 105 |
+
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
|
| 106 |
+
"""
|
| 107 |
+
if fwd is None:
|
| 108 |
+
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
|
| 109 |
+
|
| 110 |
+
@functools.wraps(fwd)
|
| 111 |
+
def decorate_fwd(*args, **kwargs):
|
| 112 |
+
args[0]._dtype = torch.get_autocast_gpu_dtype()
|
| 113 |
+
if cast_inputs is None:
|
| 114 |
+
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
|
| 115 |
+
return fwd(*args, **kwargs)
|
| 116 |
+
else:
|
| 117 |
+
autocast_context = torch.is_autocast_enabled()
|
| 118 |
+
args[0]._fwd_used_autocast = False
|
| 119 |
+
if autocast_context:
|
| 120 |
+
with autocast(enabled=False):
|
| 121 |
+
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
|
| 122 |
+
else:
|
| 123 |
+
return fwd(*args, **kwargs)
|
| 124 |
+
|
| 125 |
+
return decorate_fwd
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
|
| 129 |
+
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
|
| 130 |
+
# cast_inputs supplied to custom_fwd.
|
| 131 |
+
def custom_bwd(bwd):
|
| 132 |
+
"""Create a helper decorator for backward methods of custom autograd functions.
|
| 133 |
+
|
| 134 |
+
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
| 135 |
+
Ensures that ``backward`` executes with the same autocast state as ``forward``.
|
| 136 |
+
See the :ref:`example page<amp-custom-examples>` for more detail.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
@functools.wraps(bwd)
|
| 140 |
+
def decorate_bwd(*args, **kwargs):
|
| 141 |
+
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
|
| 142 |
+
return bwd(*args, **kwargs)
|
| 143 |
+
|
| 144 |
+
return decorate_bwd
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/error.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/graphs.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils import _pytree
|
| 6 |
+
from .._utils import _dummy_type
|
| 7 |
+
|
| 8 |
+
if not hasattr(torch._C, "_CudaStreamBase"):
|
| 9 |
+
# Define dummy base classes
|
| 10 |
+
torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
|
| 11 |
+
torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
|
| 12 |
+
torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
|
| 13 |
+
"_cuda_isCurrentStreamCapturing"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from torch._C import ( # noqa: F401
|
| 17 |
+
_cuda_isCurrentStreamCapturing,
|
| 18 |
+
_CUDAGraph,
|
| 19 |
+
_graph_pool_handle,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def is_current_stream_capturing():
|
| 24 |
+
r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
|
| 25 |
+
|
| 26 |
+
If a CUDA context does not exist on the current device, returns False without initializing the context.
|
| 27 |
+
"""
|
| 28 |
+
return _cuda_isCurrentStreamCapturing()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Python shim helps Sphinx process docstrings more reliably.
|
| 32 |
+
def graph_pool_handle():
|
| 33 |
+
r"""Return an opaque token representing the id of a graph memory pool.
|
| 34 |
+
|
| 35 |
+
See :ref:`Graph memory management<graph-memory-management>`.
|
| 36 |
+
|
| 37 |
+
.. warning::
|
| 38 |
+
This API is in beta and may change in future releases.
|
| 39 |
+
"""
|
| 40 |
+
return _graph_pool_handle()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Python shim helps Sphinx process docstrings more reliably.
|
| 44 |
+
class CUDAGraph(torch._C._CUDAGraph):
|
| 45 |
+
r"""Wrapper around a CUDA graph.
|
| 46 |
+
|
| 47 |
+
.. warning::
|
| 48 |
+
This API is in beta and may change in future releases.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __new__(cls):
|
| 52 |
+
return super().__new__(cls)
|
| 53 |
+
|
| 54 |
+
def capture_begin(self, pool=None, capture_error_mode="global"):
|
| 55 |
+
r"""Begin capturing CUDA work on the current stream.
|
| 56 |
+
|
| 57 |
+
Typically, you shouldn't call ``capture_begin`` yourself.
|
| 58 |
+
Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
|
| 59 |
+
which call ``capture_begin`` internally.
|
| 60 |
+
|
| 61 |
+
Arguments:
|
| 62 |
+
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
|
| 63 |
+
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
|
| 64 |
+
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
|
| 65 |
+
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
|
| 66 |
+
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
|
| 67 |
+
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
|
| 68 |
+
actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
|
| 69 |
+
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
|
| 70 |
+
""" # noqa: B950
|
| 71 |
+
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
|
| 72 |
+
|
| 73 |
+
def capture_end(self):
|
| 74 |
+
r"""End CUDA graph capture on the current stream.
|
| 75 |
+
|
| 76 |
+
After ``capture_end``, ``replay`` may be called on this instance.
|
| 77 |
+
|
| 78 |
+
Typically, you shouldn't call ``capture_end`` yourself.
|
| 79 |
+
Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
|
| 80 |
+
which call ``capture_end`` internally.
|
| 81 |
+
"""
|
| 82 |
+
super().capture_end()
|
| 83 |
+
|
| 84 |
+
def replay(self):
|
| 85 |
+
r"""Replay the CUDA work captured by this graph."""
|
| 86 |
+
super().replay()
|
| 87 |
+
|
| 88 |
+
def reset(self):
|
| 89 |
+
r"""Delete the graph currently held by this instance."""
|
| 90 |
+
super().reset()
|
| 91 |
+
|
| 92 |
+
def pool(self):
|
| 93 |
+
r"""Return an opaque token representing the id of this graph's memory pool.
|
| 94 |
+
|
| 95 |
+
This id can optionally be passed to another graph's ``capture_begin``,
|
| 96 |
+
which hints the other graph may share the same memory pool.
|
| 97 |
+
"""
|
| 98 |
+
return super().pool()
|
| 99 |
+
|
| 100 |
+
def enable_debug_mode(self):
|
| 101 |
+
r"""Enable debugging mode for CUDAGraph.debug_dump."""
|
| 102 |
+
return super().enable_debug_mode()
|
| 103 |
+
|
| 104 |
+
def debug_dump(self, debug_path):
|
| 105 |
+
r"""
|
| 106 |
+
Arguments:
|
| 107 |
+
debug_path (required): Path to dump the graph to.
|
| 108 |
+
|
| 109 |
+
Calls a debugging function to dump the graph if the debugging is
|
| 110 |
+
enabled via CUDAGraph.enable_debug_mode()
|
| 111 |
+
"""
|
| 112 |
+
return super().debug_dump(debug_path)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class graph:
|
| 116 |
+
r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
|
| 117 |
+
|
| 118 |
+
See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
|
| 119 |
+
detailed use, and constraints.
|
| 120 |
+
|
| 121 |
+
Arguments:
|
| 122 |
+
cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
|
| 123 |
+
pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
|
| 124 |
+
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
|
| 125 |
+
may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
|
| 126 |
+
stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
|
| 127 |
+
If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
|
| 128 |
+
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
|
| 129 |
+
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
|
| 130 |
+
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
|
| 131 |
+
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
|
| 132 |
+
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
|
| 133 |
+
|
| 134 |
+
.. note::
|
| 135 |
+
For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
|
| 136 |
+
used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
|
| 137 |
+
|
| 138 |
+
.. warning::
|
| 139 |
+
This API is in beta and may change in future releases.
|
| 140 |
+
|
| 141 |
+
.. _cudaStreamCaptureMode:
|
| 142 |
+
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
| 143 |
+
""" # noqa: B950
|
| 144 |
+
|
| 145 |
+
default_capture_stream: Optional["torch.cuda.Stream"] = None
|
| 146 |
+
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
cuda_graph,
|
| 150 |
+
pool=None,
|
| 151 |
+
stream=None,
|
| 152 |
+
capture_error_mode: str = "global",
|
| 153 |
+
):
|
| 154 |
+
# Lazy-init of default_capture_stream helps avoid circular-import errors.
|
| 155 |
+
# Not thread safe, but graphs already have the general (explicitly documented)
|
| 156 |
+
# restriction that only one capture may be underway at a time in the process.
|
| 157 |
+
if self.__class__.default_capture_stream is None:
|
| 158 |
+
self.__class__.default_capture_stream = torch.cuda.Stream()
|
| 159 |
+
|
| 160 |
+
self.pool = () if pool is None else (pool,)
|
| 161 |
+
self.capture_stream = (
|
| 162 |
+
stream if stream is not None else self.__class__.default_capture_stream
|
| 163 |
+
)
|
| 164 |
+
assert self.capture_stream is not None
|
| 165 |
+
self.stream_ctx = torch.cuda.stream(self.capture_stream)
|
| 166 |
+
self.cuda_graph = cuda_graph
|
| 167 |
+
self.capture_error_mode = capture_error_mode
|
| 168 |
+
|
| 169 |
+
def __enter__(self):
|
| 170 |
+
# Free as much memory as we can for the graph
|
| 171 |
+
torch.cuda.synchronize()
|
| 172 |
+
gc.collect()
|
| 173 |
+
torch.cuda.empty_cache()
|
| 174 |
+
|
| 175 |
+
# Stackoverflow seems comfortable with this pattern
|
| 176 |
+
# https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
|
| 177 |
+
self.stream_ctx.__enter__()
|
| 178 |
+
|
| 179 |
+
self.cuda_graph.capture_begin(
|
| 180 |
+
*self.pool, capture_error_mode=self.capture_error_mode
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 184 |
+
self.cuda_graph.capture_end()
|
| 185 |
+
self.stream_ctx.__exit__(exc_type, exc_value, traceback)
|
| 186 |
+
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def make_graphed_callables(
|
| 190 |
+
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
|
| 191 |
+
):
|
| 192 |
+
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
|
| 193 |
+
|
| 194 |
+
Each graphed callable's forward pass runs its source callable's
|
| 195 |
+
forward CUDA work as a CUDA graph inside a single autograd node.
|
| 196 |
+
|
| 197 |
+
The graphed callable's forward pass also appends
|
| 198 |
+
a backward node to the autograd graph. During backward, this node runs the
|
| 199 |
+
callable's backward work as a CUDA graph.
|
| 200 |
+
|
| 201 |
+
Therefore, each graphed callable should be a drop-in replacement for its source callable
|
| 202 |
+
in an autograd-enabled training loop.
|
| 203 |
+
|
| 204 |
+
See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
|
| 205 |
+
|
| 206 |
+
If you pass a tuple of several callables, their captures will use the same memory pool.
|
| 207 |
+
See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
|
| 208 |
+
|
| 209 |
+
Arguments:
|
| 210 |
+
callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
|
| 211 |
+
See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
|
| 212 |
+
is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
|
| 213 |
+
they'll run in the live workload.
|
| 214 |
+
sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
|
| 215 |
+
If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
|
| 216 |
+
If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
|
| 217 |
+
num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
|
| 218 |
+
11 iterations for warm up. Default: ``3``.
|
| 219 |
+
allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
|
| 220 |
+
(and therefore their grad is always zero) is an error. Defaults to False.
|
| 221 |
+
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
|
| 222 |
+
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
|
| 223 |
+
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
|
| 224 |
+
.. note::
|
| 225 |
+
The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
|
| 226 |
+
that's expected for the corresponding real input in the training loop.
|
| 227 |
+
|
| 228 |
+
.. warning::
|
| 229 |
+
This API is in beta and may change in future releases.
|
| 230 |
+
|
| 231 |
+
.. warning::
|
| 232 |
+
``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
|
| 233 |
+
|
| 234 |
+
.. warning::
|
| 235 |
+
Returned callables do not support higher order differentiation (e.g., double backward).
|
| 236 |
+
|
| 237 |
+
.. warning::
|
| 238 |
+
In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
|
| 239 |
+
may be trainable. Buffers must have ``requires_grad=False``.
|
| 240 |
+
|
| 241 |
+
.. warning::
|
| 242 |
+
After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
|
| 243 |
+
you may not add or remove any of that Module's parameters or buffers.
|
| 244 |
+
|
| 245 |
+
.. warning::
|
| 246 |
+
:class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
|
| 247 |
+
registered on them at the time they are passed. However, registering hooks on modules *after* passing them
|
| 248 |
+
through :func:`~torch.cuda.make_graphed_callables` is allowed.
|
| 249 |
+
|
| 250 |
+
.. warning::
|
| 251 |
+
When running a graphed callable, you must pass its arguments in the same order and format
|
| 252 |
+
they appeared in that callable's ``sample_args``.
|
| 253 |
+
|
| 254 |
+
.. warning::
|
| 255 |
+
The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
|
| 256 |
+
caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
|
| 257 |
+
"""
|
| 258 |
+
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
|
| 259 |
+
raise RuntimeError(
|
| 260 |
+
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
just_one_callable = False
|
| 264 |
+
|
| 265 |
+
if not isinstance(callables, tuple):
|
| 266 |
+
just_one_callable = True
|
| 267 |
+
callables = (callables,)
|
| 268 |
+
sample_args = (sample_args,)
|
| 269 |
+
|
| 270 |
+
flatten_sample_args = []
|
| 271 |
+
|
| 272 |
+
for c, args in zip(callables, sample_args):
|
| 273 |
+
if isinstance(c, torch.nn.Module):
|
| 274 |
+
assert (
|
| 275 |
+
len(c._backward_hooks) == 0
|
| 276 |
+
and len(c._forward_hooks) == 0
|
| 277 |
+
and len(c._forward_pre_hooks) == 0
|
| 278 |
+
), (
|
| 279 |
+
"Modules must not have hooks registered at the time they are passed. However, registering hooks "
|
| 280 |
+
+ "on modules after passing them through make_graphed_callables is allowed."
|
| 281 |
+
)
|
| 282 |
+
assert all(b.requires_grad is False for b in c.buffers()), (
|
| 283 |
+
"In any :class:`~torch.nn.Module` passed to "
|
| 284 |
+
+ ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
|
| 285 |
+
+ "``requires_grad=False``."
|
| 286 |
+
)
|
| 287 |
+
flatten_arg = _pytree.arg_tree_leaves(*args)
|
| 288 |
+
flatten_sample_args.append(tuple(flatten_arg))
|
| 289 |
+
assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
|
| 290 |
+
"In the beta API, sample_args "
|
| 291 |
+
+ "for each callable must contain only Tensors. Other types are not allowed."
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
|
| 295 |
+
# passes to forward (ie, its sample_args) AND the module's parameter attributes.
|
| 296 |
+
per_callable_len_user_args = [len(args) for args in flatten_sample_args]
|
| 297 |
+
per_callable_module_params = [
|
| 298 |
+
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
|
| 299 |
+
for c in callables
|
| 300 |
+
]
|
| 301 |
+
per_callable_static_input_surfaces = [
|
| 302 |
+
flatten_sample_args[i] + per_callable_module_params[i]
|
| 303 |
+
for i in range(len(callables))
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
|
| 307 |
+
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
|
| 308 |
+
|
| 309 |
+
mempool = graph_pool_handle() if pool is None else pool
|
| 310 |
+
|
| 311 |
+
# Warmup
|
| 312 |
+
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
|
| 313 |
+
# from ending up in any captures.
|
| 314 |
+
torch.cuda.synchronize()
|
| 315 |
+
with torch.cuda.stream(torch.cuda.Stream()):
|
| 316 |
+
for func, args, static_input_surface in zip(
|
| 317 |
+
callables, sample_args, per_callable_static_input_surfaces
|
| 318 |
+
):
|
| 319 |
+
for _ in range(num_warmup_iters):
|
| 320 |
+
outputs = _pytree.tree_leaves(func(*args))
|
| 321 |
+
grad_inputs = torch.autograd.grad(
|
| 322 |
+
outputs=tuple(o for o in outputs if o.requires_grad),
|
| 323 |
+
inputs=tuple(i for i in static_input_surface if i.requires_grad),
|
| 324 |
+
grad_outputs=tuple(
|
| 325 |
+
torch.empty_like(o) for o in outputs if o.requires_grad
|
| 326 |
+
),
|
| 327 |
+
only_inputs=True,
|
| 328 |
+
allow_unused=allow_unused_input,
|
| 329 |
+
)
|
| 330 |
+
del outputs, grad_inputs # type: ignore[possibly-undefined]
|
| 331 |
+
torch.cuda.synchronize()
|
| 332 |
+
|
| 333 |
+
# All captures here share a mempool. To avoid replays corrupting each other's memory,
|
| 334 |
+
# the safest approach is to capture all passes in the same order they'll run:
|
| 335 |
+
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
|
| 336 |
+
|
| 337 |
+
# Capture forward graphs
|
| 338 |
+
per_callable_static_outputs = []
|
| 339 |
+
per_callable_output_unflatten_spec = []
|
| 340 |
+
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
|
| 341 |
+
with torch.cuda.graph(fwd_graph, pool=mempool):
|
| 342 |
+
outputs = func(*args)
|
| 343 |
+
|
| 344 |
+
flatten_outputs, spec = _pytree.tree_flatten(outputs)
|
| 345 |
+
per_callable_static_outputs.append(tuple(flatten_outputs))
|
| 346 |
+
per_callable_output_unflatten_spec.append(spec)
|
| 347 |
+
|
| 348 |
+
# Capture backward graphs in reverse order
|
| 349 |
+
per_callable_static_grad_outputs = []
|
| 350 |
+
per_callable_static_grad_inputs = []
|
| 351 |
+
for static_input_surface, static_outputs, bwd_graph, module_params in zip(
|
| 352 |
+
reversed(per_callable_static_input_surfaces),
|
| 353 |
+
reversed(per_callable_static_outputs),
|
| 354 |
+
reversed(bwd_graphs),
|
| 355 |
+
reversed(per_callable_module_params),
|
| 356 |
+
):
|
| 357 |
+
# For now, assumes all static_outputs require grad
|
| 358 |
+
# assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
|
| 359 |
+
static_grad_outputs = tuple(
|
| 360 |
+
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
with torch.cuda.graph(bwd_graph, pool=mempool):
|
| 364 |
+
grad_inputs = torch.autograd.grad(
|
| 365 |
+
outputs=tuple(o for o in static_outputs if o.requires_grad),
|
| 366 |
+
inputs=tuple(i for i in static_input_surface if i.requires_grad),
|
| 367 |
+
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
|
| 368 |
+
only_inputs=True,
|
| 369 |
+
allow_unused=allow_unused_input,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Constructs a tuple suitable for returning from Graphed.backward:
|
| 373 |
+
# Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
|
| 374 |
+
# I couldn't think of a slick one-liner for this pattern.
|
| 375 |
+
static_grad_inputs = []
|
| 376 |
+
grad_idx = 0
|
| 377 |
+
for arg in static_input_surface:
|
| 378 |
+
if arg.requires_grad:
|
| 379 |
+
static_grad_inputs.append(grad_inputs[grad_idx])
|
| 380 |
+
grad_idx += 1
|
| 381 |
+
else:
|
| 382 |
+
static_grad_inputs.append(None) # type: ignore[arg-type]
|
| 383 |
+
static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
|
| 384 |
+
|
| 385 |
+
per_callable_static_grad_outputs.append(static_grad_outputs)
|
| 386 |
+
per_callable_static_grad_inputs.append(static_grad_inputs)
|
| 387 |
+
|
| 388 |
+
# Reverses the most recent two lists
|
| 389 |
+
per_callable_static_grad_outputs.reverse()
|
| 390 |
+
per_callable_static_grad_inputs.reverse()
|
| 391 |
+
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
|
| 392 |
+
|
| 393 |
+
def make_graphed_autograd_function(
|
| 394 |
+
fwd_graph,
|
| 395 |
+
bwd_graph,
|
| 396 |
+
module_params,
|
| 397 |
+
len_user_args,
|
| 398 |
+
output_unflatten_spec,
|
| 399 |
+
static_input_surface,
|
| 400 |
+
static_outputs,
|
| 401 |
+
static_grad_outputs,
|
| 402 |
+
static_grad_inputs,
|
| 403 |
+
):
|
| 404 |
+
class Graphed(torch.autograd.Function):
|
| 405 |
+
@staticmethod
|
| 406 |
+
def forward(ctx, *inputs):
|
| 407 |
+
# At this stage, only the user args may (potentially) be new tensors.
|
| 408 |
+
for i in range(len_user_args):
|
| 409 |
+
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
|
| 410 |
+
static_input_surface[i].copy_(inputs[i])
|
| 411 |
+
fwd_graph.replay()
|
| 412 |
+
assert isinstance(static_outputs, tuple)
|
| 413 |
+
return tuple(o.detach() for o in static_outputs)
|
| 414 |
+
|
| 415 |
+
@staticmethod
|
| 416 |
+
@torch.autograd.function.once_differentiable
|
| 417 |
+
def backward(ctx, *grads):
|
| 418 |
+
assert len(grads) == len(static_grad_outputs)
|
| 419 |
+
for g, grad in zip(static_grad_outputs, grads):
|
| 420 |
+
if g is not None:
|
| 421 |
+
# don't copy if autograd gods have been kind and the
|
| 422 |
+
# incoming grad is already in the right place
|
| 423 |
+
if g.data_ptr() != grad.data_ptr():
|
| 424 |
+
g.copy_(grad)
|
| 425 |
+
bwd_graph.replay()
|
| 426 |
+
|
| 427 |
+
# Input args that didn't require grad expect a None gradient.
|
| 428 |
+
assert isinstance(static_grad_inputs, tuple)
|
| 429 |
+
return tuple(
|
| 430 |
+
b.detach() if b is not None else b for b in static_grad_inputs
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
def functionalized(*user_args):
|
| 434 |
+
# Runs the autograd function with inputs == all inputs to the graph that might require grad
|
| 435 |
+
# (explicit user args + module parameters)
|
| 436 |
+
# Assumes module params didn't change since capture.
|
| 437 |
+
flatten_user_args = _pytree.arg_tree_leaves(*user_args)
|
| 438 |
+
out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
|
| 439 |
+
return _pytree.tree_unflatten(out, output_unflatten_spec)
|
| 440 |
+
|
| 441 |
+
return functionalized
|
| 442 |
+
|
| 443 |
+
# Put together the final graphed callables
|
| 444 |
+
ret = []
|
| 445 |
+
for i, func in enumerate(callables):
|
| 446 |
+
graphed = make_graphed_autograd_function(
|
| 447 |
+
fwd_graphs[i],
|
| 448 |
+
bwd_graphs[i],
|
| 449 |
+
per_callable_module_params[i],
|
| 450 |
+
per_callable_len_user_args[i],
|
| 451 |
+
per_callable_output_unflatten_spec[i],
|
| 452 |
+
per_callable_static_input_surfaces[i],
|
| 453 |
+
per_callable_static_outputs[i],
|
| 454 |
+
per_callable_static_grad_outputs[i],
|
| 455 |
+
per_callable_static_grad_inputs[i],
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
if isinstance(func, torch.nn.Module):
|
| 459 |
+
|
| 460 |
+
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
|
| 461 |
+
def new_fwd(*user_args):
|
| 462 |
+
# If the module's training-or-eval state matches what we graphed,
|
| 463 |
+
# run the graph, otherwise run the original forward method
|
| 464 |
+
if func.training == graph_training_state:
|
| 465 |
+
return graphed(*user_args)
|
| 466 |
+
else:
|
| 467 |
+
return orig_fwd(*user_args)
|
| 468 |
+
|
| 469 |
+
return new_fwd
|
| 470 |
+
|
| 471 |
+
func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment]
|
| 472 |
+
ret.append(func)
|
| 473 |
+
else:
|
| 474 |
+
ret.append(graphed)
|
| 475 |
+
|
| 476 |
+
if just_one_callable:
|
| 477 |
+
return ret[0]
|
| 478 |
+
|
| 479 |
+
return tuple(ret)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/jiterator.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Callable, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
__all__: List[str] = []
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class _CodeParser:
|
| 11 |
+
def __init__(self, code_string: str):
|
| 12 |
+
optional_ws = r"\s*"
|
| 13 |
+
required_ws = r"\s+"
|
| 14 |
+
template_params = r"(?P<template_params>\<.+\>)"
|
| 15 |
+
return_type = r"(?P<return_type>\w+)"
|
| 16 |
+
function_name = r"(?P<function_name>\w+)"
|
| 17 |
+
function_params = r"(?P<function_params>\(.+\))"
|
| 18 |
+
function_body = r"(?P<function_body>\{.+\})"
|
| 19 |
+
|
| 20 |
+
pattern = (
|
| 21 |
+
optional_ws
|
| 22 |
+
+ "template"
|
| 23 |
+
+ optional_ws
|
| 24 |
+
+ template_params
|
| 25 |
+
+ optional_ws
|
| 26 |
+
+ return_type
|
| 27 |
+
+ required_ws
|
| 28 |
+
+ function_name
|
| 29 |
+
+ optional_ws
|
| 30 |
+
+ function_params
|
| 31 |
+
+ optional_ws
|
| 32 |
+
+ function_body
|
| 33 |
+
+ optional_ws
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
result = re.match(
|
| 37 |
+
pattern, code_string, re.DOTALL
|
| 38 |
+
) # DOTALL for matching multiline
|
| 39 |
+
|
| 40 |
+
if result is None:
|
| 41 |
+
raise Exception(
|
| 42 |
+
f"Couldn't parse code, please check correctness:\n {code_string}"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.template_params = result["template_params"]
|
| 46 |
+
self.return_type = result["return_type"]
|
| 47 |
+
self.function_name = result["function_name"]
|
| 48 |
+
self.function_params = result["function_params"]
|
| 49 |
+
self.function_body = result["function_body"]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class _JittedFunction:
|
| 53 |
+
def __init__(
|
| 54 |
+
self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
|
| 55 |
+
):
|
| 56 |
+
self.code_string = code_string
|
| 57 |
+
|
| 58 |
+
assert (
|
| 59 |
+
return_by_ref or num_outputs == 1
|
| 60 |
+
), "Return by value only works for single output. "
|
| 61 |
+
self.return_by_ref = return_by_ref
|
| 62 |
+
self.num_outputs = num_outputs
|
| 63 |
+
|
| 64 |
+
parsed_code = _CodeParser(code_string)
|
| 65 |
+
self.kernel_name = parsed_code.function_name
|
| 66 |
+
|
| 67 |
+
self.kwargs_dict = kwargs
|
| 68 |
+
self.is_cuda_available = torch.cuda.is_available()
|
| 69 |
+
|
| 70 |
+
def __call__(self, *tensors: Tensor, **kwargs):
|
| 71 |
+
# Jiterator follow torch.cuda's lazy initialization behavior
|
| 72 |
+
# Defer checking cuda's availability at the function invocation time
|
| 73 |
+
assert (
|
| 74 |
+
self.is_cuda_available
|
| 75 |
+
), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
|
| 76 |
+
|
| 77 |
+
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
|
| 78 |
+
|
| 79 |
+
expanded_kwargs = self.kwargs_dict.copy()
|
| 80 |
+
for key, value in kwargs.items():
|
| 81 |
+
if key in self.kwargs_dict:
|
| 82 |
+
expanded_kwargs[key] = value
|
| 83 |
+
else:
|
| 84 |
+
raise KeyError(f"{key} is not declared in function definition")
|
| 85 |
+
|
| 86 |
+
return torch._C._cuda_jiterator_compile_and_launch_kernel(
|
| 87 |
+
self.code_string,
|
| 88 |
+
self.kernel_name,
|
| 89 |
+
self.return_by_ref,
|
| 90 |
+
self.num_outputs,
|
| 91 |
+
tensors,
|
| 92 |
+
expanded_kwargs,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
|
| 97 |
+
"""
|
| 98 |
+
Create a jiterator-generated cuda kernel for an elementwise op.
|
| 99 |
+
|
| 100 |
+
The code string has to be a valid CUDA function that describes the computation for a single element. The code
|
| 101 |
+
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
|
| 102 |
+
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
|
| 103 |
+
local temp dir.
|
| 104 |
+
|
| 105 |
+
Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
|
| 109 |
+
kwargs (Dict, optional): Keyword arguments for generated function
|
| 110 |
+
|
| 111 |
+
Example::
|
| 112 |
+
|
| 113 |
+
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
|
| 114 |
+
jitted_fn = create_jit_fn(code_string, alpha=1.0)
|
| 115 |
+
a = torch.rand(3, device='cuda')
|
| 116 |
+
b = torch.rand(3, device='cuda')
|
| 117 |
+
# invoke jitted function like a regular python function
|
| 118 |
+
result = jitted_fn(a, b, alpha=3.14)
|
| 119 |
+
|
| 120 |
+
code_string also allows multiple function definitions, and the last function will be treated as the entry function.
|
| 121 |
+
|
| 122 |
+
Example::
|
| 123 |
+
|
| 124 |
+
code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
|
| 125 |
+
code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
|
| 126 |
+
jitted_fn = create_jit_fn(code_string, val=0.0)
|
| 127 |
+
a = torch.rand(3, device='cuda')
|
| 128 |
+
b = torch.rand(3, device='cuda')
|
| 129 |
+
# invoke jitted function like a regular python function
|
| 130 |
+
result = jitted_fn(a, b) # using default val=0.0
|
| 131 |
+
|
| 132 |
+
Jiterator can be used together with python registration to override an operator's cuda kernel.
|
| 133 |
+
Following example is overriding gelu's cuda kernel with relu.
|
| 134 |
+
|
| 135 |
+
Example::
|
| 136 |
+
|
| 137 |
+
code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
|
| 138 |
+
my_gelu = create_jit_fn(code_string)
|
| 139 |
+
my_lib = torch.library.Library("aten", "IMPL")
|
| 140 |
+
my_lib.impl('aten::gelu', my_gelu, "CUDA")
|
| 141 |
+
# torch.nn.GELU and torch.nn.function.gelu are now overridden
|
| 142 |
+
a = torch.rand(3, device='cuda')
|
| 143 |
+
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
|
| 144 |
+
|
| 145 |
+
.. warning::
|
| 146 |
+
This API is in beta and may change in future releases.
|
| 147 |
+
|
| 148 |
+
.. warning::
|
| 149 |
+
This API only supports up to 8 inputs and 1 output
|
| 150 |
+
|
| 151 |
+
.. warning::
|
| 152 |
+
All input tensors must live in CUDA device
|
| 153 |
+
"""
|
| 154 |
+
return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _create_multi_output_jit_fn(
|
| 158 |
+
code_string: str, num_outputs: int, **kwargs
|
| 159 |
+
) -> Callable:
|
| 160 |
+
"""
|
| 161 |
+
Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
|
| 165 |
+
num_outputs(int): number of outputs return by the kernel
|
| 166 |
+
kwargs (Dict, optional): Keyword arguments for generated function
|
| 167 |
+
|
| 168 |
+
Example::
|
| 169 |
+
|
| 170 |
+
code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
|
| 171 |
+
jitted_fn = create_jit_fn(code_string, alpha=1.0)
|
| 172 |
+
a = torch.rand(3, device='cuda')
|
| 173 |
+
b = torch.rand(3, device='cuda')
|
| 174 |
+
# invoke jitted function like a regular python function
|
| 175 |
+
result = jitted_fn(a, b, alpha=3.14)
|
| 176 |
+
|
| 177 |
+
.. warning::
|
| 178 |
+
This API is in beta and may change in future releases.
|
| 179 |
+
|
| 180 |
+
.. warning::
|
| 181 |
+
This API only supports up to 8 inputs and 8 outputs
|
| 182 |
+
"""
|
| 183 |
+
return _JittedFunction(
|
| 184 |
+
code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
|
| 185 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/nccl.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import warnings
|
| 3 |
+
from typing import Optional, Sequence, Union
|
| 4 |
+
|
| 5 |
+
import torch.cuda
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
|
| 9 |
+
|
| 10 |
+
SUM = 0 # ncclRedOp_t
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def is_available(tensors):
|
| 14 |
+
if not hasattr(torch._C, "_nccl_all_reduce"):
|
| 15 |
+
warnings.warn("PyTorch is not compiled with NCCL support")
|
| 16 |
+
return False
|
| 17 |
+
|
| 18 |
+
devices = set()
|
| 19 |
+
for tensor in tensors:
|
| 20 |
+
if tensor.is_sparse:
|
| 21 |
+
return False
|
| 22 |
+
if not tensor.is_contiguous():
|
| 23 |
+
return False
|
| 24 |
+
if not tensor.is_cuda:
|
| 25 |
+
return False
|
| 26 |
+
device = tensor.get_device()
|
| 27 |
+
if device in devices:
|
| 28 |
+
return False
|
| 29 |
+
devices.add(device)
|
| 30 |
+
|
| 31 |
+
return True
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def version():
|
| 35 |
+
ver = torch._C._nccl_version()
|
| 36 |
+
major = ver >> 32
|
| 37 |
+
minor = (ver >> 16) & 65535
|
| 38 |
+
patch = ver & 65535
|
| 39 |
+
suffix = torch._C._nccl_version_suffix().decode("utf-8")
|
| 40 |
+
if suffix == "":
|
| 41 |
+
return (major, minor, patch)
|
| 42 |
+
else:
|
| 43 |
+
return (major, minor, patch, suffix)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def unique_id():
|
| 47 |
+
return torch._C._nccl_unique_id()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def init_rank(num_ranks, uid, rank):
|
| 51 |
+
return torch._C._nccl_init_rank(num_ranks, uid, rank)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
|
| 55 |
+
if not isinstance(inputs, collections.abc.Container) or isinstance(
|
| 56 |
+
inputs, torch.Tensor
|
| 57 |
+
):
|
| 58 |
+
raise TypeError("Inputs should be a collection of tensors")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
|
| 62 |
+
_check_sequence_type(inputs)
|
| 63 |
+
if outputs is None:
|
| 64 |
+
outputs = inputs
|
| 65 |
+
_check_sequence_type(outputs)
|
| 66 |
+
torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# `output` used to be `outputs`, taking in a list of tensors. So we have two
|
| 70 |
+
# arguments for BC reasons.
|
| 71 |
+
def reduce(
|
| 72 |
+
inputs: Sequence[torch.Tensor],
|
| 73 |
+
output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
|
| 74 |
+
root: int = 0,
|
| 75 |
+
op: int = SUM,
|
| 76 |
+
streams: Optional[Sequence[torch.cuda.Stream]] = None,
|
| 77 |
+
comms=None,
|
| 78 |
+
*,
|
| 79 |
+
outputs: Optional[Sequence[torch.Tensor]] = None,
|
| 80 |
+
) -> None:
|
| 81 |
+
_check_sequence_type(inputs)
|
| 82 |
+
_output: torch.Tensor
|
| 83 |
+
if outputs is not None:
|
| 84 |
+
if output is not None:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
"'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
|
| 87 |
+
"favor of 'output', taking in a single output tensor. The signature of reduce is: "
|
| 88 |
+
"reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
warnings.warn(
|
| 92 |
+
"nccl.reduce with an output tensor list is deprecated. "
|
| 93 |
+
"Please specify a single output tensor with argument 'output' instead instead."
|
| 94 |
+
)
|
| 95 |
+
_output = outputs[root]
|
| 96 |
+
elif not isinstance(output, torch.Tensor) and isinstance(
|
| 97 |
+
output, collections.abc.Sequence
|
| 98 |
+
):
|
| 99 |
+
# User called old API with positional arguments of list of output tensors.
|
| 100 |
+
warnings.warn(
|
| 101 |
+
"nccl.reduce with an output tensor list is deprecated. "
|
| 102 |
+
"Please specify a single output tensor."
|
| 103 |
+
)
|
| 104 |
+
_output = output[root]
|
| 105 |
+
else:
|
| 106 |
+
_output = inputs[root] if output is None else output
|
| 107 |
+
torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def broadcast(
|
| 111 |
+
inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
|
| 112 |
+
) -> None:
|
| 113 |
+
_check_sequence_type(inputs)
|
| 114 |
+
torch._C._nccl_broadcast(inputs, root, streams, comms)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def all_gather(
|
| 118 |
+
inputs: Sequence[torch.Tensor],
|
| 119 |
+
outputs: Sequence[torch.Tensor],
|
| 120 |
+
streams=None,
|
| 121 |
+
comms=None,
|
| 122 |
+
) -> None:
|
| 123 |
+
_check_sequence_type(inputs)
|
| 124 |
+
_check_sequence_type(outputs)
|
| 125 |
+
torch._C._nccl_all_gather(inputs, outputs, streams, comms)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def reduce_scatter(
|
| 129 |
+
inputs: Sequence[torch.Tensor],
|
| 130 |
+
outputs: Sequence[torch.Tensor],
|
| 131 |
+
op: int = SUM,
|
| 132 |
+
streams=None,
|
| 133 |
+
comms=None,
|
| 134 |
+
) -> None:
|
| 135 |
+
_check_sequence_type(inputs)
|
| 136 |
+
_check_sequence_type(outputs)
|
| 137 |
+
torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/profiler.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import tempfile
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from . import check_error, cudart
|
| 6 |
+
|
| 7 |
+
__all__ = ["init", "start", "stop", "profile"]
|
| 8 |
+
|
| 9 |
+
DEFAULT_FLAGS = [
|
| 10 |
+
"gpustarttimestamp",
|
| 11 |
+
"gpuendtimestamp",
|
| 12 |
+
"gridsize3d",
|
| 13 |
+
"threadblocksize",
|
| 14 |
+
"streamid",
|
| 15 |
+
"enableonstart 0",
|
| 16 |
+
"conckerneltrace",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def init(output_file, flags=None, output_mode="key_value"):
|
| 21 |
+
rt = cudart()
|
| 22 |
+
if not hasattr(rt, "cudaOutputMode"):
|
| 23 |
+
raise AssertionError("HIP does not support profiler initialization!")
|
| 24 |
+
if (
|
| 25 |
+
hasattr(torch.version, "cuda")
|
| 26 |
+
and torch.version.cuda is not None
|
| 27 |
+
and int(torch.version.cuda.split(".")[0]) >= 12
|
| 28 |
+
):
|
| 29 |
+
# Check https://github.com/pytorch/pytorch/pull/91118
|
| 30 |
+
# cudaProfilerInitialize is no longer needed after CUDA 12
|
| 31 |
+
raise AssertionError("CUDA12+ does not need profiler initialization!")
|
| 32 |
+
flags = DEFAULT_FLAGS if flags is None else flags
|
| 33 |
+
if output_mode == "key_value":
|
| 34 |
+
output_mode_enum = rt.cudaOutputMode.KeyValuePair
|
| 35 |
+
elif output_mode == "csv":
|
| 36 |
+
output_mode_enum = rt.cudaOutputMode.CSV
|
| 37 |
+
else:
|
| 38 |
+
raise RuntimeError(
|
| 39 |
+
"supported CUDA profiler output modes are: key_value and csv"
|
| 40 |
+
)
|
| 41 |
+
with tempfile.NamedTemporaryFile(delete=True) as f:
|
| 42 |
+
f.write(b"\n".join(f.encode("ascii") for f in flags))
|
| 43 |
+
f.flush()
|
| 44 |
+
check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def start():
|
| 48 |
+
check_error(cudart().cudaProfilerStart())
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def stop():
|
| 52 |
+
check_error(cudart().cudaProfilerStop())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@contextlib.contextmanager
|
| 56 |
+
def profile():
|
| 57 |
+
try:
|
| 58 |
+
start()
|
| 59 |
+
yield
|
| 60 |
+
finally:
|
| 61 |
+
stop()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/random.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterable, List, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from .. import Tensor
|
| 5 |
+
from . import _lazy_call, _lazy_init, current_device, device_count
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"get_rng_state",
|
| 9 |
+
"get_rng_state_all",
|
| 10 |
+
"set_rng_state",
|
| 11 |
+
"set_rng_state_all",
|
| 12 |
+
"manual_seed",
|
| 13 |
+
"manual_seed_all",
|
| 14 |
+
"seed",
|
| 15 |
+
"seed_all",
|
| 16 |
+
"initial_seed",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
|
| 21 |
+
r"""Return the random number generator state of the specified GPU as a ByteTensor.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
device (torch.device or int, optional): The device to return the RNG state of.
|
| 25 |
+
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
|
| 26 |
+
|
| 27 |
+
.. warning::
|
| 28 |
+
This function eagerly initializes CUDA.
|
| 29 |
+
"""
|
| 30 |
+
_lazy_init()
|
| 31 |
+
if isinstance(device, str):
|
| 32 |
+
device = torch.device(device)
|
| 33 |
+
elif isinstance(device, int):
|
| 34 |
+
device = torch.device("cuda", device)
|
| 35 |
+
idx = device.index
|
| 36 |
+
if idx is None:
|
| 37 |
+
idx = current_device()
|
| 38 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 39 |
+
return default_generator.get_state()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_rng_state_all() -> List[Tensor]:
|
| 43 |
+
r"""Return a list of ByteTensor representing the random number states of all devices."""
|
| 44 |
+
results = []
|
| 45 |
+
for i in range(device_count()):
|
| 46 |
+
results.append(get_rng_state(i))
|
| 47 |
+
return results
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def set_rng_state(
|
| 51 |
+
new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
|
| 52 |
+
) -> None:
|
| 53 |
+
r"""Set the random number generator state of the specified GPU.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
new_state (torch.ByteTensor): The desired state
|
| 57 |
+
device (torch.device or int, optional): The device to set the RNG state.
|
| 58 |
+
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
|
| 59 |
+
"""
|
| 60 |
+
with torch._C._DisableFuncTorch():
|
| 61 |
+
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
|
| 62 |
+
if isinstance(device, str):
|
| 63 |
+
device = torch.device(device)
|
| 64 |
+
elif isinstance(device, int):
|
| 65 |
+
device = torch.device("cuda", device)
|
| 66 |
+
|
| 67 |
+
def cb():
|
| 68 |
+
idx = device.index
|
| 69 |
+
if idx is None:
|
| 70 |
+
idx = current_device()
|
| 71 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 72 |
+
default_generator.set_state(new_state_copy)
|
| 73 |
+
|
| 74 |
+
_lazy_call(cb)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
|
| 78 |
+
r"""Set the random number generator state of all devices.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
new_states (Iterable of torch.ByteTensor): The desired state for each device.
|
| 82 |
+
"""
|
| 83 |
+
for i, state in enumerate(new_states):
|
| 84 |
+
set_rng_state(state, i)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def manual_seed(seed: int) -> None:
|
| 88 |
+
r"""Set the seed for generating random numbers for the current GPU.
|
| 89 |
+
|
| 90 |
+
It's safe to call this function if CUDA is not available; in that
|
| 91 |
+
case, it is silently ignored.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
seed (int): The desired seed.
|
| 95 |
+
|
| 96 |
+
.. warning::
|
| 97 |
+
If you are working with a multi-GPU model, this function is insufficient
|
| 98 |
+
to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
|
| 99 |
+
"""
|
| 100 |
+
seed = int(seed)
|
| 101 |
+
|
| 102 |
+
def cb():
|
| 103 |
+
idx = current_device()
|
| 104 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 105 |
+
default_generator.manual_seed(seed)
|
| 106 |
+
|
| 107 |
+
_lazy_call(cb, seed=True)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def manual_seed_all(seed: int) -> None:
|
| 111 |
+
r"""Set the seed for generating random numbers on all GPUs.
|
| 112 |
+
|
| 113 |
+
It's safe to call this function if CUDA is not available; in that
|
| 114 |
+
case, it is silently ignored.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
seed (int): The desired seed.
|
| 118 |
+
"""
|
| 119 |
+
seed = int(seed)
|
| 120 |
+
|
| 121 |
+
def cb():
|
| 122 |
+
for i in range(device_count()):
|
| 123 |
+
default_generator = torch.cuda.default_generators[i]
|
| 124 |
+
default_generator.manual_seed(seed)
|
| 125 |
+
|
| 126 |
+
_lazy_call(cb, seed_all=True)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def seed() -> None:
|
| 130 |
+
r"""Set the seed for generating random numbers to a random number for the current GPU.
|
| 131 |
+
|
| 132 |
+
It's safe to call this function if CUDA is not available; in that
|
| 133 |
+
case, it is silently ignored.
|
| 134 |
+
|
| 135 |
+
.. warning::
|
| 136 |
+
If you are working with a multi-GPU model, this function will only initialize
|
| 137 |
+
the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def cb():
|
| 141 |
+
idx = current_device()
|
| 142 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 143 |
+
default_generator.seed()
|
| 144 |
+
|
| 145 |
+
_lazy_call(cb)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def seed_all() -> None:
|
| 149 |
+
r"""Set the seed for generating random numbers to a random number on all GPUs.
|
| 150 |
+
|
| 151 |
+
It's safe to call this function if CUDA is not available; in that
|
| 152 |
+
case, it is silently ignored.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def cb():
|
| 156 |
+
random_seed = 0
|
| 157 |
+
seeded = False
|
| 158 |
+
for i in range(device_count()):
|
| 159 |
+
default_generator = torch.cuda.default_generators[i]
|
| 160 |
+
if not seeded:
|
| 161 |
+
default_generator.seed()
|
| 162 |
+
random_seed = default_generator.initial_seed()
|
| 163 |
+
seeded = True
|
| 164 |
+
else:
|
| 165 |
+
default_generator.manual_seed(random_seed)
|
| 166 |
+
|
| 167 |
+
_lazy_call(cb)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def initial_seed() -> int:
|
| 171 |
+
r"""Return the current random seed of the current GPU.
|
| 172 |
+
|
| 173 |
+
.. warning::
|
| 174 |
+
This function eagerly initializes CUDA.
|
| 175 |
+
"""
|
| 176 |
+
_lazy_init()
|
| 177 |
+
idx = current_device()
|
| 178 |
+
default_generator = torch.cuda.default_generators[idx]
|
| 179 |
+
return default_generator.initial_seed()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/sparse.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# The Tensor classes are added to this module by python_tensor.cpp
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/streams.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch._streambase import _EventBase, _StreamBase
|
| 5 |
+
from .._utils import _dummy_type
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
if not hasattr(torch._C, "_CudaStreamBase"):
|
| 9 |
+
# Define dummy base classes
|
| 10 |
+
torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase")
|
| 11 |
+
torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Stream(torch._C._CudaStreamBase, _StreamBase):
|
| 15 |
+
r"""Wrapper around a CUDA stream.
|
| 16 |
+
|
| 17 |
+
A CUDA stream is a linear sequence of execution that belongs to a specific
|
| 18 |
+
device, independent from other streams. See :ref:`cuda-semantics` for
|
| 19 |
+
details.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
device(torch.device or int, optional): a device on which to allocate
|
| 23 |
+
the stream. If :attr:`device` is ``None`` (default) or a negative
|
| 24 |
+
integer, this will use the current device.
|
| 25 |
+
priority(int, optional): priority of the stream, should be 0 or
|
| 26 |
+
negative, where negative numbers indicate higher priority. By default,
|
| 27 |
+
streams have priority 0.
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __new__(cls, device=None, priority=0, **kwargs):
|
| 32 |
+
# setting device manager is expensive, so we avoid it unless necessary
|
| 33 |
+
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
|
| 34 |
+
return super().__new__(cls, priority=priority, **kwargs)
|
| 35 |
+
else:
|
| 36 |
+
with torch.cuda.device(device):
|
| 37 |
+
return super().__new__(cls, priority=priority, **kwargs)
|
| 38 |
+
|
| 39 |
+
def wait_event(self, event):
|
| 40 |
+
r"""Make all future work submitted to the stream wait for an event.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
event (torch.cuda.Event): an event to wait for.
|
| 44 |
+
|
| 45 |
+
.. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
|
| 46 |
+
`CUDA Stream documentation`_ for more info.
|
| 47 |
+
|
| 48 |
+
This function returns without waiting for :attr:`event`: only future
|
| 49 |
+
operations are affected.
|
| 50 |
+
|
| 51 |
+
.. _CUDA Stream documentation:
|
| 52 |
+
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
|
| 53 |
+
"""
|
| 54 |
+
event.wait(self)
|
| 55 |
+
|
| 56 |
+
def wait_stream(self, stream):
|
| 57 |
+
r"""Synchronize with another stream.
|
| 58 |
+
|
| 59 |
+
All future work submitted to this stream will wait until all kernels
|
| 60 |
+
submitted to a given stream at the time of call complete.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
stream (Stream): a stream to synchronize.
|
| 64 |
+
|
| 65 |
+
.. note:: This function returns without waiting for currently enqueued
|
| 66 |
+
kernels in :attr:`stream`: only future operations are affected.
|
| 67 |
+
"""
|
| 68 |
+
self.wait_event(stream.record_event())
|
| 69 |
+
|
| 70 |
+
def record_event(self, event=None):
|
| 71 |
+
r"""Record an event.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
event (torch.cuda.Event, optional): event to record. If not given, a new one
|
| 75 |
+
will be allocated.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Recorded event.
|
| 79 |
+
"""
|
| 80 |
+
if event is None:
|
| 81 |
+
event = Event()
|
| 82 |
+
event.record(self)
|
| 83 |
+
return event
|
| 84 |
+
|
| 85 |
+
def query(self):
|
| 86 |
+
r"""Check if all the work submitted has been completed.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
A boolean indicating if all kernels in this stream are completed.
|
| 90 |
+
"""
|
| 91 |
+
return super().query()
|
| 92 |
+
|
| 93 |
+
def synchronize(self):
|
| 94 |
+
r"""Wait for all the kernels in this stream to complete.
|
| 95 |
+
|
| 96 |
+
.. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
|
| 97 |
+
`CUDA Stream documentation`_ for more info.
|
| 98 |
+
"""
|
| 99 |
+
super().synchronize()
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def _as_parameter_(self):
|
| 103 |
+
return ctypes.c_void_p(self.cuda_stream)
|
| 104 |
+
|
| 105 |
+
def __eq__(self, o):
|
| 106 |
+
if isinstance(o, Stream):
|
| 107 |
+
return super().__eq__(o)
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
def __hash__(self):
|
| 111 |
+
return hash((self.cuda_stream, self.device))
|
| 112 |
+
|
| 113 |
+
def __repr__(self):
|
| 114 |
+
return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ExternalStream(Stream):
|
| 118 |
+
r"""Wrapper around an externally allocated CUDA stream.
|
| 119 |
+
|
| 120 |
+
This class is used to wrap streams allocated in other libraries in order
|
| 121 |
+
to facilitate data exchange and multi-library interactions.
|
| 122 |
+
|
| 123 |
+
.. note:: This class doesn't manage the stream life-cycle, it is the user
|
| 124 |
+
responsibility to keep the referenced stream alive while this class is
|
| 125 |
+
being used.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
stream_ptr(int): Integer representation of the `cudaStream_t` value.
|
| 129 |
+
allocated externally.
|
| 130 |
+
device(torch.device or int, optional): the device where the stream
|
| 131 |
+
was originally allocated. if device is specified incorrectly,
|
| 132 |
+
subsequent launches using this stream may fail.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __new__(cls, stream_ptr, device=None, **kwargs):
|
| 136 |
+
with torch.cuda.device(device):
|
| 137 |
+
return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class Event(torch._C._CudaEventBase, _EventBase):
|
| 141 |
+
r"""Wrapper around a CUDA event.
|
| 142 |
+
|
| 143 |
+
CUDA events are synchronization markers that can be used to monitor the
|
| 144 |
+
device's progress, to accurately measure timing, and to synchronize CUDA
|
| 145 |
+
streams.
|
| 146 |
+
|
| 147 |
+
The underlying CUDA events are lazily initialized when the event is first
|
| 148 |
+
recorded or exported to another process. After creation, only streams on the
|
| 149 |
+
same device may record the event. However, streams on any device can wait on
|
| 150 |
+
the event.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
enable_timing (bool, optional): indicates if the event should measure time
|
| 154 |
+
(default: ``False``)
|
| 155 |
+
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
|
| 156 |
+
interprocess (bool): if ``True``, the event can be shared between processes
|
| 157 |
+
(default: ``False``)
|
| 158 |
+
|
| 159 |
+
.. _CUDA Event Documentation:
|
| 160 |
+
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
|
| 164 |
+
return super().__new__(
|
| 165 |
+
cls,
|
| 166 |
+
enable_timing=enable_timing,
|
| 167 |
+
blocking=blocking,
|
| 168 |
+
interprocess=interprocess,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def from_ipc_handle(cls, device, handle):
|
| 173 |
+
r"""Reconstruct an event from an IPC handle on the given device."""
|
| 174 |
+
return super().from_ipc_handle(device, handle)
|
| 175 |
+
|
| 176 |
+
def record(self, stream=None):
|
| 177 |
+
r"""Record the event in a given stream.
|
| 178 |
+
|
| 179 |
+
Uses ``torch.cuda.current_stream()`` if no stream is specified. The
|
| 180 |
+
stream's device must match the event's device.
|
| 181 |
+
"""
|
| 182 |
+
if stream is None:
|
| 183 |
+
stream = torch.cuda.current_stream()
|
| 184 |
+
super().record(stream)
|
| 185 |
+
|
| 186 |
+
def wait(self, stream=None):
|
| 187 |
+
r"""Make all future work submitted to the given stream wait for this event.
|
| 188 |
+
|
| 189 |
+
Use ``torch.cuda.current_stream()`` if no stream is specified.
|
| 190 |
+
|
| 191 |
+
.. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
|
| 192 |
+
`CUDA Event documentation`_ for more info.
|
| 193 |
+
"""
|
| 194 |
+
if stream is None:
|
| 195 |
+
stream = torch.cuda.current_stream()
|
| 196 |
+
super().wait(stream)
|
| 197 |
+
|
| 198 |
+
def query(self):
|
| 199 |
+
r"""Check if all work currently captured by event has completed.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
A boolean indicating if all work currently captured by event has
|
| 203 |
+
completed.
|
| 204 |
+
"""
|
| 205 |
+
return super().query()
|
| 206 |
+
|
| 207 |
+
def elapsed_time(self, end_event):
|
| 208 |
+
r"""Return the time elapsed.
|
| 209 |
+
|
| 210 |
+
Time reported in milliseconds after the event was recorded and
|
| 211 |
+
before the end_event was recorded.
|
| 212 |
+
"""
|
| 213 |
+
return super().elapsed_time(end_event)
|
| 214 |
+
|
| 215 |
+
def synchronize(self):
|
| 216 |
+
r"""Wait for the event to complete.
|
| 217 |
+
|
| 218 |
+
Waits until the completion of all work currently captured in this event.
|
| 219 |
+
This prevents the CPU thread from proceeding until the event completes.
|
| 220 |
+
|
| 221 |
+
.. note:: This is a wrapper around ``cudaEventSynchronize()``: see
|
| 222 |
+
`CUDA Event documentation`_ for more info.
|
| 223 |
+
"""
|
| 224 |
+
super().synchronize()
|
| 225 |
+
|
| 226 |
+
def ipc_handle(self):
|
| 227 |
+
r"""Return an IPC handle of this event.
|
| 228 |
+
|
| 229 |
+
If not recorded yet, the event will use the current device.
|
| 230 |
+
"""
|
| 231 |
+
return super().ipc_handle()
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def _as_parameter_(self):
|
| 235 |
+
return ctypes.c_void_p(self.cuda_event)
|
| 236 |
+
|
| 237 |
+
def __repr__(self):
|
| 238 |
+
if self.cuda_event:
|
| 239 |
+
return f"<torch.cuda.Event {self._as_parameter_.value:#x}>"
|
| 240 |
+
else:
|
| 241 |
+
return "<torch.cuda.Event uninitialized>"
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-311.pyc
ADDED
|
Binary file (9.08 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/graph.cpython-311.pyc
ADDED
|
Binary file (86.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/interpreter.cpython-311.pyc
ADDED
|
Binary file (29.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/operator_schemas.cpython-311.pyc
ADDED
|
Binary file (23.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-311.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ATen.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#if !defined(_MSC_VER) && __cplusplus < 201703L
|
| 4 |
+
#error C++17 or later compatible compiler is required to use ATen.
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
#include <ATen/Context.h>
|
| 8 |
+
#include <ATen/Device.h>
|
| 9 |
+
#include <ATen/DeviceGuard.h>
|
| 10 |
+
#include <ATen/DimVector.h>
|
| 11 |
+
#include <ATen/Dispatch.h>
|
| 12 |
+
#include <ATen/Formatting.h>
|
| 13 |
+
#include <ATen/Functions.h>
|
| 14 |
+
#include <ATen/NamedTensor.h>
|
| 15 |
+
#include <ATen/ScalarOps.h>
|
| 16 |
+
#include <ATen/Tensor.h>
|
| 17 |
+
#include <ATen/TensorGeometry.h>
|
| 18 |
+
#include <ATen/TensorIndexing.h>
|
| 19 |
+
#include <ATen/TensorOperators.h>
|
| 20 |
+
#include <ATen/Version.h>
|
| 21 |
+
#include <ATen/core/ATenGeneral.h>
|
| 22 |
+
#include <ATen/core/Generator.h>
|
| 23 |
+
#include <ATen/core/Reduction.h>
|
| 24 |
+
#include <ATen/core/Scalar.h>
|
| 25 |
+
#include <ATen/core/UnsafeFromTH.h>
|
| 26 |
+
#include <ATen/core/ivalue.h>
|
| 27 |
+
#include <ATen/core/jit_type.h>
|
| 28 |
+
#include <c10/core/Allocator.h>
|
| 29 |
+
#include <c10/core/InferenceMode.h>
|
| 30 |
+
#include <c10/core/Layout.h>
|
| 31 |
+
#include <c10/core/Storage.h>
|
| 32 |
+
#include <c10/core/TensorOptions.h>
|
| 33 |
+
#include <c10/util/Exception.h>
|
| 34 |
+
|
| 35 |
+
// TODO: try to remove this
|
| 36 |
+
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
|
| 37 |
+
#include <ATen/NativeFunctions.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/AccumulateType.h
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/Config.h>
|
| 3 |
+
#include <c10/core/DeviceType.h>
|
| 4 |
+
#include <c10/core/ScalarType.h>
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
#include <c10/util/Float8_e4m3fn.h>
|
| 7 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
| 8 |
+
#include <c10/util/Float8_e5m2.h>
|
| 9 |
+
#include <c10/util/Float8_e5m2fnuz.h>
|
| 10 |
+
#include <c10/util/Half.h>
|
| 11 |
+
|
| 12 |
+
// Defines the accumulation type for a scalar type.
|
| 13 |
+
// Example:
|
| 14 |
+
// using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>;
|
| 15 |
+
//
|
| 16 |
+
// Accumulation types are an important concept in numeric computing
|
| 17 |
+
// because you frequently want to perform intermediate computations
|
| 18 |
+
// at a higher precision than the input and output precision, to avoid
|
| 19 |
+
// compounding internal rounding errors. Accumulation is the most
|
| 20 |
+
// well-known intermediate computation (it is of great importance for
|
| 21 |
+
// sum reduction and matrix multiply, for example), but in PyTorch
|
| 22 |
+
// acc_type ends up getting used for all sorts of other intermediate
|
| 23 |
+
// computations, so it perhaps would be more accurately (ahem) called an
|
| 24 |
+
// "accurate" type. acc_type is especially important for reduced
|
| 25 |
+
// precision operations like float16 and bfloat16, where relatively
|
| 26 |
+
// benign looking inputs can easily end up overflowing/underflowing.
|
| 27 |
+
//
|
| 28 |
+
// acc_type is parametrized by whether or not you are running on CUDA
|
| 29 |
+
// or not, because on CUDA double precision operations are expensive
|
| 30 |
+
// and so by default, we don't actually want to use double as an
|
| 31 |
+
// acc_type on CUDA. A lot of things are typed out below, but
|
| 32 |
+
// basically, the table is generated by a few rules:
|
| 33 |
+
//
|
| 34 |
+
// If bool:
|
| 35 |
+
// Use 'bool' as acc_type.
|
| 36 |
+
// If floating point:
|
| 37 |
+
// If CUDA, use 'float' as acc_type (unless scalar_t is double),
|
| 38 |
+
// otherwise (CPU) use 'double'
|
| 39 |
+
// If integral:
|
| 40 |
+
// Use 'int64_t' as acc_type
|
| 41 |
+
//
|
| 42 |
+
// You're not forced to use this template; if you happen to know
|
| 43 |
+
// something specific about your use case, you can specify your own
|
| 44 |
+
// desired behavior. This template, however, will give you a reasonable
|
| 45 |
+
// default that will work for all dtypes supported in PyTorch.
|
| 46 |
+
|
| 47 |
+
#if defined(__CUDACC__)
|
| 48 |
+
#include <cuda.h>
|
| 49 |
+
#include <cuda_fp16.h>
|
| 50 |
+
#elif defined(__HIPCC__)
|
| 51 |
+
#include <hip/hip_fp16.h>
|
| 52 |
+
#include <hip/hip_runtime.h>
|
| 53 |
+
#endif
|
| 54 |
+
|
| 55 |
+
namespace at {
|
| 56 |
+
|
| 57 |
+
template <typename T, c10::DeviceType D>
|
| 58 |
+
struct AccumulateTypeDevice {};
|
| 59 |
+
|
| 60 |
+
template <typename T, bool>
|
| 61 |
+
struct AccumulateType {};
|
| 62 |
+
|
| 63 |
+
template <typename T>
|
| 64 |
+
struct AccumulateType<T, false> {
|
| 65 |
+
using type = typename AccumulateTypeDevice<T, c10::DeviceType::CPU>::type;
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
template <typename T>
|
| 69 |
+
struct AccumulateType<T, true> {
|
| 70 |
+
using type = typename AccumulateTypeDevice<T, c10::DeviceType::CUDA>::type;
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template <typename T, c10::DeviceType device>
|
| 74 |
+
using acc_type_device = typename AccumulateTypeDevice<T, device>::type;
|
| 75 |
+
|
| 76 |
+
template <typename T, bool is_cuda>
|
| 77 |
+
using acc_type = typename AccumulateType<T, is_cuda>::type;
|
| 78 |
+
|
| 79 |
+
#define ACC_TYPE(t, acc_t, device_type) \
|
| 80 |
+
template <> \
|
| 81 |
+
struct AccumulateTypeDevice<t, device_type> { \
|
| 82 |
+
using type = acc_t; \
|
| 83 |
+
};
|
| 84 |
+
#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
|
| 85 |
+
#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
|
| 86 |
+
#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
|
| 87 |
+
|
| 88 |
+
MPS_ACC_TYPE(BFloat16, float);
|
| 89 |
+
MPS_ACC_TYPE(Half, float);
|
| 90 |
+
MPS_ACC_TYPE(Float8_e5m2, float);
|
| 91 |
+
MPS_ACC_TYPE(Float8_e4m3fn, float);
|
| 92 |
+
MPS_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 93 |
+
MPS_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 94 |
+
MPS_ACC_TYPE(float, float);
|
| 95 |
+
MPS_ACC_TYPE(double, float);
|
| 96 |
+
MPS_ACC_TYPE(int8_t, int64_t);
|
| 97 |
+
MPS_ACC_TYPE(uint8_t, int64_t);
|
| 98 |
+
MPS_ACC_TYPE(char, int64_t);
|
| 99 |
+
MPS_ACC_TYPE(int16_t, int64_t);
|
| 100 |
+
MPS_ACC_TYPE(int32_t, int64_t);
|
| 101 |
+
MPS_ACC_TYPE(int64_t, int64_t);
|
| 102 |
+
MPS_ACC_TYPE(bool, bool);
|
| 103 |
+
MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 104 |
+
MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
| 105 |
+
MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
|
| 106 |
+
|
| 107 |
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 108 |
+
CUDA_ACC_TYPE(half, float);
|
| 109 |
+
#endif
|
| 110 |
+
CUDA_ACC_TYPE(BFloat16, float);
|
| 111 |
+
CUDA_ACC_TYPE(Half, float);
|
| 112 |
+
CUDA_ACC_TYPE(Float8_e5m2, float);
|
| 113 |
+
CUDA_ACC_TYPE(Float8_e4m3fn, float);
|
| 114 |
+
CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 115 |
+
CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 116 |
+
CUDA_ACC_TYPE(float, float);
|
| 117 |
+
CUDA_ACC_TYPE(double, double);
|
| 118 |
+
CUDA_ACC_TYPE(int8_t, int64_t);
|
| 119 |
+
CUDA_ACC_TYPE(uint8_t, int64_t);
|
| 120 |
+
CUDA_ACC_TYPE(char, int64_t);
|
| 121 |
+
CUDA_ACC_TYPE(int16_t, int64_t);
|
| 122 |
+
CUDA_ACC_TYPE(int32_t, int64_t);
|
| 123 |
+
CUDA_ACC_TYPE(int64_t, int64_t);
|
| 124 |
+
CUDA_ACC_TYPE(bool, bool);
|
| 125 |
+
CUDA_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 126 |
+
CUDA_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
| 127 |
+
CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
| 128 |
+
|
| 129 |
+
CPU_ACC_TYPE(BFloat16, float);
|
| 130 |
+
CPU_ACC_TYPE(Half, float);
|
| 131 |
+
CPU_ACC_TYPE(Float8_e5m2, float);
|
| 132 |
+
CPU_ACC_TYPE(Float8_e4m3fn, float);
|
| 133 |
+
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
|
| 134 |
+
CPU_ACC_TYPE(Float8_e4m3fnuz, float);
|
| 135 |
+
CPU_ACC_TYPE(float, double);
|
| 136 |
+
CPU_ACC_TYPE(double, double);
|
| 137 |
+
CPU_ACC_TYPE(int8_t, int64_t);
|
| 138 |
+
CPU_ACC_TYPE(uint8_t, int64_t);
|
| 139 |
+
CPU_ACC_TYPE(char, int64_t);
|
| 140 |
+
CPU_ACC_TYPE(int16_t, int64_t);
|
| 141 |
+
CPU_ACC_TYPE(int32_t, int64_t);
|
| 142 |
+
CPU_ACC_TYPE(int64_t, int64_t);
|
| 143 |
+
CPU_ACC_TYPE(bool, bool);
|
| 144 |
+
CPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
| 145 |
+
CPU_ACC_TYPE(c10::complex<float>, c10::complex<double>);
|
| 146 |
+
CPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
| 147 |
+
|
| 148 |
+
TORCH_API c10::ScalarType toAccumulateType(
|
| 149 |
+
c10::ScalarType type,
|
| 150 |
+
c10::DeviceType device);
|
| 151 |
+
TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);
|
| 152 |
+
|
| 153 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Backend.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Backend.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFixedAllocator.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Allocator.h>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
|
| 6 |
+
// This file creates a fake allocator that just throws exceptions if
|
| 7 |
+
// it is actually used.
|
| 8 |
+
|
| 9 |
+
// state passed to the allocator is the std::function<void(void*)> called
|
| 10 |
+
// when the blob is release by ATen
|
| 11 |
+
|
| 12 |
+
namespace at {
|
| 13 |
+
|
| 14 |
+
static cpu_fixed_malloc(void*, ptrdiff_t) {
|
| 15 |
+
AT_ERROR("attempting to resize a tensor view of an external blob");
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
static cpu_fixed_realloc(void*, void*, ptrdiff_t) {
|
| 19 |
+
AT_ERROR("attempting to resize a tensor view of an external blob");
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
static cpu_fixed_free(void* state, void* allocation) {
|
| 23 |
+
auto on_release = static_cast<std::function<void(void*)>*>(state);
|
| 24 |
+
(*on_release)(allocation);
|
| 25 |
+
delete on_release;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
static Allocator CPU_fixed_allocator = {
|
| 29 |
+
cpu_fixed_malloc,
|
| 30 |
+
cpu_fixed_realloc,
|
| 31 |
+
cpu_fixed_free};
|
| 32 |
+
|
| 33 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CollapseDims.h
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <c10/util/Exception.h>
|
| 2 |
+
#include <utility>
|
| 3 |
+
|
| 4 |
+
namespace at {
|
| 5 |
+
|
| 6 |
+
/*
|
| 7 |
+
[collapse dims] Updates sizes, and strides to reflect a "collapse" of
|
| 8 |
+
the info, possibly excluding the optional excludeDim. A "collapsed" version
|
| 9 |
+
of the info is the fewest dims that order the tensor's elements in the same
|
| 10 |
+
way as the original info. If excludeDim is specified, the collapse is the
|
| 11 |
+
fewest dims that order the tensor's elements as the original and preserve the
|
| 12 |
+
excluded dimension, unless the tensor collapses to a point.
|
| 13 |
+
|
| 14 |
+
This function returns a pair of values.
|
| 15 |
+
|
| 16 |
+
1) The (new) index of the preserved dimension if excludeDim is
|
| 17 |
+
specified. 0 if the tensor is collapsed to a point. -1
|
| 18 |
+
otherwise.
|
| 19 |
+
|
| 20 |
+
2) The new number of dimensions.
|
| 21 |
+
*/
|
| 22 |
+
template <typename T>
|
| 23 |
+
inline std::pair<int64_t, int64_t> collapse_dims(
|
| 24 |
+
T* sizes,
|
| 25 |
+
T* strides,
|
| 26 |
+
int64_t dims,
|
| 27 |
+
const int excludeDim = -1) {
|
| 28 |
+
TORCH_CHECK(
|
| 29 |
+
excludeDim >= -1 && excludeDim < dims,
|
| 30 |
+
"expected excluded dim between -1 and dims - 1");
|
| 31 |
+
|
| 32 |
+
int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
|
| 33 |
+
int64_t newIndex = -1;
|
| 34 |
+
int64_t oldIndex = 0;
|
| 35 |
+
int64_t remappedExcludedDim = -1;
|
| 36 |
+
|
| 37 |
+
while (oldIndex < dims) {
|
| 38 |
+
// Finds a dimension to collapse into
|
| 39 |
+
for (; oldIndex < stopDim; ++oldIndex) {
|
| 40 |
+
if (sizes[oldIndex] == 1) {
|
| 41 |
+
continue;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
++newIndex;
|
| 45 |
+
sizes[newIndex] = sizes[oldIndex];
|
| 46 |
+
strides[newIndex] = strides[oldIndex];
|
| 47 |
+
++oldIndex;
|
| 48 |
+
break;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// Collapses dims
|
| 52 |
+
for (; oldIndex < stopDim; ++oldIndex) {
|
| 53 |
+
if (sizes[oldIndex] == 1) {
|
| 54 |
+
continue;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
|
| 58 |
+
sizes[newIndex] *= sizes[oldIndex];
|
| 59 |
+
strides[newIndex] = strides[oldIndex];
|
| 60 |
+
} else {
|
| 61 |
+
++newIndex;
|
| 62 |
+
sizes[newIndex] = sizes[oldIndex];
|
| 63 |
+
strides[newIndex] = strides[oldIndex];
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// Handles excludeDim being set (oldIndex == excludeDim)
|
| 68 |
+
if (oldIndex != dims) {
|
| 69 |
+
// Preserves excluded dimension
|
| 70 |
+
++newIndex;
|
| 71 |
+
sizes[newIndex] = sizes[oldIndex];
|
| 72 |
+
strides[newIndex] = strides[oldIndex];
|
| 73 |
+
remappedExcludedDim = newIndex;
|
| 74 |
+
|
| 75 |
+
// Restarts iteration after excludeDim
|
| 76 |
+
++oldIndex;
|
| 77 |
+
stopDim = dims;
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// Handles special case of all dims size 1
|
| 82 |
+
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
|
| 83 |
+
dims = 1;
|
| 84 |
+
sizes[0] = 1;
|
| 85 |
+
strides[0] = 1;
|
| 86 |
+
|
| 87 |
+
return std::pair<int64_t, int64_t>(0, 1);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
dims = newIndex + 1;
|
| 91 |
+
return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeImplicitAutogradFunctions_inl.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeimplicitautogradnestedtensor_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/randn_like_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 20 |
+
#include <ATen/ops/reshape_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 21 |
+
#include <ATen/ops/reshape_as_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 22 |
+
#include <ATen/ops/zeros_like_compositeimplicitautogradnestedtensor_dispatch.h>
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch.h
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/DeprecatedTypeProperties.h>
|
| 4 |
+
#include <c10/macros/Macros.h>
|
| 5 |
+
#include <c10/util/Exception.h>
|
| 6 |
+
#include <c10/util/Half.h>
|
| 7 |
+
#include <c10/util/Metaprogramming.h>
|
| 8 |
+
#include <c10/util/complex.h>
|
| 9 |
+
#include <c10/util/string_view.h>
|
| 10 |
+
|
| 11 |
+
#ifdef __CUDACC__
|
| 12 |
+
#include <cuda.h> // For CUDA_VERSION
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
#ifdef TEMPLATE_SELECTIVE_BUILD
|
| 16 |
+
#include <ATen/selected_mobile_ops.h>
|
| 17 |
+
#else
|
| 18 |
+
namespace at {
|
| 19 |
+
/**
|
| 20 |
+
* The method should_include_kernel_dtype() returns true/false
|
| 21 |
+
* based on whether the switching code for a specific dtype should be
|
| 22 |
+
* included based on build time constants generated from tracing model
|
| 23 |
+
* execution. This method will be implmeneted via code-generation and
|
| 24 |
+
* included in this file when code-gen is ready.
|
| 25 |
+
*/
|
| 26 |
+
inline constexpr bool should_include_kernel_dtype(
|
| 27 |
+
const char* /*kernel_tag_str*/,
|
| 28 |
+
at::ScalarType /*scalar_type*/
|
| 29 |
+
) {
|
| 30 |
+
return true;
|
| 31 |
+
}
|
| 32 |
+
} // namespace at
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
/**
|
| 36 |
+
* In the Facebook internal build (using BUCK), this macro is enabled by
|
| 37 |
+
* passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
|
| 38 |
+
* binary.
|
| 39 |
+
*/
|
| 40 |
+
#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
|
| 41 |
+
namespace at {
|
| 42 |
+
namespace detail {
|
| 43 |
+
TORCH_API void record_kernel_function_dtype(std::string name);
|
| 44 |
+
}
|
| 45 |
+
} // namespace at
|
| 46 |
+
|
| 47 |
+
#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
|
| 48 |
+
at::detail::record_kernel_function_dtype( \
|
| 49 |
+
std::string(NAME) + "$" + toString(enum_type));
|
| 50 |
+
#else
|
| 51 |
+
#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
|
| 52 |
+
#endif
|
| 53 |
+
|
| 54 |
+
#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
|
| 55 |
+
do { \
|
| 56 |
+
if constexpr (!at::should_include_kernel_dtype( \
|
| 57 |
+
at_dispatch_name, enum_type)) { \
|
| 58 |
+
AT_ERROR( \
|
| 59 |
+
"dtype '", \
|
| 60 |
+
toString(enum_type), \
|
| 61 |
+
"' not selected for kernel tag ", \
|
| 62 |
+
at_dispatch_name); \
|
| 63 |
+
} \
|
| 64 |
+
} while (0)
|
| 65 |
+
|
| 66 |
+
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
| 67 |
+
case enum_type: { \
|
| 68 |
+
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
| 69 |
+
using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
| 70 |
+
return __VA_ARGS__(); \
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
#define AT_DISPATCH_CASE(enum_type, ...) \
|
| 74 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
| 75 |
+
|
| 76 |
+
#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
|
| 77 |
+
case enum_type: { \
|
| 78 |
+
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
| 79 |
+
using scalar_t = scalar_type; \
|
| 80 |
+
using underlying_t C10_UNUSED = typename scalar_t::underlying; \
|
| 81 |
+
const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
|
| 82 |
+
const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
|
| 83 |
+
return __VA_ARGS__(); \
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 87 |
+
enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
|
| 88 |
+
case enum_type: { \
|
| 89 |
+
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
| 90 |
+
using scalar_t = scalar_type; \
|
| 91 |
+
using underlying_t C10_UNUSED = typename scalar_t::underlying; \
|
| 92 |
+
const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
|
| 93 |
+
const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
|
| 94 |
+
C10_UNUSED int bit_width = bitwidth; \
|
| 95 |
+
C10_UNUSED int64_t quant_min = qmin; \
|
| 96 |
+
C10_UNUSED int64_t quant_max = qmax; \
|
| 97 |
+
return __VA_ARGS__(); \
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
namespace detail {
|
| 101 |
+
|
| 102 |
+
inline at::ScalarType scalar_type(at::ScalarType s) {
|
| 103 |
+
return s;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
C10_DEPRECATED_MESSAGE(
|
| 107 |
+
"passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
|
| 108 |
+
"pass an at::ScalarType instead")
|
| 109 |
+
inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
|
| 110 |
+
return t.scalarType();
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
C10_DEPRECATED_MESSAGE(
|
| 114 |
+
"AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
|
| 115 |
+
"use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
|
| 116 |
+
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
|
| 117 |
+
|
| 118 |
+
C10_DEPRECATED_MESSAGE(
|
| 119 |
+
"AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
|
| 120 |
+
"use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
|
| 121 |
+
"instead")
|
| 122 |
+
inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
| 123 |
+
|
| 124 |
+
} // namespace detail
|
| 125 |
+
|
| 126 |
+
// The AT_DISPATCH_* family of macros provides the ability to
|
| 127 |
+
// conveniently generate specializations of a kernel over all of the
|
| 128 |
+
// dtypes we care about in PyTorch. We call it "dispatch" because
|
| 129 |
+
// we are "dispatching" to the correct, dtype-specific kernel.
|
| 130 |
+
//
|
| 131 |
+
// A standard usage looks like:
|
| 132 |
+
//
|
| 133 |
+
// AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
|
| 134 |
+
// // Your code here, with 'scalar_t' now defined to
|
| 135 |
+
// // be the dtype in question
|
| 136 |
+
// });
|
| 137 |
+
//
|
| 138 |
+
// There are many variations of this macro, so it's important to
|
| 139 |
+
// understand exactly /which/ dtypes you want to get instantiated, as
|
| 140 |
+
// well as what the "default" set is.
|
| 141 |
+
//
|
| 142 |
+
// The default set of dtypes that are instantiated (e.g., by
|
| 143 |
+
// AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
|
| 144 |
+
// and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
|
| 145 |
+
// but NOT booleans (bool), half-precision floats (Half) or
|
| 146 |
+
// complex number (c10::complex<float>, c10::complex<double>).
|
| 147 |
+
// This "cut" is somewhat historical (the default types are the
|
| 148 |
+
// ones that TH historically supported), but it also reflects the
|
| 149 |
+
// fact that the non-default types are "poorly" behaved (booleans
|
| 150 |
+
// are NOT integers mod 2, half precision operations ~essentially
|
| 151 |
+
// don't exist on CPU, complex numbers are an experimental application).
|
| 152 |
+
//
|
| 153 |
+
// Here are the questions you should generally ask to decide which
|
| 154 |
+
// dispatch you want:
|
| 155 |
+
//
|
| 156 |
+
// 1. Is this an integral or floating point specific operation?
|
| 157 |
+
// (If so, you'll want one of the FLOATING or INTEGRAL macros.)
|
| 158 |
+
//
|
| 159 |
+
// 2. Should half be supported? (If you're on CPU, the answer is almost
|
| 160 |
+
// definitely no. If you do want support, use one of the AND_HALF
|
| 161 |
+
// macros)
|
| 162 |
+
//
|
| 163 |
+
// Much rarer situations:
|
| 164 |
+
//
|
| 165 |
+
// 3. Should bool be supported? (You often have to write your kernel
|
| 166 |
+
// differently if arithmetic operations are involved.) If so,
|
| 167 |
+
// Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
|
| 168 |
+
//
|
| 169 |
+
// 4. Should complex be supported? The answer is almost always no,
|
| 170 |
+
// unless you are working on "generic" code that should work on
|
| 171 |
+
// all dtypes.
|
| 172 |
+
//
|
| 173 |
+
// Parameters:
|
| 174 |
+
// -----------
|
| 175 |
+
//
|
| 176 |
+
// 1. The NAME argument is a "tag" that is used to trace and then
|
| 177 |
+
// conditionally compile fragments of the case statements such
|
| 178 |
+
// that the kernel functions are specialized only for the dtypes
|
| 179 |
+
// that are needed. The NAME parameter *must* be a build time
|
| 180 |
+
// const char* (can't be std::string, etc...)
|
| 181 |
+
//
|
| 182 |
+
// Please ensure that the NAME is unique for every implementation
|
| 183 |
+
// or you run the risk of over-including code for the kernel
|
| 184 |
+
// functions. There is no risk of missing out on any code, so
|
| 185 |
+
// it's mostly a risk of a Type-2 error, and not a Type-1 error.
|
| 186 |
+
//
|
| 187 |
+
// Switch-like syntax:
|
| 188 |
+
// -------------------
|
| 189 |
+
// There is also a switch-case like syntax which is useful if a kernel
|
| 190 |
+
// needs to be specialized for particular scalar types
|
| 191 |
+
//
|
| 192 |
+
// AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
|
| 193 |
+
// AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
|
| 194 |
+
// op_integral<scalar_t>(iter);
|
| 195 |
+
// })
|
| 196 |
+
// AT_DISPATCH_CASE_FLOATING_TYPES([&] {
|
| 197 |
+
// op_floating<scalar_t>(iter);
|
| 198 |
+
// })
|
| 199 |
+
// AT_DISPATCH_CASE(kBool, [&] {
|
| 200 |
+
// op_bool(iter);
|
| 201 |
+
// })
|
| 202 |
+
// );
|
| 203 |
+
//
|
| 204 |
+
// For each AT_DISPATCH_FOO macro, there is a corresponding
|
| 205 |
+
// AT_DISPATCH_CASE_FOO macro which can be used inside of an
|
| 206 |
+
// AT_DISPATCH_SWITCH block.
|
| 207 |
+
|
| 208 |
+
// NB: the the_type variable is not used, but we have kept it for
|
| 209 |
+
// backwards compatibility. It's probably not used by anyone though;
|
| 210 |
+
// but we're just being safe (and it doesn't hurt.) Note we must
|
| 211 |
+
// use it to shut up warnings about unused store.
|
| 212 |
+
|
| 213 |
+
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
| 214 |
+
[&] { \
|
| 215 |
+
const auto& the_type = TYPE; \
|
| 216 |
+
constexpr const char* at_dispatch_name = NAME; \
|
| 217 |
+
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
| 218 |
+
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
| 219 |
+
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
| 220 |
+
switch (_st) { \
|
| 221 |
+
__VA_ARGS__ \
|
| 222 |
+
default: \
|
| 223 |
+
AT_ERROR( \
|
| 224 |
+
'"', \
|
| 225 |
+
at_dispatch_name, \
|
| 226 |
+
"\" not implemented for '", \
|
| 227 |
+
toString(_st), \
|
| 228 |
+
"'"); \
|
| 229 |
+
} \
|
| 230 |
+
}()
|
| 231 |
+
|
| 232 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
| 233 |
+
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
| 234 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
|
| 235 |
+
|
| 236 |
+
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 237 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
| 238 |
+
|
| 239 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
|
| 240 |
+
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
| 241 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 242 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
| 243 |
+
|
| 244 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
|
| 245 |
+
AT_DISPATCH_SWITCH( \
|
| 246 |
+
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
|
| 247 |
+
|
| 248 |
+
#define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
|
| 249 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
| 250 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
| 251 |
+
|
| 252 |
+
#define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 253 |
+
AT_DISPATCH_SWITCH( \
|
| 254 |
+
TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
|
| 255 |
+
|
| 256 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
|
| 257 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 258 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 259 |
+
|
| 260 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 261 |
+
AT_DISPATCH_SWITCH( \
|
| 262 |
+
TYPE, \
|
| 263 |
+
NAME, \
|
| 264 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 265 |
+
|
| 266 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
|
| 267 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 268 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 269 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 270 |
+
|
| 271 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND2( \
|
| 272 |
+
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 273 |
+
AT_DISPATCH_SWITCH( \
|
| 274 |
+
TYPE, \
|
| 275 |
+
NAME, \
|
| 276 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
|
| 277 |
+
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 278 |
+
|
| 279 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
|
| 280 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 281 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 282 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 283 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 284 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 285 |
+
|
| 286 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND3( \
|
| 287 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 288 |
+
AT_DISPATCH_SWITCH( \
|
| 289 |
+
TYPE, \
|
| 290 |
+
NAME, \
|
| 291 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
|
| 292 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 293 |
+
|
| 294 |
+
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
|
| 295 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
| 296 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 297 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 298 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 299 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 300 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
| 301 |
+
|
| 302 |
+
#define AT_DISPATCH_FLOATING_TYPES_AND4( \
|
| 303 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
| 304 |
+
AT_DISPATCH_SWITCH( \
|
| 305 |
+
TYPE, \
|
| 306 |
+
NAME, \
|
| 307 |
+
AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
|
| 308 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
| 309 |
+
|
| 310 |
+
#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
|
| 311 |
+
AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
|
| 312 |
+
AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
|
| 313 |
+
|
| 314 |
+
#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
|
| 315 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
|
| 316 |
+
|
| 317 |
+
#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
|
| 318 |
+
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
|
| 319 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 320 |
+
|
| 321 |
+
#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 322 |
+
AT_DISPATCH_SWITCH( \
|
| 323 |
+
TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 324 |
+
|
| 325 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
|
| 326 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
| 327 |
+
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
|
| 328 |
+
|
| 329 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
|
| 330 |
+
AT_DISPATCH_SWITCH( \
|
| 331 |
+
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
|
| 332 |
+
|
| 333 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
|
| 334 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 335 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 336 |
+
|
| 337 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
|
| 338 |
+
SCALARTYPE, TYPE, NAME, ...) \
|
| 339 |
+
AT_DISPATCH_SWITCH( \
|
| 340 |
+
TYPE, \
|
| 341 |
+
NAME, \
|
| 342 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
|
| 343 |
+
SCALARTYPE, __VA_ARGS__))
|
| 344 |
+
|
| 345 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
|
| 346 |
+
SCALARTYPE1, SCALARTYPE2, ...) \
|
| 347 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 348 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 349 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 350 |
+
|
| 351 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
|
| 352 |
+
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 353 |
+
AT_DISPATCH_SWITCH( \
|
| 354 |
+
TYPE, \
|
| 355 |
+
NAME, \
|
| 356 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
|
| 357 |
+
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 358 |
+
|
| 359 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
| 360 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 361 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 362 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 363 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 364 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 365 |
+
|
| 366 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
| 367 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 368 |
+
AT_DISPATCH_SWITCH( \
|
| 369 |
+
TYPE, \
|
| 370 |
+
NAME, \
|
| 371 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
|
| 372 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 373 |
+
|
| 374 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
| 375 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
| 376 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 377 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 378 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 379 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 380 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
| 381 |
+
|
| 382 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
| 383 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
| 384 |
+
AT_DISPATCH_SWITCH( \
|
| 385 |
+
TYPE, \
|
| 386 |
+
NAME, \
|
| 387 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
|
| 388 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
| 389 |
+
|
| 390 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
|
| 391 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
|
| 392 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 393 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 394 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 395 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 396 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 397 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
|
| 398 |
+
|
| 399 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \
|
| 400 |
+
SCALARTYPE1, \
|
| 401 |
+
SCALARTYPE2, \
|
| 402 |
+
SCALARTYPE3, \
|
| 403 |
+
SCALARTYPE4, \
|
| 404 |
+
SCALARTYPE5, \
|
| 405 |
+
TYPE, \
|
| 406 |
+
NAME, \
|
| 407 |
+
...) \
|
| 408 |
+
AT_DISPATCH_SWITCH( \
|
| 409 |
+
TYPE, \
|
| 410 |
+
NAME, \
|
| 411 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
|
| 412 |
+
SCALARTYPE1, \
|
| 413 |
+
SCALARTYPE2, \
|
| 414 |
+
SCALARTYPE3, \
|
| 415 |
+
SCALARTYPE4, \
|
| 416 |
+
SCALARTYPE5, \
|
| 417 |
+
__VA_ARGS__))
|
| 418 |
+
|
| 419 |
+
#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
| 420 |
+
SCALARTYPE1, \
|
| 421 |
+
SCALARTYPE2, \
|
| 422 |
+
SCALARTYPE3, \
|
| 423 |
+
SCALARTYPE4, \
|
| 424 |
+
SCALARTYPE5, \
|
| 425 |
+
SCALARTYPE6, \
|
| 426 |
+
...) \
|
| 427 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
|
| 428 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 429 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 430 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 431 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 432 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 433 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
|
| 434 |
+
|
| 435 |
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
| 436 |
+
SCALARTYPE1, \
|
| 437 |
+
SCALARTYPE2, \
|
| 438 |
+
SCALARTYPE3, \
|
| 439 |
+
SCALARTYPE4, \
|
| 440 |
+
SCALARTYPE5, \
|
| 441 |
+
SCALARTYPE6, \
|
| 442 |
+
TYPE, \
|
| 443 |
+
NAME, \
|
| 444 |
+
...) \
|
| 445 |
+
AT_DISPATCH_SWITCH( \
|
| 446 |
+
TYPE, \
|
| 447 |
+
NAME, \
|
| 448 |
+
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
|
| 449 |
+
SCALARTYPE1, \
|
| 450 |
+
SCALARTYPE2, \
|
| 451 |
+
SCALARTYPE3, \
|
| 452 |
+
SCALARTYPE4, \
|
| 453 |
+
SCALARTYPE5, \
|
| 454 |
+
SCALARTYPE6, \
|
| 455 |
+
__VA_ARGS__))
|
| 456 |
+
|
| 457 |
+
#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
| 458 |
+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
| 459 |
+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
| 460 |
+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
| 461 |
+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
| 462 |
+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
|
| 463 |
+
|
| 464 |
+
#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
| 465 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
| 466 |
+
|
| 467 |
+
#define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
|
| 468 |
+
AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
|
| 469 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 470 |
+
|
| 471 |
+
#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 472 |
+
AT_DISPATCH_SWITCH( \
|
| 473 |
+
TYPE, \
|
| 474 |
+
NAME, \
|
| 475 |
+
AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 476 |
+
|
| 477 |
+
#define AT_DISPATCH_CASE_ALL_TYPES(...) \
|
| 478 |
+
AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
|
| 479 |
+
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
|
| 480 |
+
|
| 481 |
+
#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
|
| 482 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
|
| 483 |
+
|
| 484 |
+
#define AT_DISPATCH_CASE_QINT_TYPES(...) \
|
| 485 |
+
AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
|
| 486 |
+
AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
|
| 487 |
+
AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
|
| 488 |
+
|
| 489 |
+
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
|
| 490 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
|
| 491 |
+
|
| 492 |
+
#define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
|
| 493 |
+
AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
|
| 494 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 495 |
+
|
| 496 |
+
#define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 497 |
+
AT_DISPATCH_SWITCH( \
|
| 498 |
+
TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 499 |
+
|
| 500 |
+
#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
|
| 501 |
+
AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
|
| 502 |
+
AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
|
| 503 |
+
|
| 504 |
+
#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
|
| 505 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
|
| 506 |
+
|
| 507 |
+
#define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
|
| 508 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 509 |
+
at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
|
| 510 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 511 |
+
at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
|
| 512 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 513 |
+
at::kQInt32, \
|
| 514 |
+
at::qint32, \
|
| 515 |
+
CHAR_BIT * sizeof(int), \
|
| 516 |
+
INT_MIN, \
|
| 517 |
+
INT_MAX, \
|
| 518 |
+
__VA_ARGS__) \
|
| 519 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 520 |
+
at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
|
| 521 |
+
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
|
| 522 |
+
at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
|
| 523 |
+
|
| 524 |
+
#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
|
| 525 |
+
AT_DISPATCH_SWITCH( \
|
| 526 |
+
TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
|
| 527 |
+
|
| 528 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
|
| 529 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 530 |
+
AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
|
| 531 |
+
|
| 532 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
|
| 533 |
+
AT_DISPATCH_SWITCH( \
|
| 534 |
+
TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
|
| 535 |
+
|
| 536 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
|
| 537 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 538 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 539 |
+
|
| 540 |
+
#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 541 |
+
AT_DISPATCH_SWITCH( \
|
| 542 |
+
TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
|
| 543 |
+
|
| 544 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
|
| 545 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 546 |
+
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
|
| 547 |
+
|
| 548 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
|
| 549 |
+
AT_DISPATCH_SWITCH( \
|
| 550 |
+
TYPE, \
|
| 551 |
+
NAME, \
|
| 552 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
|
| 553 |
+
|
| 554 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
|
| 555 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 556 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 557 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 558 |
+
|
| 559 |
+
#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 560 |
+
AT_DISPATCH_SWITCH( \
|
| 561 |
+
TYPE, \
|
| 562 |
+
NAME, \
|
| 563 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 564 |
+
|
| 565 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
|
| 566 |
+
SCALARTYPE1, SCALARTYPE2, ...) \
|
| 567 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 568 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 569 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
|
| 570 |
+
|
| 571 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
|
| 572 |
+
SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
|
| 573 |
+
AT_DISPATCH_SWITCH( \
|
| 574 |
+
TYPE, \
|
| 575 |
+
NAME, \
|
| 576 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
|
| 577 |
+
SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
|
| 578 |
+
|
| 579 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
|
| 580 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 581 |
+
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
|
| 582 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 583 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 584 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 585 |
+
|
| 586 |
+
#define AT_DISPATCH_ALL_TYPES_AND3( \
|
| 587 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 588 |
+
AT_DISPATCH_SWITCH( \
|
| 589 |
+
TYPE, \
|
| 590 |
+
NAME, \
|
| 591 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND3( \
|
| 592 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 593 |
+
|
| 594 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
|
| 595 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
|
| 596 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 597 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 598 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 599 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
|
| 600 |
+
|
| 601 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
|
| 602 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
|
| 603 |
+
AT_DISPATCH_SWITCH( \
|
| 604 |
+
TYPE, \
|
| 605 |
+
NAME, \
|
| 606 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
|
| 607 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
| 608 |
+
|
| 609 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 610 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
| 611 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 612 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 613 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 614 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 615 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
| 616 |
+
|
| 617 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 618 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
| 619 |
+
AT_DISPATCH_SWITCH( \
|
| 620 |
+
TYPE, \
|
| 621 |
+
NAME, \
|
| 622 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 623 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
| 624 |
+
|
| 625 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
|
| 626 |
+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
|
| 627 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 628 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 629 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 630 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 631 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 632 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
|
| 633 |
+
|
| 634 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
|
| 635 |
+
SCALARTYPE1, \
|
| 636 |
+
SCALARTYPE2, \
|
| 637 |
+
SCALARTYPE3, \
|
| 638 |
+
SCALARTYPE4, \
|
| 639 |
+
SCALARTYPE5, \
|
| 640 |
+
TYPE, \
|
| 641 |
+
NAME, \
|
| 642 |
+
...) \
|
| 643 |
+
AT_DISPATCH_SWITCH( \
|
| 644 |
+
TYPE, \
|
| 645 |
+
NAME, \
|
| 646 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
|
| 647 |
+
SCALARTYPE1, \
|
| 648 |
+
SCALARTYPE2, \
|
| 649 |
+
SCALARTYPE3, \
|
| 650 |
+
SCALARTYPE4, \
|
| 651 |
+
SCALARTYPE5, \
|
| 652 |
+
__VA_ARGS__))
|
| 653 |
+
|
| 654 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
|
| 655 |
+
SCALARTYPE1, \
|
| 656 |
+
SCALARTYPE2, \
|
| 657 |
+
SCALARTYPE3, \
|
| 658 |
+
SCALARTYPE4, \
|
| 659 |
+
SCALARTYPE5, \
|
| 660 |
+
SCALARTYPE6, \
|
| 661 |
+
...) \
|
| 662 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 663 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 664 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 665 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 666 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 667 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 668 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
|
| 669 |
+
|
| 670 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
| 671 |
+
SCALARTYPE1, \
|
| 672 |
+
SCALARTYPE2, \
|
| 673 |
+
SCALARTYPE3, \
|
| 674 |
+
SCALARTYPE4, \
|
| 675 |
+
SCALARTYPE5, \
|
| 676 |
+
SCALARTYPE6, \
|
| 677 |
+
TYPE, \
|
| 678 |
+
NAME, \
|
| 679 |
+
...) \
|
| 680 |
+
AT_DISPATCH_SWITCH( \
|
| 681 |
+
TYPE, \
|
| 682 |
+
NAME, \
|
| 683 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
|
| 684 |
+
SCALARTYPE1, \
|
| 685 |
+
SCALARTYPE2, \
|
| 686 |
+
SCALARTYPE3, \
|
| 687 |
+
SCALARTYPE4, \
|
| 688 |
+
SCALARTYPE5, \
|
| 689 |
+
SCALARTYPE6, \
|
| 690 |
+
__VA_ARGS__))
|
| 691 |
+
|
| 692 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
|
| 693 |
+
SCALARTYPE1, \
|
| 694 |
+
SCALARTYPE2, \
|
| 695 |
+
SCALARTYPE3, \
|
| 696 |
+
SCALARTYPE4, \
|
| 697 |
+
SCALARTYPE5, \
|
| 698 |
+
SCALARTYPE6, \
|
| 699 |
+
SCALARTYPE7, \
|
| 700 |
+
...) \
|
| 701 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 702 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 703 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 704 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 705 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 706 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 707 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
|
| 708 |
+
AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
|
| 709 |
+
|
| 710 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
|
| 711 |
+
SCALARTYPE1, \
|
| 712 |
+
SCALARTYPE2, \
|
| 713 |
+
SCALARTYPE3, \
|
| 714 |
+
SCALARTYPE4, \
|
| 715 |
+
SCALARTYPE5, \
|
| 716 |
+
SCALARTYPE6, \
|
| 717 |
+
SCALARTYPE7, \
|
| 718 |
+
TYPE, \
|
| 719 |
+
NAME, \
|
| 720 |
+
...) \
|
| 721 |
+
AT_DISPATCH_SWITCH( \
|
| 722 |
+
TYPE, \
|
| 723 |
+
NAME, \
|
| 724 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
|
| 725 |
+
SCALARTYPE1, \
|
| 726 |
+
SCALARTYPE2, \
|
| 727 |
+
SCALARTYPE3, \
|
| 728 |
+
SCALARTYPE4, \
|
| 729 |
+
SCALARTYPE5, \
|
| 730 |
+
SCALARTYPE6, \
|
| 731 |
+
SCALARTYPE7, \
|
| 732 |
+
__VA_ARGS__))
|
| 733 |
+
|
| 734 |
+
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
|
| 735 |
+
SCALARTYPE1, \
|
| 736 |
+
SCALARTYPE2, \
|
| 737 |
+
SCALARTYPE3, \
|
| 738 |
+
SCALARTYPE4, \
|
| 739 |
+
SCALARTYPE5, \
|
| 740 |
+
SCALARTYPE6, \
|
| 741 |
+
SCALARTYPE7, \
|
| 742 |
+
SCALARTYPE8, \
|
| 743 |
+
...) \
|
| 744 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
| 745 |
+
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
| 746 |
+
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
| 747 |
+
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
| 748 |
+
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
| 749 |
+
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
| 750 |
+
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
|
| 751 |
+
AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \
|
| 752 |
+
AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
|
| 753 |
+
|
| 754 |
+
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
|
| 755 |
+
SCALARTYPE1, \
|
| 756 |
+
SCALARTYPE2, \
|
| 757 |
+
SCALARTYPE3, \
|
| 758 |
+
SCALARTYPE4, \
|
| 759 |
+
SCALARTYPE5, \
|
| 760 |
+
SCALARTYPE6, \
|
| 761 |
+
SCALARTYPE7, \
|
| 762 |
+
SCALARTYPE8, \
|
| 763 |
+
TYPE, \
|
| 764 |
+
NAME, \
|
| 765 |
+
...) \
|
| 766 |
+
AT_DISPATCH_SWITCH( \
|
| 767 |
+
TYPE, \
|
| 768 |
+
NAME, \
|
| 769 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
|
| 770 |
+
SCALARTYPE1, \
|
| 771 |
+
SCALARTYPE2, \
|
| 772 |
+
SCALARTYPE3, \
|
| 773 |
+
SCALARTYPE4, \
|
| 774 |
+
SCALARTYPE5, \
|
| 775 |
+
SCALARTYPE6, \
|
| 776 |
+
SCALARTYPE7, \
|
| 777 |
+
SCALARTYPE8, \
|
| 778 |
+
__VA_ARGS__))
|
| 779 |
+
|
| 780 |
+
#define AT_DISPATCH_CASE_BIT_TYPES(...) \
|
| 781 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
|
| 782 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
|
| 783 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
|
| 784 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \
|
| 785 |
+
AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
|
| 786 |
+
|
| 787 |
+
#define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
|
| 788 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
|
| 789 |
+
|
| 790 |
+
#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
|
| 791 |
+
AT_DISPATCH_SWITCH( \
|
| 792 |
+
TYPE, \
|
| 793 |
+
NAME, \
|
| 794 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
| 795 |
+
at::ScalarType::Int, index_t, __VA_ARGS__) \
|
| 796 |
+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
| 797 |
+
at::ScalarType::Long, index_t, __VA_ARGS__))
|
| 798 |
+
|
| 799 |
+
// ----------------------------------------------------------------------------
|
| 800 |
+
// DEPRECATED MACROS, DON'T USE THESE
|
| 801 |
+
// ----------------------------------------------------------------------------
|
| 802 |
+
|
| 803 |
+
#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
|
| 804 |
+
detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
|
| 805 |
+
AT_DISPATCH_SWITCH( \
|
| 806 |
+
TYPE, \
|
| 807 |
+
NAME, \
|
| 808 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Dispatch_v2.h
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/Dispatch.h>
|
| 2 |
+
|
| 3 |
+
// This is a new implementation of the AT_DISPATCH macro family from
|
| 4 |
+
// ATen/Dispatch.h
|
| 5 |
+
//
|
| 6 |
+
// The intended usage is:
|
| 7 |
+
//
|
| 8 |
+
// ScalarType scalar_type;
|
| 9 |
+
//
|
| 10 |
+
// AT_DISPATCH_V2(
|
| 11 |
+
// scalar_type,
|
| 12 |
+
// "debug string",
|
| 13 |
+
// AT_WRAP([&] {
|
| 14 |
+
// ... code to specialize with scalar_t ...
|
| 15 |
+
// }),
|
| 16 |
+
// kHalf,
|
| 17 |
+
// AT_EXPAND(AT_ALL_TYPES),
|
| 18 |
+
// ... as many types arguments as needed ...
|
| 19 |
+
// )
|
| 20 |
+
//
|
| 21 |
+
// For example, given an old style:
|
| 22 |
+
//
|
| 23 |
+
// AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
| 24 |
+
// kComplexHalf,
|
| 25 |
+
// kHalf,
|
| 26 |
+
// self.scalar_type(),
|
| 27 |
+
// "_local_scalar_dense_cpu",
|
| 28 |
+
// [&] {
|
| 29 |
+
// scalar_t value = *self.data_ptr<scalar_t>();
|
| 30 |
+
// r = Scalar(value);
|
| 31 |
+
// }
|
| 32 |
+
// )
|
| 33 |
+
//
|
| 34 |
+
// You now write:
|
| 35 |
+
//
|
| 36 |
+
// AT_DISPATCH_V2(
|
| 37 |
+
// self.scalar_type(),
|
| 38 |
+
// "_local_scalar_dense_cpu",
|
| 39 |
+
// AT_WRAP([&] {
|
| 40 |
+
// scalar_t value = *self.data_ptr<scalar_t>();
|
| 41 |
+
// r = Scalar(value);
|
| 42 |
+
// }),
|
| 43 |
+
// AT_EXPAND(AT_ALL_TYPES),
|
| 44 |
+
// AT_EXPAND(AT_COMPLEX_TYPES),
|
| 45 |
+
// kComplexHalf,
|
| 46 |
+
// kHalf,
|
| 47 |
+
// )
|
| 48 |
+
//
|
| 49 |
+
// Notably, it sports the following improvements:
|
| 50 |
+
//
|
| 51 |
+
// - It is not necessary to specify the arity (e.g.,
|
| 52 |
+
// AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3,4,...})
|
| 53 |
+
// when using the macro
|
| 54 |
+
//
|
| 55 |
+
// - It is not necessary to specify each dtype individually; if
|
| 56 |
+
// there is a set of related dtypes and you want to dispatch
|
| 57 |
+
// over all of them, you can simply say, e.g., AT_EXPAND(AT_INTEGRAL_TYPES)
|
| 58 |
+
// in your argument list.
|
| 59 |
+
//
|
| 60 |
+
// However, you must remember to wrap the payload body in AT_WRAP, or commas
|
| 61 |
+
// inside your lambda will be improperly handled. Furthermore, if you more
|
| 62 |
+
// entries to ScalarType than can be supported by this macro, it will fail
|
| 63 |
+
// with an obscure error (due to attempting to concatenate AT_AP with
|
| 64 |
+
// something that is not a number).
|
| 65 |
+
//
|
| 66 |
+
// The implementation strategy is to use the count arguments trick
|
| 67 |
+
// (e.g., as described in https://stackoverflow.com/a/2124385/23845)
|
| 68 |
+
// to discover how many dtypes have been passed, and then dispatch to a
|
| 69 |
+
// hand-written macro for each arity that applies as many DISPATCH_CASE as
|
| 70 |
+
// necessary. The hand-written macros can be regenerated for other arities
|
| 71 |
+
// with the script below.
|
| 72 |
+
//
|
| 73 |
+
// There is some delicacy in the implementation in controlling when
|
| 74 |
+
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
| 75 |
+
// relied on GPT4 to help me get it right.
|
| 76 |
+
|
| 77 |
+
// Public API macros
|
| 78 |
+
|
| 79 |
+
// See documentation above
|
| 80 |
+
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
| 81 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
|
| 82 |
+
|
| 83 |
+
// This macro lets you pass an arbitrary expression that may contain internal
|
| 84 |
+
// commas to another macro without having the commas causing the expression
|
| 85 |
+
// to be interpreted as being multiple arguments
|
| 86 |
+
#define AT_WRAP(...) __VA_ARGS__
|
| 87 |
+
|
| 88 |
+
#define AT_FLOAT8_TYPES \
|
| 89 |
+
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
| 90 |
+
c10::kFloat8_e4m3fnuz
|
| 91 |
+
|
| 92 |
+
#define AT_INTEGRAL_TYPES \
|
| 93 |
+
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
| 94 |
+
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
|
| 95 |
+
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
|
| 96 |
+
#define AT_INTEGRAL_TYPES_V2 \
|
| 97 |
+
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
| 98 |
+
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
|
| 99 |
+
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
|
| 100 |
+
// NB: not *actually* all types
|
| 101 |
+
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
| 102 |
+
#define AT_ALL_TYPES_AND_COMPLEX \
|
| 103 |
+
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
|
| 104 |
+
|
| 105 |
+
// Helper macros
|
| 106 |
+
|
| 107 |
+
#define AT_AP_VAR(N, T, ...) \
|
| 108 |
+
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
| 109 |
+
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
|
| 110 |
+
#define AT_CONCAT_AUX(a, b) a##b
|
| 111 |
+
#define AT_EXPAND(X) X
|
| 112 |
+
|
| 113 |
+
// Ensure we never have too many scalar types for the expansion here to
|
| 114 |
+
// support. To bump this, you must regenerate the macros below.
|
| 115 |
+
static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 45);
|
| 116 |
+
|
| 117 |
+
// Python code to regenerate generate code below:
|
| 118 |
+
#if 0
|
| 119 |
+
|
| 120 |
+
num_args = 45
|
| 121 |
+
|
| 122 |
+
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
|
| 123 |
+
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
|
| 124 |
+
|
| 125 |
+
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
|
| 126 |
+
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
|
| 127 |
+
|
| 128 |
+
for i in range(1, num_args+1):
|
| 129 |
+
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
| 130 |
+
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
| 131 |
+
print(f'#define AT_AP{i}(N, {args}) {cases}')
|
| 132 |
+
|
| 133 |
+
#endif
|
| 134 |
+
|
| 135 |
+
// Begin generated code
|
| 136 |
+
// clang-format off
|
| 137 |
+
|
| 138 |
+
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
| 139 |
+
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, N, ...) N
|
| 140 |
+
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
| 141 |
+
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
| 142 |
+
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
| 143 |
+
#define AT_AP4(N, _1, _2, _3, _4) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N)
|
| 144 |
+
#define AT_AP5(N, _1, _2, _3, _4, _5) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N)
|
| 145 |
+
#define AT_AP6(N, _1, _2, _3, _4, _5, _6) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N)
|
| 146 |
+
#define AT_AP7(N, _1, _2, _3, _4, _5, _6, _7) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N)
|
| 147 |
+
#define AT_AP8(N, _1, _2, _3, _4, _5, _6, _7, _8) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N)
|
| 148 |
+
#define AT_AP9(N, _1, _2, _3, _4, _5, _6, _7, _8, _9) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N)
|
| 149 |
+
#define AT_AP10(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N)
|
| 150 |
+
#define AT_AP11(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N)
|
| 151 |
+
#define AT_AP12(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N)
|
| 152 |
+
#define AT_AP13(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N)
|
| 153 |
+
#define AT_AP14(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N)
|
| 154 |
+
#define AT_AP15(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N)
|
| 155 |
+
#define AT_AP16(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N)
|
| 156 |
+
#define AT_AP17(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N)
|
| 157 |
+
#define AT_AP18(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N)
|
| 158 |
+
#define AT_AP19(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N)
|
| 159 |
+
#define AT_AP20(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N)
|
| 160 |
+
#define AT_AP21(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N)
|
| 161 |
+
#define AT_AP22(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N)
|
| 162 |
+
#define AT_AP23(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N)
|
| 163 |
+
#define AT_AP24(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N)
|
| 164 |
+
#define AT_AP25(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N)
|
| 165 |
+
#define AT_AP26(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N)
|
| 166 |
+
#define AT_AP27(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N)
|
| 167 |
+
#define AT_AP28(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N)
|
| 168 |
+
#define AT_AP29(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N)
|
| 169 |
+
#define AT_AP30(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N)
|
| 170 |
+
#define AT_AP31(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N)
|
| 171 |
+
#define AT_AP32(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N)
|
| 172 |
+
#define AT_AP33(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N)
|
| 173 |
+
#define AT_AP34(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N)
|
| 174 |
+
#define AT_AP35(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N)
|
| 175 |
+
#define AT_AP36(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N)
|
| 176 |
+
#define AT_AP37(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N)
|
| 177 |
+
#define AT_AP38(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N)
|
| 178 |
+
#define AT_AP39(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N)
|
| 179 |
+
#define AT_AP40(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N)
|
| 180 |
+
#define AT_AP41(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N)
|
| 181 |
+
#define AT_AP42(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N)
|
| 182 |
+
#define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N)
|
| 183 |
+
#define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N)
|
| 184 |
+
#define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N)
|
| 185 |
+
// End generated code
|
| 186 |
+
// clang-format on
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/EmptyTensor.h
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/TensorBase.h>
|
| 3 |
+
|
| 4 |
+
namespace at::detail {
|
| 5 |
+
|
| 6 |
+
inline void check_size_nonnegative(ArrayRef<int64_t> size) {
|
| 7 |
+
for (const auto& x : size) {
|
| 8 |
+
TORCH_CHECK(
|
| 9 |
+
x >= 0,
|
| 10 |
+
"Trying to create tensor with negative dimension ",
|
| 11 |
+
x,
|
| 12 |
+
": ",
|
| 13 |
+
size);
|
| 14 |
+
}
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
|
| 18 |
+
for (const auto& x : size) {
|
| 19 |
+
TORCH_CHECK(
|
| 20 |
+
x.expect_size(__FILE__, __LINE__),
|
| 21 |
+
"Trying to create tensor with negative dimension ",
|
| 22 |
+
x,
|
| 23 |
+
": ",
|
| 24 |
+
size);
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
TORCH_API size_t computeStorageNbytesContiguous(
|
| 29 |
+
IntArrayRef sizes,
|
| 30 |
+
size_t itemsize,
|
| 31 |
+
size_t storage_offset = 0);
|
| 32 |
+
TORCH_API SymInt computeStorageNbytesContiguous(
|
| 33 |
+
SymIntArrayRef sizes,
|
| 34 |
+
const SymInt& itemsize,
|
| 35 |
+
const SymInt& storage_offset = 0);
|
| 36 |
+
TORCH_API size_t computeStorageNbytes(
|
| 37 |
+
IntArrayRef sizes,
|
| 38 |
+
IntArrayRef strides,
|
| 39 |
+
size_t itemsize,
|
| 40 |
+
size_t storage_offset = 0);
|
| 41 |
+
TORCH_API SymInt computeStorageNbytes(
|
| 42 |
+
SymIntArrayRef sizes,
|
| 43 |
+
SymIntArrayRef strides,
|
| 44 |
+
const SymInt& itemsize,
|
| 45 |
+
const SymInt& storage_offset = 0);
|
| 46 |
+
|
| 47 |
+
TORCH_API TensorBase empty_generic(
|
| 48 |
+
IntArrayRef size,
|
| 49 |
+
c10::Allocator* allocator,
|
| 50 |
+
c10::DispatchKeySet ks,
|
| 51 |
+
ScalarType scalar_type,
|
| 52 |
+
c10::optional<c10::MemoryFormat> memory_format_opt);
|
| 53 |
+
|
| 54 |
+
TORCH_API TensorBase empty_strided_generic(
|
| 55 |
+
IntArrayRef size,
|
| 56 |
+
IntArrayRef stride,
|
| 57 |
+
c10::Allocator* allocator,
|
| 58 |
+
c10::DispatchKeySet ks,
|
| 59 |
+
ScalarType scalar_type);
|
| 60 |
+
|
| 61 |
+
TORCH_API TensorBase empty_strided_symint_generic(
|
| 62 |
+
SymIntArrayRef size,
|
| 63 |
+
SymIntArrayRef stride,
|
| 64 |
+
c10::Allocator* allocator,
|
| 65 |
+
c10::DispatchKeySet ks,
|
| 66 |
+
ScalarType scalar_type);
|
| 67 |
+
|
| 68 |
+
TORCH_API TensorBase empty_cpu(
|
| 69 |
+
IntArrayRef size,
|
| 70 |
+
ScalarType dtype,
|
| 71 |
+
bool pin_memory = false,
|
| 72 |
+
c10::optional<c10::MemoryFormat> memory_format_opt = c10::nullopt);
|
| 73 |
+
|
| 74 |
+
TORCH_API TensorBase empty_cpu(
|
| 75 |
+
IntArrayRef size,
|
| 76 |
+
c10::optional<ScalarType> dtype_opt,
|
| 77 |
+
c10::optional<Layout> layout_opt,
|
| 78 |
+
c10::optional<Device> device_opt,
|
| 79 |
+
c10::optional<bool> pin_memory_opt,
|
| 80 |
+
c10::optional<c10::MemoryFormat> memory_format_opt);
|
| 81 |
+
|
| 82 |
+
TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options);
|
| 83 |
+
|
| 84 |
+
TORCH_API TensorBase empty_strided_cpu(
|
| 85 |
+
IntArrayRef size,
|
| 86 |
+
IntArrayRef stride,
|
| 87 |
+
ScalarType dtype,
|
| 88 |
+
bool pin_memory = false);
|
| 89 |
+
|
| 90 |
+
TORCH_API TensorBase empty_strided_cpu(
|
| 91 |
+
IntArrayRef size,
|
| 92 |
+
IntArrayRef stride,
|
| 93 |
+
c10::optional<ScalarType> dtype_opt,
|
| 94 |
+
c10::optional<Layout> layout_opt,
|
| 95 |
+
c10::optional<Device> device_opt,
|
| 96 |
+
c10::optional<bool> pin_memory_opt);
|
| 97 |
+
|
| 98 |
+
TORCH_API TensorBase empty_strided_cpu(
|
| 99 |
+
IntArrayRef size,
|
| 100 |
+
IntArrayRef stride,
|
| 101 |
+
const TensorOptions& options);
|
| 102 |
+
|
| 103 |
+
TORCH_API TensorBase empty_meta(
|
| 104 |
+
IntArrayRef size,
|
| 105 |
+
ScalarType dtype,
|
| 106 |
+
c10::optional<c10::MemoryFormat> memory_format_opt = c10::nullopt);
|
| 107 |
+
|
| 108 |
+
TORCH_API TensorBase empty_meta(
|
| 109 |
+
IntArrayRef size,
|
| 110 |
+
c10::optional<ScalarType> dtype_opt,
|
| 111 |
+
c10::optional<Layout> layout_opt,
|
| 112 |
+
c10::optional<Device> device_opt,
|
| 113 |
+
c10::optional<bool> pin_memory_opt,
|
| 114 |
+
c10::optional<c10::MemoryFormat> memory_format_opt);
|
| 115 |
+
|
| 116 |
+
TORCH_API TensorBase empty_symint_meta(
|
| 117 |
+
SymIntArrayRef size,
|
| 118 |
+
c10::optional<ScalarType> dtype_opt,
|
| 119 |
+
c10::optional<Layout> layout_opt,
|
| 120 |
+
c10::optional<Device> device_opt,
|
| 121 |
+
c10::optional<bool> pin_memory_opt,
|
| 122 |
+
c10::optional<c10::MemoryFormat> memory_format_opt);
|
| 123 |
+
|
| 124 |
+
TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options);
|
| 125 |
+
|
| 126 |
+
TORCH_API TensorBase
|
| 127 |
+
empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype);
|
| 128 |
+
|
| 129 |
+
TORCH_API TensorBase empty_strided_meta(
|
| 130 |
+
IntArrayRef size,
|
| 131 |
+
IntArrayRef stride,
|
| 132 |
+
c10::optional<ScalarType> dtype_opt,
|
| 133 |
+
c10::optional<Layout> layout_opt,
|
| 134 |
+
c10::optional<Device> device_opt,
|
| 135 |
+
c10::optional<bool> pin_memory_opt);
|
| 136 |
+
|
| 137 |
+
TORCH_API TensorBase empty_strided_meta(
|
| 138 |
+
IntArrayRef size,
|
| 139 |
+
IntArrayRef stride,
|
| 140 |
+
const TensorOptions& options);
|
| 141 |
+
|
| 142 |
+
TORCH_API TensorBase empty_strided_symint_meta(
|
| 143 |
+
SymIntArrayRef size,
|
| 144 |
+
SymIntArrayRef stride,
|
| 145 |
+
ScalarType dtype);
|
| 146 |
+
|
| 147 |
+
TORCH_API TensorBase empty_strided_symint_meta(
|
| 148 |
+
SymIntArrayRef size,
|
| 149 |
+
SymIntArrayRef stride,
|
| 150 |
+
c10::optional<ScalarType> dtype_opt,
|
| 151 |
+
c10::optional<Layout> layout_opt,
|
| 152 |
+
c10::optional<Device> device_opt,
|
| 153 |
+
c10::optional<bool> pin_memory_opt);
|
| 154 |
+
|
| 155 |
+
TORCH_API TensorBase empty_strided_symint_meta(
|
| 156 |
+
SymIntArrayRef size,
|
| 157 |
+
SymIntArrayRef stride,
|
| 158 |
+
const TensorOptions& options);
|
| 159 |
+
|
| 160 |
+
} // namespace at::detail
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ExpandBase.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBase.h>
|
| 2 |
+
|
| 3 |
+
// Broadcasting utilities for working with TensorBase
|
| 4 |
+
namespace at {
|
| 5 |
+
namespace internal {
|
| 6 |
+
TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
|
| 7 |
+
} // namespace internal
|
| 8 |
+
|
| 9 |
+
inline c10::MaybeOwned<TensorBase> expand_size(
|
| 10 |
+
const TensorBase& self,
|
| 11 |
+
IntArrayRef size) {
|
| 12 |
+
if (size.equals(self.sizes())) {
|
| 13 |
+
return c10::MaybeOwned<TensorBase>::borrowed(self);
|
| 14 |
+
}
|
| 15 |
+
return c10::MaybeOwned<TensorBase>::owned(
|
| 16 |
+
at::internal::expand_slow_path(self, size));
|
| 17 |
+
}
|
| 18 |
+
c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) =
|
| 19 |
+
delete;
|
| 20 |
+
|
| 21 |
+
inline c10::MaybeOwned<TensorBase> expand_inplace(
|
| 22 |
+
const TensorBase& tensor,
|
| 23 |
+
const TensorBase& to_expand) {
|
| 24 |
+
return expand_size(to_expand, tensor.sizes());
|
| 25 |
+
}
|
| 26 |
+
c10::MaybeOwned<TensorBase> expand_inplace(
|
| 27 |
+
const TensorBase& tensor,
|
| 28 |
+
TensorBase&& to_expand) = delete;
|
| 29 |
+
|
| 30 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalTensorWrapper.h
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#pragma once
|
| 3 |
+
|
| 4 |
+
#include <ATen/ArrayRef.h>
|
| 5 |
+
#include <ATen/FunctionalStorageImpl.h>
|
| 6 |
+
#include <ATen/core/IListRef.h>
|
| 7 |
+
#include <ATen/core/List.h>
|
| 8 |
+
#include <ATen/core/boxing/BoxedKernel.h>
|
| 9 |
+
#include <ATen/core/boxing/impl/boxing.h>
|
| 10 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 11 |
+
|
| 12 |
+
#include <c10/core/DispatchKey.h>
|
| 13 |
+
|
| 14 |
+
namespace at {
|
| 15 |
+
|
| 16 |
+
// Note [Functionalization Pass In Core]
|
| 17 |
+
// The Functionalization pass is used to remove aliasing from a pytorch program.
|
| 18 |
+
//
|
| 19 |
+
// This is useful for backends that don't support aliasing, like XLA and Vulkan.
|
| 20 |
+
// It's also necessary in order to remove mutation from a program, which is
|
| 21 |
+
// needed in Functorch.
|
| 22 |
+
//
|
| 23 |
+
// Consider this program:
|
| 24 |
+
// a = torch.ones(...)
|
| 25 |
+
// b = a.view(...)
|
| 26 |
+
// b.add_(1)
|
| 27 |
+
//
|
| 28 |
+
// In this program, b is meant to alias with a due to the use of view(). At the
|
| 29 |
+
// end of the program, both a and b are full of 2's. However, backends that
|
| 30 |
+
// don't support aliasing aren't able to correctly implement the view()
|
| 31 |
+
// operator. Instead, they can opt into the Functionalization pass, which will
|
| 32 |
+
// sit between the user and the backend, and provide the necessary aliasing
|
| 33 |
+
// logic.
|
| 34 |
+
//
|
| 35 |
+
// The functionalization pass will turn the above program into a slightly
|
| 36 |
+
// different program that has the same semantics, transparently to the user,
|
| 37 |
+
// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
|
| 38 |
+
// a.view_copy(...) # view() replaced with view_copy(). Backends like
|
| 39 |
+
// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
|
| 40 |
+
// pass machinery knows that a and b are aliased - it applies b's mutation to a
|
| 41 |
+
// too.
|
| 42 |
+
//
|
| 43 |
+
// So, how does the functionalization pass keep track of which tensors are
|
| 44 |
+
// aliased? The pass works by wrapping EVERY tensor in the program inside of a
|
| 45 |
+
// FunctionalTensorWrapper, which knows about its alias'd tensors.
|
| 46 |
+
//
|
| 47 |
+
// See Note [Functionalization: Alias Removal] for details on the aliasing
|
| 48 |
+
// machinery. See Note [Functionalization: Mutation Removal] for details on
|
| 49 |
+
// mutation removal.
|
| 50 |
+
struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
| 51 |
+
explicit FunctionalTensorWrapper(const Tensor& value);
|
| 52 |
+
// Additional constructor to create a FunctionalTensorWrapper directly from an
|
| 53 |
+
// underlying tensor that was created from a view. For example, the code b =
|
| 54 |
+
// a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
|
| 55 |
+
// view1_meta)
|
| 56 |
+
explicit FunctionalTensorWrapper(
|
| 57 |
+
const Tensor& view_value,
|
| 58 |
+
const FunctionalTensorWrapper* base,
|
| 59 |
+
const functionalization::ViewMeta& meta);
|
| 60 |
+
|
| 61 |
+
// Get the underlying, actual tensor, that doesn't know anything about
|
| 62 |
+
// functionalization.
|
| 63 |
+
const Tensor& value() const {
|
| 64 |
+
return value_;
|
| 65 |
+
};
|
| 66 |
+
// The concept of "level" is only ever important to functorch; it's exposed
|
| 67 |
+
// here as more of a hook for functorch to use.
|
| 68 |
+
int64_t level() const {
|
| 69 |
+
return level_;
|
| 70 |
+
};
|
| 71 |
+
void set_level(int64_t level) {
|
| 72 |
+
level_ = level;
|
| 73 |
+
}
|
| 74 |
+
bool has_metadata_mutation() const {
|
| 75 |
+
return has_metadata_mutation_;
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
// Denotes a mutation that's hidden from autograd,
|
| 79 |
+
// e.g. for the purposes of passing a tensor to a triton kernel
|
| 80 |
+
void mark_mutation_hidden_from_autograd() {
|
| 81 |
+
mutation_hidden_from_autograd_counter_++;
|
| 82 |
+
}
|
| 83 |
+
void mark_mutation_during_no_grad_or_inference_mode() {
|
| 84 |
+
mutation_during_no_grad_or_inference_mode_++;
|
| 85 |
+
}
|
| 86 |
+
// Are all the mutations happening to the tensor hidden from autograd
|
| 87 |
+
bool are_all_mutations_hidden_from_autograd() const {
|
| 88 |
+
return mutation_hidden_from_autograd_counter_ == mutation_counter_;
|
| 89 |
+
}
|
| 90 |
+
// Did all mutations happen under no_grad or inference_mode
|
| 91 |
+
// (We also need to ignore mutations fully hidden from autograd here)
|
| 92 |
+
bool are_all_mutations_under_no_grad_or_inference_mode() const {
|
| 93 |
+
return mutation_hidden_from_autograd_counter_ +
|
| 94 |
+
mutation_during_no_grad_or_inference_mode_ ==
|
| 95 |
+
mutation_counter_;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
// Sync's the underlying tensor with its alias, if it's out of date. This
|
| 99 |
+
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
|
| 100 |
+
// Replay the views (if any) to regenerate the current tensor off of the
|
| 101 |
+
// updated alias.
|
| 102 |
+
void sync_();
|
| 103 |
+
// Performs step (1) of the sync. This is its own public API because it's
|
| 104 |
+
// needed by view_inplace ops like transpose_. See Note [Functionalization
|
| 105 |
+
// Pass - Inplace View Ops]
|
| 106 |
+
void regenerate_from_base();
|
| 107 |
+
// Performs step (2) of the sync. This is its own public API because it's
|
| 108 |
+
// needed by functorch. functorch wants to make sure that all input tensors to
|
| 109 |
+
// a functionalized program have been properly synced so it can properly
|
| 110 |
+
// propagate mutations to inputs. It can't just call sync_(), because the
|
| 111 |
+
// FunctionalTensorWrapper will look like it has no aliases and sync_ will be
|
| 112 |
+
// a noop. We use the reference count on storage_ to determine if the wrapper
|
| 113 |
+
// is aliased, and by the time functorch is ready to propagate updates to
|
| 114 |
+
// inputs, any intermediate views of the input created by the program will
|
| 115 |
+
// have been deallocated. This function also returns whether or not the base
|
| 116 |
+
// actually had any updates to apply.
|
| 117 |
+
bool apply_updates();
|
| 118 |
+
// Takes the current state of value_ and snapshots it, sending it as a pending
|
| 119 |
+
// update to the alias.
|
| 120 |
+
void commit_update();
|
| 121 |
+
// When any tensor is mutated, the tensor increments its alias's "generation".
|
| 122 |
+
// Separately, each tensor maintains its own "generation" counter, which is
|
| 123 |
+
// used to determine if it's up-to-date with its alias. The act of syncing a
|
| 124 |
+
// tensor will set a tensor's generation equal to its alias's generation.
|
| 125 |
+
bool is_up_to_date() const;
|
| 126 |
+
// Freezes the storage of this tensor, preventing subsequent mutations
|
| 127 |
+
void freeze_storage() const;
|
| 128 |
+
// Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
|
| 129 |
+
// describing the series of view ops that ran to generate the current tensor
|
| 130 |
+
// from the base tensor. This method is used by inplace-view ops like
|
| 131 |
+
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
|
| 132 |
+
// tensor by replaying the views off of the alias.
|
| 133 |
+
void mutate_view_meta(const at::functionalization::ViewMeta& meta);
|
| 134 |
+
|
| 135 |
+
// Custom implementation of self.set_(src)
|
| 136 |
+
void set__impl(const FunctionalTensorWrapper* other);
|
| 137 |
+
|
| 138 |
+
// Returns whether the current tensor's data was ever mutated
|
| 139 |
+
bool has_data_mutation();
|
| 140 |
+
//
|
| 141 |
+
// Returns whether the current FunctionalTensorWrapper
|
| 142 |
+
// experienced a set_() call.
|
| 143 |
+
bool was_storage_changed() {
|
| 144 |
+
return was_storage_changed_;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// The functionalization pass can be used to remove mutations.
|
| 148 |
+
// It does so by replacing any mutation op with it's corresponding
|
| 149 |
+
// out-of-place op, followed by a call to replace_(). e.g:
|
| 150 |
+
//
|
| 151 |
+
// a.add_(1)
|
| 152 |
+
//
|
| 153 |
+
// will turn into:
|
| 154 |
+
//
|
| 155 |
+
// tmp = a.add(1)
|
| 156 |
+
// a.replace_(tmp)
|
| 157 |
+
//
|
| 158 |
+
// replace_() swaps out the wrapped tensor, value_, with tmp.
|
| 159 |
+
void replace_(const Tensor& other);
|
| 160 |
+
|
| 161 |
+
bool is_multi_output_view() {
|
| 162 |
+
return is_multi_output_view_;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
// See Note[resize_() in functionalization pass]
|
| 166 |
+
void maybe_replace_storage(const Tensor& other);
|
| 167 |
+
|
| 168 |
+
// Replaces the storage with a new functional storage,
|
| 169 |
+
// and clears the view_metas_ stack.
|
| 170 |
+
// WARNING: Calling this function will sever the aliasing relationship between
|
| 171 |
+
// the current FunctionalTensorWrapper and any of its outstanding aliases.
|
| 172 |
+
// Please only call if you know what you're doing.
|
| 173 |
+
void _unsafe_reset_storage();
|
| 174 |
+
|
| 175 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 176 |
+
const c10::VariableVersion& version_counter,
|
| 177 |
+
bool allow_tensor_metadata_change) const override;
|
| 178 |
+
|
| 179 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 180 |
+
c10::VariableVersion&& version_counter,
|
| 181 |
+
bool allow_tensor_metadata_change) const override;
|
| 182 |
+
|
| 183 |
+
~FunctionalTensorWrapper() override = default;
|
| 184 |
+
|
| 185 |
+
// FunctionalTensorWrapper overrides all custom size/stride function,
|
| 186 |
+
// so that if the inner tensor has a custom implementation
|
| 187 |
+
// we make sure to call that implementation.
|
| 188 |
+
at::IntArrayRef sizes_custom() const override;
|
| 189 |
+
at::IntArrayRef strides_custom() const override;
|
| 190 |
+
int64_t dim_custom() const override;
|
| 191 |
+
int64_t numel_custom() const override;
|
| 192 |
+
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
| 193 |
+
c10::SymIntArrayRef sym_sizes_custom() const override;
|
| 194 |
+
c10::SymInt sym_size_custom(int64_t d) const override;
|
| 195 |
+
c10::SymIntArrayRef sym_strides_custom() const override;
|
| 196 |
+
c10::SymInt sym_storage_offset_custom() const override;
|
| 197 |
+
c10::Device device_custom() const override;
|
| 198 |
+
|
| 199 |
+
private:
|
| 200 |
+
const char* tensorimpl_type_name() const override;
|
| 201 |
+
void set_constructor_metadata();
|
| 202 |
+
functionalization::FunctionalStorageImpl* functional_storage_impl() const;
|
| 203 |
+
|
| 204 |
+
// This is used to re-implement shallow_copy_and_detach for
|
| 205 |
+
// FunctionalTensorWrapper. The implementation is identical, but we just need
|
| 206 |
+
// to return a subclass instead of a plain TensorImpl.
|
| 207 |
+
// TODO: maybe it's possible to arrange for that to happen automatically
|
| 208 |
+
// without an override here?
|
| 209 |
+
template <typename VariableVersion>
|
| 210 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
| 211 |
+
VariableVersion&& version_counter,
|
| 212 |
+
bool allow_tensor_metadata_change) const;
|
| 213 |
+
|
| 214 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
| 215 |
+
void copy_tensor_metadata_and_refresh(
|
| 216 |
+
const FunctionalTensorWrapper* src_impl,
|
| 217 |
+
FunctionalTensorWrapper* dest_impl,
|
| 218 |
+
const c10::VariableVersion& version_counter,
|
| 219 |
+
bool allow_tensor_metadata_change) const;
|
| 220 |
+
|
| 221 |
+
// Note that value is not taken by reference: internally, the wrapper will
|
| 222 |
+
// change the value tensor that it points to over time.
|
| 223 |
+
Tensor value_;
|
| 224 |
+
int64_t level_{};
|
| 225 |
+
// These two counters are used for identifying
|
| 226 |
+
// whether all the mutations on a given tensor are hidden from autograd or
|
| 227 |
+
// not. If we have an input mutation that is hidden from autograd, then once
|
| 228 |
+
// we convert the input mutation to a copy_() we know it will be safe to hide
|
| 229 |
+
// the copy_() from autograd as well.
|
| 230 |
+
uint64_t mutation_counter_ = 0;
|
| 231 |
+
uint64_t mutation_hidden_from_autograd_counter_ = 0;
|
| 232 |
+
uint64_t mutation_during_no_grad_or_inference_mode_ = 0;
|
| 233 |
+
bool has_metadata_mutation_ = false;
|
| 234 |
+
bool is_multi_output_view_ = false;
|
| 235 |
+
// Did the tensor experience a set_() call.
|
| 236 |
+
bool was_storage_changed_ = false;
|
| 237 |
+
|
| 238 |
+
size_t generation_ = 0;
|
| 239 |
+
std::vector<at::functionalization::ViewMeta> view_metas_;
|
| 240 |
+
|
| 241 |
+
protected:
|
| 242 |
+
static void copy_tensor_metadata(
|
| 243 |
+
const FunctionalTensorWrapper* src_impl,
|
| 244 |
+
FunctionalTensorWrapper* dest_impl,
|
| 245 |
+
const c10::VariableVersion& version_counter,
|
| 246 |
+
bool allow_tensor_metadata_change);
|
| 247 |
+
};
|
| 248 |
+
|
| 249 |
+
// Utility functions for the functionalization pass.
|
| 250 |
+
|
| 251 |
+
namespace functionalization {
|
| 252 |
+
namespace impl {
|
| 253 |
+
|
| 254 |
+
TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
|
| 255 |
+
const Tensor& tensor) {
|
| 256 |
+
auto functional_impl =
|
| 257 |
+
static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
|
| 258 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
|
| 259 |
+
return functional_impl;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
|
| 263 |
+
TORCH_API bool isFunctionalTensor(const c10::optional<Tensor>& t);
|
| 264 |
+
TORCH_API bool isFunctionalTensor(
|
| 265 |
+
const c10::List<c10::optional<Tensor>>& t_list);
|
| 266 |
+
TORCH_API bool isFunctionalTensor(ITensorListRef list);
|
| 267 |
+
|
| 268 |
+
TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
|
| 269 |
+
TORCH_API c10::optional<Tensor> to_functional_tensor(
|
| 270 |
+
const c10::optional<Tensor>& tensor);
|
| 271 |
+
TORCH_API c10::List<c10::optional<Tensor>> to_functional_tensor(
|
| 272 |
+
const c10::List<c10::optional<Tensor>>& t_list);
|
| 273 |
+
TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
|
| 274 |
+
|
| 275 |
+
TORCH_API void freeze_functional_tensor(const Tensor& tensor);
|
| 276 |
+
|
| 277 |
+
TORCH_API Tensor
|
| 278 |
+
from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
|
| 279 |
+
TORCH_API c10::optional<Tensor> from_functional_tensor(
|
| 280 |
+
const c10::optional<Tensor>& t,
|
| 281 |
+
bool assert_functional = true);
|
| 282 |
+
TORCH_API c10::List<c10::optional<Tensor>> from_functional_tensor(
|
| 283 |
+
const c10::List<c10::optional<Tensor>>& t_list);
|
| 284 |
+
TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
|
| 285 |
+
|
| 286 |
+
TORCH_API void sync(const at::Tensor& t);
|
| 287 |
+
TORCH_API void sync(const c10::optional<Tensor>& t);
|
| 288 |
+
TORCH_API void sync(const c10::List<c10::optional<Tensor>>& t_list);
|
| 289 |
+
TORCH_API void sync(ITensorListRef t_list);
|
| 290 |
+
|
| 291 |
+
TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
|
| 292 |
+
TORCH_API void replace_(
|
| 293 |
+
const ITensorListRef functional_tensor,
|
| 294 |
+
ITensorListRef other);
|
| 295 |
+
|
| 296 |
+
TORCH_API void commit_update(const Tensor& functional_tensor);
|
| 297 |
+
TORCH_API void commit_update(ITensorListRef functional_tensor);
|
| 298 |
+
|
| 299 |
+
TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
|
| 300 |
+
|
| 301 |
+
TORCH_API void mark_mutation_hidden_from_autograd(
|
| 302 |
+
const Tensor& functional_tensor);
|
| 303 |
+
|
| 304 |
+
TORCH_API bool are_all_mutations_hidden_from_autograd(
|
| 305 |
+
const Tensor& functional_tensor);
|
| 306 |
+
|
| 307 |
+
TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
|
| 308 |
+
const Tensor& functional_tensor);
|
| 309 |
+
|
| 310 |
+
// These two methods are XLA-specific logic and are no-ops
|
| 311 |
+
// for the normal functionalization flow.
|
| 312 |
+
TORCH_API void propagate_xla_data(
|
| 313 |
+
const Tensor& functional_tensor,
|
| 314 |
+
const Tensor& other);
|
| 315 |
+
TORCH_API void propagate_xla_data(
|
| 316 |
+
const ITensorListRef functional_tensor,
|
| 317 |
+
ITensorListRef other);
|
| 318 |
+
|
| 319 |
+
Tensor create_functional_tensor_with_view_meta(
|
| 320 |
+
const Tensor& view_to_wrap,
|
| 321 |
+
const Tensor& base,
|
| 322 |
+
functionalization::ViewMeta meta,
|
| 323 |
+
int64_t out_idx = 0);
|
| 324 |
+
std::vector<Tensor> create_functional_tensor_with_view_meta(
|
| 325 |
+
ITensorListRef view_to_wrap,
|
| 326 |
+
const Tensor& base,
|
| 327 |
+
const functionalization::ViewMeta& meta);
|
| 328 |
+
|
| 329 |
+
void mutate_view_meta(
|
| 330 |
+
const Tensor& self,
|
| 331 |
+
const functionalization::ViewMeta& meta);
|
| 332 |
+
|
| 333 |
+
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
|
| 334 |
+
void set_sizes_strides_offset(
|
| 335 |
+
const std::vector<Tensor>& outs,
|
| 336 |
+
const std::vector<Tensor>& meta_outs);
|
| 337 |
+
|
| 338 |
+
// ~~~~~ TLS used in functionalization ~~~~~
|
| 339 |
+
|
| 340 |
+
TORCH_API bool getFunctionalizationReapplyViewsTLS();
|
| 341 |
+
TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
|
| 342 |
+
|
| 343 |
+
class TORCH_API FunctionalizationReapplyViewsGuard {
|
| 344 |
+
public:
|
| 345 |
+
FunctionalizationReapplyViewsGuard(bool reapply_views)
|
| 346 |
+
: prev_(getFunctionalizationReapplyViewsTLS()) {
|
| 347 |
+
setFunctionalizationReapplyViewsTLS(reapply_views);
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
~FunctionalizationReapplyViewsGuard() {
|
| 351 |
+
setFunctionalizationReapplyViewsTLS(prev_);
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
FunctionalizationReapplyViewsGuard(
|
| 355 |
+
const FunctionalizationReapplyViewsGuard&) = delete;
|
| 356 |
+
FunctionalizationReapplyViewsGuard operator=(
|
| 357 |
+
const FunctionalizationReapplyViewsGuard&) = delete;
|
| 358 |
+
FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
|
| 359 |
+
delete;
|
| 360 |
+
FunctionalizationReapplyViewsGuard operator=(
|
| 361 |
+
FunctionalizationReapplyViewsGuard&&) = delete;
|
| 362 |
+
|
| 363 |
+
private:
|
| 364 |
+
bool prev_;
|
| 365 |
+
};
|
| 366 |
+
|
| 367 |
+
} // namespace impl
|
| 368 |
+
|
| 369 |
+
// Helper function to call an out-of-place composite aten kernel that may use
|
| 370 |
+
// mutations / views internally, and functionalize them.
|
| 371 |
+
TORCH_API void functionalize_op_helper(
|
| 372 |
+
const c10::OperatorHandle& op,
|
| 373 |
+
torch::jit::Stack* stack);
|
| 374 |
+
|
| 375 |
+
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 376 |
+
struct _functionalize_aten_op final {};
|
| 377 |
+
|
| 378 |
+
template <class Op, bool symint, class ReturnType, class... ParameterTypes>
|
| 379 |
+
struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
|
| 380 |
+
static ReturnType call(
|
| 381 |
+
typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
|
| 382 |
+
using FuncType = ReturnType(
|
| 383 |
+
typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
|
| 384 |
+
auto op = c10::Dispatcher::singleton()
|
| 385 |
+
.findSchemaOrThrow(
|
| 386 |
+
(const char*)Op::name, (const char*)Op::overload_name)
|
| 387 |
+
.typed<FuncType>();
|
| 388 |
+
|
| 389 |
+
return c10::impl::BoxedKernelWrapper<FuncType>::call(
|
| 390 |
+
c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
|
| 391 |
+
op,
|
| 392 |
+
// BoxedKernelWrapper knows to ignore this keyset argument,
|
| 393 |
+
// because functionalize_op_helper doesn't take in a DispatchKeySet
|
| 394 |
+
c10::DispatchKeySet(),
|
| 395 |
+
args...);
|
| 396 |
+
}
|
| 397 |
+
};
|
| 398 |
+
|
| 399 |
+
template <class Op>
|
| 400 |
+
using functionalize_aten_op =
|
| 401 |
+
_functionalize_aten_op<Op, false, typename Op::schema>;
|
| 402 |
+
|
| 403 |
+
template <class Op>
|
| 404 |
+
using functionalize_aten_op_symint =
|
| 405 |
+
_functionalize_aten_op<Op, true, typename Op::schema>;
|
| 406 |
+
|
| 407 |
+
} // namespace functionalization
|
| 408 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Generator.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Generator.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedFallback.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/ATen.h>
|
| 3 |
+
#include <ATen/core/op_registration/op_registration.h>
|
| 4 |
+
#include <torch/library.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
// If an operator doesn't have a batching rule implemented then we fallback
|
| 9 |
+
// to this implementation. The fallback only works on out-of-place operators
|
| 10 |
+
// that return only tensors with new memory. (e.g., no in-place operators, no
|
| 11 |
+
// view operations).
|
| 12 |
+
//
|
| 13 |
+
// The fallback effectively takes all of the BatchedTensors in `stack`, slices
|
| 14 |
+
// them, and runs `op` on all of the corresponding slices to produce slices
|
| 15 |
+
// of the outputs. The output slices then get `torch.stack`ed to create the
|
| 16 |
+
// final returns.
|
| 17 |
+
//
|
| 18 |
+
// The performance of the fallback is not very good because it introduces an
|
| 19 |
+
// extra copy from stacking the sliced outputs. Because of this, we prefer to
|
| 20 |
+
// write batching rules for operators whenever possible.
|
| 21 |
+
void batchedTensorForLoopFallback(
|
| 22 |
+
const c10::OperatorHandle& op,
|
| 23 |
+
torch::jit::Stack* stack);
|
| 24 |
+
|
| 25 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <bitset>
|
| 4 |
+
|
| 5 |
+
#include <ATen/ArrayRef.h>
|
| 6 |
+
#include <ATen/SmallVector.h>
|
| 7 |
+
#include <ATen/Tensor.h>
|
| 8 |
+
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
// We assume this in a few other places in the codebase,
|
| 12 |
+
// but there isn't a centralized definition.
|
| 13 |
+
constexpr int64_t kVmapMaxTensorDims = 64;
|
| 14 |
+
|
| 15 |
+
// The valid vmap levels range from [0, 64). This effectively means that we
|
| 16 |
+
// support a maximum of 64 nested vmaps.
|
| 17 |
+
constexpr int64_t kVmapNumLevels = 64;
|
| 18 |
+
|
| 19 |
+
// Store this number of elements of BatchDims on the stack. Most people will
|
| 20 |
+
// probably use <= 5 nested vmaps, but adjust this number as necessary.
|
| 21 |
+
constexpr int64_t kBatchDimsStackSize = 5;
|
| 22 |
+
|
| 23 |
+
// a BatchDim represents a "private" dimension on a Tensor created inside of
|
| 24 |
+
// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
|
| 25 |
+
// is being vmap'ed over and the `level` being an identifier for which vmap
|
| 26 |
+
// said dimension was created inside. The `dim` corresponds to a "physical
|
| 27 |
+
// dim" - it is a dimension index on the underlying physical tensor that is
|
| 28 |
+
// being vmapped over.
|
| 29 |
+
struct BatchDim {
|
| 30 |
+
BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
|
| 31 |
+
int64_t dim() const {
|
| 32 |
+
return dim_;
|
| 33 |
+
}
|
| 34 |
+
int64_t level() const {
|
| 35 |
+
return level_;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
private:
|
| 39 |
+
int64_t dim_;
|
| 40 |
+
int64_t level_;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
|
| 44 |
+
using BatchDimsRef = ArrayRef<BatchDim>;
|
| 45 |
+
|
| 46 |
+
// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
|
| 47 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 48 |
+
// BatchedTensorImpl.
|
| 49 |
+
//
|
| 50 |
+
// The batch dimensions are treated as being "private"; they are not
|
| 51 |
+
// user-visible. For example, in the following Tensor,
|
| 52 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
|
| 53 |
+
// dimensions 0 and 1 are batch dimensions.
|
| 54 |
+
//
|
| 55 |
+
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
|
| 56 |
+
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
|
| 57 |
+
// tensor.
|
| 58 |
+
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
|
| 59 |
+
explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
|
| 60 |
+
|
| 61 |
+
// Returns a reference to BatchDims that represent which dimensions of this
|
| 62 |
+
// tensor are private.
|
| 63 |
+
BatchDimsRef bdims() const {
|
| 64 |
+
return bdims_;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// BatchedTensorImpl wraps a Tensor
|
| 68 |
+
const Tensor& value() const {
|
| 69 |
+
return value_;
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
// Given a public dimension index, return the dimension index in the
|
| 73 |
+
// underlying value() tensor. For example, if we have
|
| 74 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
|
| 75 |
+
// dim=2)])
|
| 76 |
+
// bt.actualDim(0) -> 1
|
| 77 |
+
// bt.actualDim(1) -> 3
|
| 78 |
+
// bt.actualDim(2) -> Error
|
| 79 |
+
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
|
| 80 |
+
|
| 81 |
+
// We have to override this because we opted into CustomStrides
|
| 82 |
+
IntArrayRef strides_custom() const override;
|
| 83 |
+
// Override a bunch of methods inherited from TensorImpl to return error
|
| 84 |
+
// messages.
|
| 85 |
+
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
| 86 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 87 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 88 |
+
void set_storage_offset(int64_t storage_offset) override;
|
| 89 |
+
#ifdef DEBUG
|
| 90 |
+
bool has_storage() const override;
|
| 91 |
+
#endif
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
// see NOTE: [BatchedTensorImpl levels invariant]
|
| 95 |
+
void checkInvariants() const;
|
| 96 |
+
const char* tensorimpl_type_name() const override;
|
| 97 |
+
|
| 98 |
+
Tensor value_;
|
| 99 |
+
|
| 100 |
+
// Note: [BatchedTensorImpl levels invariant]
|
| 101 |
+
// There is an invariant that the BatchDims must be stored in increasing
|
| 102 |
+
// `level` order. That is, for i < j, bdims_[i].level must be less than
|
| 103 |
+
// bdims_[j].level.
|
| 104 |
+
BatchDims bdims_;
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 108 |
+
// BatchedTensorImpl.
|
| 109 |
+
inline bool isBatchedTensor(const Tensor& tensor) {
|
| 110 |
+
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// It is unsafe to call this on a Tensor that is not backed by a
|
| 114 |
+
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
|
| 115 |
+
inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
|
| 116 |
+
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
|
| 120 |
+
if (!isBatchedTensor(tensor)) {
|
| 121 |
+
return nullptr;
|
| 122 |
+
}
|
| 123 |
+
return unsafeGetBatchedImpl(tensor);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
|
| 127 |
+
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
|
| 128 |
+
BatchDimsRef bdims) {
|
| 129 |
+
std::bitset<kVmapMaxTensorDims> is_bdim;
|
| 130 |
+
for (const auto& bdim : bdims) {
|
| 131 |
+
is_bdim.set(bdim.dim());
|
| 132 |
+
}
|
| 133 |
+
return is_bdim;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// Creates a bitset for all of the levels present in `bdims`
|
| 137 |
+
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
|
| 138 |
+
std::bitset<kVmapNumLevels> result;
|
| 139 |
+
for (const auto& bdim : bdims) {
|
| 140 |
+
result.set(bdim.level());
|
| 141 |
+
}
|
| 142 |
+
return result;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
|
| 146 |
+
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
|
| 147 |
+
return out;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
// Use this to construct a BatchedTensor from a regular Tensor
|
| 151 |
+
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
|
| 152 |
+
|
| 153 |
+
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
| 154 |
+
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
|
| 155 |
+
|
| 156 |
+
// Checks if an inplace operation on self and other is "vmap compatible".
|
| 157 |
+
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
|
| 158 |
+
TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
|
| 159 |
+
|
| 160 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MapAllocator.h
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Allocator.h>
|
| 4 |
+
#include <c10/util/string_view.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
enum MappedAllocatorModes {
|
| 9 |
+
ALLOCATOR_MAPPED_SHARED = 1,
|
| 10 |
+
ALLOCATOR_MAPPED_SHAREDMEM = 2,
|
| 11 |
+
ALLOCATOR_MAPPED_EXCLUSIVE = 4,
|
| 12 |
+
ALLOCATOR_MAPPED_NOCREATE = 8,
|
| 13 |
+
ALLOCATOR_MAPPED_KEEPFD = 16,
|
| 14 |
+
ALLOCATOR_MAPPED_FROMFD = 32,
|
| 15 |
+
ALLOCATOR_MAPPED_UNLINK = 64
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
// Sentinel value/type to help distinguish the file descriptor constructor from
|
| 19 |
+
// the non-file descriptor constructor
|
| 20 |
+
enum WithFd { WITH_FD };
|
| 21 |
+
|
| 22 |
+
TORCH_API std::string NewProcessWideShmHandle();
|
| 23 |
+
|
| 24 |
+
class TORCH_API MapAllocator {
|
| 25 |
+
public:
|
| 26 |
+
MapAllocator(c10::string_view filename, int flags, size_t size);
|
| 27 |
+
MapAllocator(
|
| 28 |
+
WithFd,
|
| 29 |
+
c10::string_view filename,
|
| 30 |
+
int fd,
|
| 31 |
+
int flags,
|
| 32 |
+
size_t size);
|
| 33 |
+
MapAllocator(const MapAllocator&) = delete;
|
| 34 |
+
MapAllocator& operator=(const MapAllocator&) = delete;
|
| 35 |
+
MapAllocator(MapAllocator&&) = delete;
|
| 36 |
+
MapAllocator& operator=(MapAllocator&&) = delete;
|
| 37 |
+
|
| 38 |
+
const char* filename() const {
|
| 39 |
+
return filename_.c_str();
|
| 40 |
+
}
|
| 41 |
+
int fd() const {
|
| 42 |
+
#ifdef _WIN32
|
| 43 |
+
TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows");
|
| 44 |
+
#else
|
| 45 |
+
return fd_;
|
| 46 |
+
#endif
|
| 47 |
+
}
|
| 48 |
+
ptrdiff_t size() const {
|
| 49 |
+
return size_;
|
| 50 |
+
}
|
| 51 |
+
// Return a pointer to the actual data for this allocator
|
| 52 |
+
// (in the case of the refcounted allocator, this is offset
|
| 53 |
+
// from the base pointer.)
|
| 54 |
+
virtual void* data() const {
|
| 55 |
+
return base_ptr_;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
static MapAllocator* fromDataPtr(const at::DataPtr&);
|
| 59 |
+
static at::DataPtr makeDataPtr(
|
| 60 |
+
c10::string_view filename,
|
| 61 |
+
int flags,
|
| 62 |
+
size_t size,
|
| 63 |
+
size_t* actual_size_out);
|
| 64 |
+
static at::DataPtr makeDataPtr(
|
| 65 |
+
WithFd,
|
| 66 |
+
const char* filename,
|
| 67 |
+
int fd,
|
| 68 |
+
int flags,
|
| 69 |
+
size_t size,
|
| 70 |
+
size_t* actual_size_out);
|
| 71 |
+
|
| 72 |
+
// Closes the data. Helps us avoid destructor shenanigans
|
| 73 |
+
virtual void close();
|
| 74 |
+
|
| 75 |
+
// This is very dangerous. You have to redefine this destructor for each
|
| 76 |
+
// subclass
|
| 77 |
+
virtual ~MapAllocator();
|
| 78 |
+
|
| 79 |
+
protected:
|
| 80 |
+
bool closed_ = false;
|
| 81 |
+
std::string filename_;
|
| 82 |
+
int flags_ = 0;
|
| 83 |
+
ptrdiff_t size_; /* mapped size */
|
| 84 |
+
#ifdef _WIN32
|
| 85 |
+
void* handle_;
|
| 86 |
+
void* event_;
|
| 87 |
+
std::string eventname_;
|
| 88 |
+
#else
|
| 89 |
+
int fd_ = -1;
|
| 90 |
+
#endif
|
| 91 |
+
void* base_ptr_ = nullptr;
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
// Base-from-member idiom
|
| 95 |
+
struct TORCH_API RefcountedMapAllocatorArgCheck {
|
| 96 |
+
RefcountedMapAllocatorArgCheck(int flags);
|
| 97 |
+
};
|
| 98 |
+
|
| 99 |
+
class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck,
|
| 100 |
+
public MapAllocator {
|
| 101 |
+
public:
|
| 102 |
+
RefcountedMapAllocator(const char* filename, int flags, size_t size);
|
| 103 |
+
RefcountedMapAllocator(
|
| 104 |
+
WithFd,
|
| 105 |
+
const char* filename,
|
| 106 |
+
int fd,
|
| 107 |
+
int flags,
|
| 108 |
+
size_t size);
|
| 109 |
+
|
| 110 |
+
static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&);
|
| 111 |
+
static at::DataPtr makeDataPtr(
|
| 112 |
+
const char* filename,
|
| 113 |
+
int flags,
|
| 114 |
+
size_t size,
|
| 115 |
+
size_t* actual_size_out);
|
| 116 |
+
static at::DataPtr makeDataPtr(
|
| 117 |
+
WithFd,
|
| 118 |
+
const char* filename,
|
| 119 |
+
int fd,
|
| 120 |
+
int flags,
|
| 121 |
+
size_t size,
|
| 122 |
+
size_t* actual_size_out);
|
| 123 |
+
|
| 124 |
+
void* data() const override;
|
| 125 |
+
|
| 126 |
+
void incref();
|
| 127 |
+
int decref();
|
| 128 |
+
void close() override;
|
| 129 |
+
|
| 130 |
+
~RefcountedMapAllocator() override {
|
| 131 |
+
RefcountedMapAllocator::close();
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
protected:
|
| 135 |
+
void checkFlags();
|
| 136 |
+
void initializeAlloc();
|
| 137 |
+
};
|
| 138 |
+
|
| 139 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensor.h
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/NamedTensor.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NestedTensorImpl.h
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/MemoryOverlap.h>
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <c10/core/DispatchKey.h>
|
| 5 |
+
#include <c10/core/DispatchKeySet.h>
|
| 6 |
+
#include <c10/core/MemoryFormat.h>
|
| 7 |
+
#include <c10/core/TensorImpl.h>
|
| 8 |
+
#include <c10/util/ArrayRef.h>
|
| 9 |
+
#include <c10/util/Exception.h>
|
| 10 |
+
#include <c10/util/Metaprogramming.h>
|
| 11 |
+
#include <c10/util/irange.h>
|
| 12 |
+
|
| 13 |
+
namespace at::native {
|
| 14 |
+
struct NestedTensorImpl;
|
| 15 |
+
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
|
| 16 |
+
int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
|
| 17 |
+
|
| 18 |
+
struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
|
| 19 |
+
explicit NestedTensorImpl(
|
| 20 |
+
Storage storage,
|
| 21 |
+
c10::DispatchKeySet key_set,
|
| 22 |
+
const caffe2::TypeMeta data_type,
|
| 23 |
+
at::Tensor nested_sizes,
|
| 24 |
+
at::Tensor nested_strides,
|
| 25 |
+
at::Tensor storage_offsets);
|
| 26 |
+
|
| 27 |
+
explicit NestedTensorImpl(
|
| 28 |
+
const at::Tensor& buffer,
|
| 29 |
+
at::Tensor nested_sizes,
|
| 30 |
+
at::Tensor nested_strides,
|
| 31 |
+
at::Tensor storage_offsets);
|
| 32 |
+
// assume contiguous, `nested_strides` and `offsets`
|
| 33 |
+
// can be infered from `nested_sizes`
|
| 34 |
+
explicit NestedTensorImpl(
|
| 35 |
+
const at::Tensor& buffer,
|
| 36 |
+
const at::Tensor& nested_sizes);
|
| 37 |
+
|
| 38 |
+
// This constructor is used creating view tensors from nested tensors
|
| 39 |
+
explicit NestedTensorImpl(
|
| 40 |
+
c10::TensorImpl::ImplType impl_type,
|
| 41 |
+
const at::Tensor& base_tensor,
|
| 42 |
+
at::Tensor nested_sizes,
|
| 43 |
+
at::Tensor nested_strides,
|
| 44 |
+
at::Tensor storage_offsets);
|
| 45 |
+
|
| 46 |
+
// TODO: don't expose private implementation details like this; in
|
| 47 |
+
// particular, resizing this tensor will mess up our dim() and
|
| 48 |
+
// callers cannot fix it.
|
| 49 |
+
const Tensor& get_nested_sizes() const {
|
| 50 |
+
return nested_sizes_;
|
| 51 |
+
}
|
| 52 |
+
// TODO: don't expose private implementation details like this
|
| 53 |
+
const Tensor& get_nested_strides() const {
|
| 54 |
+
return nested_strides_;
|
| 55 |
+
}
|
| 56 |
+
const Tensor& get_storage_offsets() const {
|
| 57 |
+
return storage_offsets_;
|
| 58 |
+
}
|
| 59 |
+
// Returns nullopt if the ith dimension is irregular. The ith dimension
|
| 60 |
+
// of a NestedTensor is regular if the unbound tensors match in
|
| 61 |
+
// size at the (i-1)th dimension.
|
| 62 |
+
c10::optional<int64_t> opt_size(int64_t d) const;
|
| 63 |
+
|
| 64 |
+
int64_t size(int64_t d) const {
|
| 65 |
+
c10::optional<int64_t> optional_size = this->opt_size(d);
|
| 66 |
+
TORCH_CHECK(
|
| 67 |
+
optional_size.has_value(),
|
| 68 |
+
"Given dimension ",
|
| 69 |
+
d,
|
| 70 |
+
" is irregular and does not have a size.");
|
| 71 |
+
return *optional_size;
|
| 72 |
+
}
|
| 73 |
+
/**
|
| 74 |
+
* Return a view of the nested tensor as a 1 dimensional contiguous tensor.
|
| 75 |
+
*
|
| 76 |
+
* The buffer tensor created by this function shares the same storage_impl as
|
| 77 |
+
* the original nested tensor, and therefore can be seen as a view.
|
| 78 |
+
*
|
| 79 |
+
* @return A newly constructed view tensor
|
| 80 |
+
*/
|
| 81 |
+
at::Tensor get_buffer() const {
|
| 82 |
+
TORCH_CHECK(
|
| 83 |
+
nested_tensor_impl_is_contiguous(this),
|
| 84 |
+
"NestedTensor must be contiguous to get buffer.");
|
| 85 |
+
return get_unsafe_storage_as_tensor();
|
| 86 |
+
}
|
| 87 |
+
/**
|
| 88 |
+
* If possible use get_buffer() instead. This function returns the storage
|
| 89 |
+
* as a tensor directly, which is not safe to use in general. If using this
|
| 90 |
+
* function, The caller must ensure to account for nested_sizes,
|
| 91 |
+
* nested_strides and storage_offsets.
|
| 92 |
+
*
|
| 93 |
+
* @return A newly constructed view tensor
|
| 94 |
+
*/
|
| 95 |
+
at::Tensor get_unsafe_storage_as_tensor() const {
|
| 96 |
+
auto buffer_key_set_ = generate_buffer_key_set();
|
| 97 |
+
const auto buffer_size = get_buffer_size();
|
| 98 |
+
auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
|
| 99 |
+
c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
|
| 100 |
+
buffer_tensor_impl->set_sizes_contiguous(
|
| 101 |
+
c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
|
| 102 |
+
return Tensor(buffer_tensor_impl);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
size_t get_buffer_size() const {
|
| 106 |
+
return storage_.nbytes() / data_type_.itemsize();
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
protected:
|
| 110 |
+
const char* tensorimpl_type_name() const override;
|
| 111 |
+
|
| 112 |
+
// TODO: numel_custom and is_contiguous_custom can be profitably overridden
|
| 113 |
+
// with real implementations
|
| 114 |
+
int64_t numel_custom() const override;
|
| 115 |
+
c10::SymInt sym_numel_custom() const override;
|
| 116 |
+
bool is_contiguous_custom(MemoryFormat) const override;
|
| 117 |
+
int64_t size_custom(int64_t d) const override {
|
| 118 |
+
return this->size(d);
|
| 119 |
+
}
|
| 120 |
+
c10::SymInt sym_size_custom(int64_t d) const override {
|
| 121 |
+
return c10::SymInt{this->size(d)};
|
| 122 |
+
}
|
| 123 |
+
IntArrayRef sizes_custom() const override;
|
| 124 |
+
c10::SymIntArrayRef sym_sizes_custom() const override;
|
| 125 |
+
IntArrayRef strides_custom() const override;
|
| 126 |
+
c10::SymIntArrayRef sym_strides_custom() const override;
|
| 127 |
+
|
| 128 |
+
// this one is real
|
| 129 |
+
int64_t dim_custom() const override;
|
| 130 |
+
|
| 131 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 132 |
+
const c10::VariableVersion& version_counter,
|
| 133 |
+
bool allow_tensor_metadata_change) const override;
|
| 134 |
+
|
| 135 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 136 |
+
c10::VariableVersion&& version_counter,
|
| 137 |
+
bool allow_tensor_metadata_change) const override;
|
| 138 |
+
|
| 139 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
| 140 |
+
copy_tensor_metadata(
|
| 141 |
+
/*src_impl=*/impl.get(),
|
| 142 |
+
/*dest_impl=*/this,
|
| 143 |
+
/*version_counter=*/version_counter(),
|
| 144 |
+
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
private:
|
| 148 |
+
// Must be called after any changes to our dim() to sync the state
|
| 149 |
+
// to TensorImpl.
|
| 150 |
+
void refresh_dim();
|
| 151 |
+
|
| 152 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 153 |
+
const at::Tensor nested_sizes_, nested_strides_;
|
| 154 |
+
// The starting positions of the underlying tensors in contiguous buffer
|
| 155 |
+
// i.e. the buffer memory offsets to get the underlying tensors
|
| 156 |
+
// The reason to keep this metadata is that, without strong enough constraint
|
| 157 |
+
// it cannot be derived from `nested_sizes_`
|
| 158 |
+
// and `nested_strides_`:
|
| 159 |
+
// 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
|
| 160 |
+
// this can happen e.g. after slicing a nested tensor
|
| 161 |
+
// 2. when multiple tensors share a same memory
|
| 162 |
+
// 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
|
| 163 |
+
// Some strong enough constraints are:
|
| 164 |
+
// 1. every underlying tensor is contiguous in memory
|
| 165 |
+
// && nesting in ascending order
|
| 166 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 167 |
+
const at::Tensor storage_offsets_;
|
| 168 |
+
// NOTE: -1 here means the size is missing
|
| 169 |
+
// Optional to allow it to be computed lazily from nested.
|
| 170 |
+
// TODO: maybe we can remove this metadata since
|
| 171 |
+
// we can compute it from `nested_sizes_`
|
| 172 |
+
mutable c10::optional<std::vector<int64_t>> opt_sizes_;
|
| 173 |
+
|
| 174 |
+
template <typename VariableVersion>
|
| 175 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
|
| 176 |
+
VariableVersion&& version_counter,
|
| 177 |
+
bool allow_tensor_metadata_change) const;
|
| 178 |
+
|
| 179 |
+
/**
|
| 180 |
+
* Generates a non-nested key_set from a nested tensor.
|
| 181 |
+
*
|
| 182 |
+
* For many nested tensor kernel implementations a buffer tensor
|
| 183 |
+
* is generated and redispatched to a non-nested kernel this function
|
| 184 |
+
* generates the key set used by that buffer tensor
|
| 185 |
+
*
|
| 186 |
+
* @return Appropriate key set for non-nested tensor
|
| 187 |
+
*/
|
| 188 |
+
inline c10::DispatchKeySet generate_buffer_key_set() const {
|
| 189 |
+
auto buffer_key_set = this->key_set();
|
| 190 |
+
const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
|
| 191 |
+
// Remove nested tensor specific keys
|
| 192 |
+
buffer_key_set = buffer_key_set -
|
| 193 |
+
c10::DispatchKeySet{
|
| 194 |
+
c10::DispatchKey::NestedTensor,
|
| 195 |
+
c10::DispatchKey::AutogradNestedTensor};
|
| 196 |
+
|
| 197 |
+
// Add dense tensor specific keys
|
| 198 |
+
buffer_key_set =
|
| 199 |
+
buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
|
| 200 |
+
buffer_key_set = Autograd
|
| 201 |
+
? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
|
| 202 |
+
: buffer_key_set;
|
| 203 |
+
|
| 204 |
+
return buffer_key_set;
|
| 205 |
+
}
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
inline NestedTensorImpl* get_nested_tensor_impl_or_null(
|
| 209 |
+
const at::Tensor& tensor) {
|
| 210 |
+
if (tensor.is_nested()) {
|
| 211 |
+
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 212 |
+
}
|
| 213 |
+
return nullptr;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
|
| 217 |
+
TORCH_CHECK(
|
| 218 |
+
tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
|
| 219 |
+
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
|
| 223 |
+
int64_t ntensors = nt->size(0);
|
| 224 |
+
if (ntensors == 0) {
|
| 225 |
+
return true;
|
| 226 |
+
}
|
| 227 |
+
const Tensor &sizemat = nt->get_nested_sizes(),
|
| 228 |
+
&stridemat = nt->get_nested_strides();
|
| 229 |
+
int64_t* offsets_ptr = nt->get_storage_offsets().data_ptr<int64_t>();
|
| 230 |
+
int64_t orig_dim = sizemat.size(1);
|
| 231 |
+
// nesting scalars
|
| 232 |
+
if (orig_dim == 0) {
|
| 233 |
+
// each scalar must be contiguous
|
| 234 |
+
// if there is blank memory between underlying scalars
|
| 235 |
+
for (int64_t i = 0; i < ntensors; i++) {
|
| 236 |
+
if (offsets_ptr[i] != i) {
|
| 237 |
+
return false;
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
// nesting tensors
|
| 242 |
+
else {
|
| 243 |
+
// if any underlying tensor is non-contiguous
|
| 244 |
+
const int64_t *sizemat_ptr = sizemat.data_ptr<int64_t>(),
|
| 245 |
+
*stridemat_ptr = stridemat.data_ptr<int64_t>();
|
| 246 |
+
for (int64_t i = 0; i < ntensors; i++) {
|
| 247 |
+
if (stridemat_ptr[orig_dim - 1] != 1) {
|
| 248 |
+
return false;
|
| 249 |
+
}
|
| 250 |
+
int64_t product = sizemat_ptr[orig_dim - 1];
|
| 251 |
+
for (int64_t j = orig_dim - 2; j >= 0; j--) {
|
| 252 |
+
if (stridemat_ptr[j] != product) {
|
| 253 |
+
return false;
|
| 254 |
+
}
|
| 255 |
+
product *= sizemat_ptr[j];
|
| 256 |
+
}
|
| 257 |
+
sizemat_ptr += orig_dim;
|
| 258 |
+
stridemat_ptr += orig_dim;
|
| 259 |
+
}
|
| 260 |
+
// if there is blank memory between underlying tensors
|
| 261 |
+
if (offsets_ptr[0] != 0) {
|
| 262 |
+
return false;
|
| 263 |
+
}
|
| 264 |
+
sizemat_ptr = sizemat.data_ptr<int64_t>();
|
| 265 |
+
stridemat_ptr = stridemat.data_ptr<int64_t>();
|
| 266 |
+
for (int64_t i = 1; i < ntensors; i++) {
|
| 267 |
+
if (offsets_ptr[i] !=
|
| 268 |
+
offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
|
| 269 |
+
return false;
|
| 270 |
+
}
|
| 271 |
+
sizemat_ptr += orig_dim;
|
| 272 |
+
stridemat_ptr += orig_dim;
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
// everything is fine
|
| 276 |
+
return true;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
|
| 280 |
+
return get_nested_tensor_impl(tensor)->get_nested_sizes();
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
} // namespace at::native
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/PadNd.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/util/Exception.h>
|
| 3 |
+
#include <c10/util/string_view.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
|
| 7 |
+
enum class padding_mode {
|
| 8 |
+
reflect,
|
| 9 |
+
replicate,
|
| 10 |
+
circular,
|
| 11 |
+
constant,
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
static inline c10::string_view padding_mode_string(padding_mode m) {
|
| 15 |
+
switch (m) {
|
| 16 |
+
case padding_mode::reflect:
|
| 17 |
+
return "reflect";
|
| 18 |
+
case padding_mode::replicate:
|
| 19 |
+
return "replicate";
|
| 20 |
+
case padding_mode::circular:
|
| 21 |
+
return "circular";
|
| 22 |
+
case padding_mode::constant:
|
| 23 |
+
return "constant";
|
| 24 |
+
}
|
| 25 |
+
TORCH_CHECK(false, "Invalid padding mode (", static_cast<int64_t>(m), ")");
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel-inl.h
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/Exception.h>
|
| 4 |
+
#include <c10/util/ParallelGuard.h>
|
| 5 |
+
#include <c10/util/SmallVector.h>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
template <class F>
|
| 10 |
+
inline void parallel_for(
|
| 11 |
+
const int64_t begin,
|
| 12 |
+
const int64_t end,
|
| 13 |
+
const int64_t grain_size,
|
| 14 |
+
const F& f) {
|
| 15 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
|
| 16 |
+
if (begin >= end) {
|
| 17 |
+
return;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
#ifdef INTRA_OP_PARALLEL
|
| 21 |
+
at::internal::lazy_init_num_threads();
|
| 22 |
+
const auto numiter = end - begin;
|
| 23 |
+
const bool use_parallel =
|
| 24 |
+
(numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
|
| 25 |
+
at::get_num_threads() > 1);
|
| 26 |
+
if (!use_parallel) {
|
| 27 |
+
internal::ThreadIdGuard tid_guard(0);
|
| 28 |
+
c10::ParallelGuard guard(true);
|
| 29 |
+
f(begin, end);
|
| 30 |
+
return;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
internal::invoke_parallel(
|
| 34 |
+
begin, end, grain_size, [&](int64_t begin, int64_t end) {
|
| 35 |
+
c10::ParallelGuard guard(true);
|
| 36 |
+
f(begin, end);
|
| 37 |
+
});
|
| 38 |
+
#else
|
| 39 |
+
internal::ThreadIdGuard tid_guard(0);
|
| 40 |
+
c10::ParallelGuard guard(true);
|
| 41 |
+
f(begin, end);
|
| 42 |
+
#endif
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <class scalar_t, class F, class SF>
|
| 46 |
+
inline scalar_t parallel_reduce(
|
| 47 |
+
const int64_t begin,
|
| 48 |
+
const int64_t end,
|
| 49 |
+
const int64_t grain_size,
|
| 50 |
+
const scalar_t ident,
|
| 51 |
+
const F& f,
|
| 52 |
+
const SF& sf) {
|
| 53 |
+
TORCH_CHECK(grain_size >= 0);
|
| 54 |
+
if (begin >= end) {
|
| 55 |
+
return ident;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
#ifdef INTRA_OP_PARALLEL
|
| 59 |
+
at::internal::lazy_init_num_threads();
|
| 60 |
+
const auto max_threads = at::get_num_threads();
|
| 61 |
+
const bool use_parallel =
|
| 62 |
+
((end - begin) > grain_size && !at::in_parallel_region() &&
|
| 63 |
+
max_threads > 1);
|
| 64 |
+
if (!use_parallel) {
|
| 65 |
+
internal::ThreadIdGuard tid_guard(0);
|
| 66 |
+
c10::ParallelGuard guard(true);
|
| 67 |
+
return f(begin, end, ident);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
c10::SmallVector<scalar_t, 64> results(max_threads, ident);
|
| 71 |
+
internal::invoke_parallel(
|
| 72 |
+
begin,
|
| 73 |
+
end,
|
| 74 |
+
grain_size,
|
| 75 |
+
[&](const int64_t my_begin, const int64_t my_end) {
|
| 76 |
+
const auto tid = at::get_thread_num();
|
| 77 |
+
c10::ParallelGuard guard(true);
|
| 78 |
+
results[tid] = f(my_begin, my_end, ident);
|
| 79 |
+
});
|
| 80 |
+
|
| 81 |
+
scalar_t result = ident;
|
| 82 |
+
for (auto partial_result : results) {
|
| 83 |
+
result = sf(result, partial_result);
|
| 84 |
+
}
|
| 85 |
+
return result;
|
| 86 |
+
#else
|
| 87 |
+
internal::ThreadIdGuard tid_guard(0);
|
| 88 |
+
c10::ParallelGuard guard(true);
|
| 89 |
+
return f(begin, end, ident);
|
| 90 |
+
#endif
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Parallel.h
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/Config.h>
|
| 3 |
+
#include <c10/macros/Macros.h>
|
| 4 |
+
#include <functional>
|
| 5 |
+
#include <string>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
|
| 9 |
+
inline int64_t divup(int64_t x, int64_t y) {
|
| 10 |
+
return (x + y - 1) / y;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
// Called during new thread initialization
|
| 14 |
+
TORCH_API void init_num_threads();
|
| 15 |
+
|
| 16 |
+
// Sets the number of threads to be used in parallel region
|
| 17 |
+
TORCH_API void set_num_threads(int);
|
| 18 |
+
|
| 19 |
+
// Returns the maximum number of threads that may be used in a parallel region
|
| 20 |
+
TORCH_API int get_num_threads();
|
| 21 |
+
|
| 22 |
+
// Returns the current thread number (starting from 0)
|
| 23 |
+
// in the current parallel region, or 0 in the sequential region
|
| 24 |
+
TORCH_API int get_thread_num();
|
| 25 |
+
|
| 26 |
+
// Checks whether the code runs in parallel region
|
| 27 |
+
TORCH_API bool in_parallel_region();
|
| 28 |
+
|
| 29 |
+
namespace internal {
|
| 30 |
+
|
| 31 |
+
// Initialise num_threads lazily at first parallel call
|
| 32 |
+
inline void lazy_init_num_threads() {
|
| 33 |
+
thread_local bool init = false;
|
| 34 |
+
if (C10_UNLIKELY(!init)) {
|
| 35 |
+
at::init_num_threads();
|
| 36 |
+
init = true;
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
TORCH_API void set_thread_num(int);
|
| 41 |
+
|
| 42 |
+
class TORCH_API ThreadIdGuard {
|
| 43 |
+
public:
|
| 44 |
+
ThreadIdGuard(int new_id) : old_id_(at::get_thread_num()) {
|
| 45 |
+
set_thread_num(new_id);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
~ThreadIdGuard() {
|
| 49 |
+
set_thread_num(old_id_);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
private:
|
| 53 |
+
int old_id_;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
} // namespace internal
|
| 57 |
+
|
| 58 |
+
/*
|
| 59 |
+
parallel_for
|
| 60 |
+
|
| 61 |
+
begin: index at which to start applying user function
|
| 62 |
+
|
| 63 |
+
end: index at which to stop applying user function
|
| 64 |
+
|
| 65 |
+
grain_size: number of elements per chunk. impacts the degree of parallelization
|
| 66 |
+
|
| 67 |
+
f: user function applied in parallel to the chunks, signature:
|
| 68 |
+
void f(int64_t begin, int64_t end)
|
| 69 |
+
|
| 70 |
+
Warning: parallel_for does NOT copy thread local
|
| 71 |
+
states from the current thread to the worker threads.
|
| 72 |
+
This means for example that Tensor operations CANNOT be used in the
|
| 73 |
+
body of your function, only data pointers.
|
| 74 |
+
*/
|
| 75 |
+
template <class F>
|
| 76 |
+
inline void parallel_for(
|
| 77 |
+
const int64_t begin,
|
| 78 |
+
const int64_t end,
|
| 79 |
+
const int64_t grain_size,
|
| 80 |
+
const F& f);
|
| 81 |
+
|
| 82 |
+
/*
|
| 83 |
+
parallel_reduce
|
| 84 |
+
|
| 85 |
+
begin: index at which to start applying reduction
|
| 86 |
+
|
| 87 |
+
end: index at which to stop applying reduction
|
| 88 |
+
|
| 89 |
+
grain_size: number of elements per chunk. impacts number of elements in
|
| 90 |
+
intermediate results tensor and degree of parallelization.
|
| 91 |
+
|
| 92 |
+
ident: identity for binary combination function sf. sf(ident, x) needs to return
|
| 93 |
+
x.
|
| 94 |
+
|
| 95 |
+
f: function for reduction over a chunk. f needs to be of signature scalar_t
|
| 96 |
+
f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy)
|
| 97 |
+
|
| 98 |
+
sf: function to combine two partial results. sf needs to be of signature
|
| 99 |
+
scalar_t sf(scalar_t x, scalar_t y)
|
| 100 |
+
|
| 101 |
+
For example, you might have a tensor of 10000 entires and want to sum together
|
| 102 |
+
all the elements. Parallel_reduce with a grain_size of 2500 will then allocate
|
| 103 |
+
an intermediate result tensor with 4 elements. Then it will execute the function
|
| 104 |
+
"f" you provide and pass the beginning and end index of these chunks, so
|
| 105 |
+
0-2499, 2500-4999, etc. and the combination identity. It will then write out
|
| 106 |
+
the result from each of these chunks into the intermediate result tensor. After
|
| 107 |
+
that it'll reduce the partial results from each chunk into a single number using
|
| 108 |
+
the combination function sf and the identity ident. For a total summation this
|
| 109 |
+
would be "+" and 0 respectively. This is similar to tbb's approach [1], where
|
| 110 |
+
you need to provide a function to accumulate a subrange, a function to combine
|
| 111 |
+
two partial results and an identity.
|
| 112 |
+
|
| 113 |
+
Warning: parallel_reduce does NOT copy thread local
|
| 114 |
+
states from the current thread to the worker threads.
|
| 115 |
+
This means for example that Tensor operations CANNOT be used in the
|
| 116 |
+
body of your function, only data pointers.
|
| 117 |
+
|
| 118 |
+
[1] https://software.intel.com/en-us/node/506154
|
| 119 |
+
*/
|
| 120 |
+
template <class scalar_t, class F, class SF>
|
| 121 |
+
inline scalar_t parallel_reduce(
|
| 122 |
+
const int64_t begin,
|
| 123 |
+
const int64_t end,
|
| 124 |
+
const int64_t grain_size,
|
| 125 |
+
const scalar_t ident,
|
| 126 |
+
const F& f,
|
| 127 |
+
const SF& sf);
|
| 128 |
+
|
| 129 |
+
// Returns a detailed string describing parallelization settings
|
| 130 |
+
TORCH_API std::string get_parallel_info();
|
| 131 |
+
|
| 132 |
+
// Sets number of threads used for inter-op parallelism
|
| 133 |
+
TORCH_API void set_num_interop_threads(int);
|
| 134 |
+
|
| 135 |
+
// Returns the number of threads used for inter-op parallelism
|
| 136 |
+
TORCH_API int get_num_interop_threads();
|
| 137 |
+
|
| 138 |
+
// Launches inter-op parallel task
|
| 139 |
+
TORCH_API void launch(std::function<void()> func);
|
| 140 |
+
namespace internal {
|
| 141 |
+
void launch_no_thread_state(std::function<void()> fn);
|
| 142 |
+
} // namespace internal
|
| 143 |
+
|
| 144 |
+
// Launches intra-op parallel task
|
| 145 |
+
TORCH_API void intraop_launch(std::function<void()> func);
|
| 146 |
+
|
| 147 |
+
// Returns number of intra-op threads used by default
|
| 148 |
+
TORCH_API int intraop_default_num_threads();
|
| 149 |
+
|
| 150 |
+
} // namespace at
|
| 151 |
+
|
| 152 |
+
#if AT_PARALLEL_OPENMP
|
| 153 |
+
#include <ATen/ParallelOpenMP.h> // IWYU pragma: keep
|
| 154 |
+
#elif AT_PARALLEL_NATIVE
|
| 155 |
+
#include <ATen/ParallelNative.h> // IWYU pragma: keep
|
| 156 |
+
#elif AT_PARALLEL_NATIVE_TBB
|
| 157 |
+
#include <ATen/ParallelNativeTBB.h> // IWYU pragma: keep
|
| 158 |
+
#endif
|
| 159 |
+
|
| 160 |
+
#include <ATen/Parallel-inl.h> // IWYU pragma: keep
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelNative.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <cstddef>
|
| 5 |
+
#include <exception>
|
| 6 |
+
|
| 7 |
+
#include <c10/util/Exception.h>
|
| 8 |
+
|
| 9 |
+
#define INTRA_OP_PARALLEL
|
| 10 |
+
|
| 11 |
+
namespace at::internal {
|
| 12 |
+
|
| 13 |
+
TORCH_API void invoke_parallel(
|
| 14 |
+
const int64_t begin,
|
| 15 |
+
const int64_t end,
|
| 16 |
+
const int64_t grain_size,
|
| 17 |
+
const std::function<void(int64_t, int64_t)>& f);
|
| 18 |
+
|
| 19 |
+
} // namespace at::internal
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SavedTensorHooks.h
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/macros/Export.h>
|
| 4 |
+
#include <c10/util/Optional.h>
|
| 5 |
+
#include <c10/util/python_stub.h>
|
| 6 |
+
#include <stack>
|
| 7 |
+
#include <string>
|
| 8 |
+
|
| 9 |
+
#include <utility>
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
|
| 13 |
+
namespace impl {
|
| 14 |
+
|
| 15 |
+
struct TORCH_API SavedTensorDefaultHooksTLS {
|
| 16 |
+
// PyObject is defined in c10/util/python_stub.h
|
| 17 |
+
std::stack<std::pair<PyObject*, PyObject*>> stack;
|
| 18 |
+
|
| 19 |
+
// See NOTE: [Disabling SavedTensorDefaultHooks] for context
|
| 20 |
+
// NOTE: [disabled_error_message invariant]
|
| 21 |
+
// disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
|
| 22 |
+
// We did this for efficiency (so we didn't have to keep a separate bool
|
| 23 |
+
// around)
|
| 24 |
+
c10::optional<std::string> disabled_error_message;
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
} // namespace impl
|
| 28 |
+
|
| 29 |
+
struct TORCH_API SavedTensorDefaultHooks {
|
| 30 |
+
static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
|
| 31 |
+
static void pop_hooks();
|
| 32 |
+
static std::pair<PyObject*, PyObject*> get_hooks();
|
| 33 |
+
static void lazy_initialize();
|
| 34 |
+
static std::stack<std::pair<PyObject*, PyObject*>> get_stack();
|
| 35 |
+
static void set_stack(std::stack<std::pair<PyObject*, PyObject*>>);
|
| 36 |
+
|
| 37 |
+
static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
|
| 38 |
+
static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
|
| 39 |
+
|
| 40 |
+
// NOTE: [Disabling SavedTensorDefaultHooks]
|
| 41 |
+
// A developer of a PyTorch feature may choose to disable SavedTensorDefault
|
| 42 |
+
// hooks, especially if their feature does not work with it. If they are
|
| 43 |
+
// disabled, then the following will raise an error:
|
| 44 |
+
// - Attempting to push_hooks
|
| 45 |
+
// - calling disable(message) with a non-zero stack (from get_stack) size
|
| 46 |
+
static void disable(const std::string& error_message);
|
| 47 |
+
static void enable();
|
| 48 |
+
static bool is_enabled();
|
| 49 |
+
static const c10::optional<std::string>& get_disabled_error_message();
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorAccessor.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/TensorAccessor.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorIteratorInternal.h
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/native/TensorIterator.h>
|
| 3 |
+
#include <c10/util/SmallBuffer.h>
|
| 4 |
+
#include <c10/util/irange.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
struct DimCounter {
|
| 9 |
+
DimCounter(IntArrayRef shape, Range range);
|
| 10 |
+
|
| 11 |
+
void increment(const std::array<int64_t, 2>& step);
|
| 12 |
+
bool is_done() const;
|
| 13 |
+
std::array<int64_t, 2> max_2d_step() const;
|
| 14 |
+
|
| 15 |
+
IntArrayRef shape;
|
| 16 |
+
Range range;
|
| 17 |
+
c10::SmallBuffer<int64_t, 4> values;
|
| 18 |
+
int64_t offset;
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
namespace internal {
|
| 22 |
+
|
| 23 |
+
inline void get_data_ptrs(
|
| 24 |
+
char** ptrs,
|
| 25 |
+
ArrayRef<char*> base,
|
| 26 |
+
IntArrayRef strides,
|
| 27 |
+
IntArrayRef counter) {
|
| 28 |
+
const auto ntensors = base.size();
|
| 29 |
+
const auto ndim = counter.size();
|
| 30 |
+
std::copy(base.begin(), base.end(), ptrs);
|
| 31 |
+
for (const auto dim : c10::irange(ndim)) {
|
| 32 |
+
int64_t value = counter[dim];
|
| 33 |
+
for (const auto arg : c10::irange(ntensors)) {
|
| 34 |
+
ptrs[arg] += value * strides[dim * ntensors + arg];
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
inline void serial_for_each(
|
| 40 |
+
IntArrayRef shape,
|
| 41 |
+
IntArrayRef strides,
|
| 42 |
+
char** base_ptrs,
|
| 43 |
+
size_t ntensors,
|
| 44 |
+
typename TensorIteratorBase::loop2d_t loop,
|
| 45 |
+
Range range) {
|
| 46 |
+
const auto ndim = shape.size();
|
| 47 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
| 48 |
+
strides.size() == ntensors * std::max(size_t{2}, ndim));
|
| 49 |
+
|
| 50 |
+
if (ndim <= 1) {
|
| 51 |
+
if (range.begin == 0) {
|
| 52 |
+
loop(base_ptrs, strides.data(), range.size(), 1);
|
| 53 |
+
} else {
|
| 54 |
+
c10::SmallBuffer<char*, 4> ptrs(ntensors);
|
| 55 |
+
get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin});
|
| 56 |
+
loop(ptrs.data(), strides.data(), range.size(), 1);
|
| 57 |
+
}
|
| 58 |
+
} else {
|
| 59 |
+
c10::SmallBuffer<char*, 4> ptrs(ntensors);
|
| 60 |
+
auto counter = DimCounter(shape, range);
|
| 61 |
+
while (!counter.is_done()) {
|
| 62 |
+
get_data_ptrs(
|
| 63 |
+
ptrs.data(), {base_ptrs, ntensors}, strides, counter.values);
|
| 64 |
+
auto step = counter.max_2d_step();
|
| 65 |
+
loop(ptrs.data(), strides.data(), step[0], step[1]);
|
| 66 |
+
counter.increment(step);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
} // namespace internal
|
| 72 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorMeta.h
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/DimVector.h>
|
| 4 |
+
#include <ATen/core/Dimname.h>
|
| 5 |
+
#include <c10/core/TensorOptions.h>
|
| 6 |
+
#include <c10/util/strides.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
class Tensor;
|
| 11 |
+
|
| 12 |
+
namespace impl {
|
| 13 |
+
|
| 14 |
+
// Use this to define the prototype for a meta function. There are two
|
| 15 |
+
// versions; one that takes one argument (just the operator name), or FUNC2
|
| 16 |
+
// variant that takes two arguments (operator name and overload name).
|
| 17 |
+
//
|
| 18 |
+
// Example usage:
|
| 19 |
+
//
|
| 20 |
+
// TORCH_META_FUNC2(add, Tensor) (
|
| 21 |
+
// const Tensor& self, const Tensor& other
|
| 22 |
+
// ) {
|
| 23 |
+
// ... compute sizes and options ...
|
| 24 |
+
// set_output(sizes, options);
|
| 25 |
+
// }
|
| 26 |
+
//
|
| 27 |
+
#define TORCH_META_FUNC(name) void structured_##name::meta
|
| 28 |
+
#define TORCH_META_FUNC2(name, overload) \
|
| 29 |
+
void structured_##name##_##overload::meta
|
| 30 |
+
|
| 31 |
+
// These are versions of TORCH_META_FUNC(2) that include a precompute_out struct
|
| 32 |
+
// as a return value. They should be used when the kernel in question has
|
| 33 |
+
// precomputed values declared in native_functions.yaml and the corresponding
|
| 34 |
+
// implementation should return an instance of the aforementioned struct.
|
| 35 |
+
#define TORCH_PRECOMPUTE_META_FUNC(name) \
|
| 36 |
+
structured_##name::meta_return_ty structured_##name::meta
|
| 37 |
+
#define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \
|
| 38 |
+
structured_##name##_##overload::meta_return_ty \
|
| 39 |
+
structured_##name##_##overload::meta
|
| 40 |
+
|
| 41 |
+
// Use this to create a precompute struct in a meta function.
|
| 42 |
+
#define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
|
| 43 |
+
#define TORCH_PRECOMPUTE_STRUCT2(name, overload) \
|
| 44 |
+
structured_##name##_##overload::precompute_out<>
|
| 45 |
+
|
| 46 |
+
// Use this to define the prototype for an implementation. This takes only
|
| 47 |
+
// one argument, which is the name of the dispatch key entry you're
|
| 48 |
+
// implementing.
|
| 49 |
+
//
|
| 50 |
+
// Example usage:
|
| 51 |
+
//
|
| 52 |
+
// TORCH_IMPL_FUNC(add_cpu) (
|
| 53 |
+
// Tensor& result, const Tensor& self, const Tensor& other
|
| 54 |
+
// ) {
|
| 55 |
+
// ... do the actual implementation ...
|
| 56 |
+
// }
|
| 57 |
+
//
|
| 58 |
+
#define TORCH_IMPL_FUNC(name) void structured_##name::impl
|
| 59 |
+
|
| 60 |
+
// Base class for all structured kernel classes. The set_output virtual
|
| 61 |
+
// method is varied depending whether or not the operator is
|
| 62 |
+
// functional/out/inplace, and could also be specialized for CPU/CUDA/etc
|
| 63 |
+
// (although presently it isn't).
|
| 64 |
+
//
|
| 65 |
+
// A notable subclass of this interface is TensorIteratorBase.
|
| 66 |
+
struct TORCH_API MetaBase {
|
| 67 |
+
MetaBase() = default;
|
| 68 |
+
MetaBase(const MetaBase&) = default;
|
| 69 |
+
MetaBase& operator=(const MetaBase&) = default;
|
| 70 |
+
MetaBase(MetaBase&&) noexcept = default;
|
| 71 |
+
MetaBase& operator=(MetaBase&&) noexcept = default;
|
| 72 |
+
virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
|
| 73 |
+
|
| 74 |
+
// Note: [set_output_*]
|
| 75 |
+
// See: https://github.com/pytorch/pytorch/issues/69813
|
| 76 |
+
// Whenever defining the output properties in the META function of a
|
| 77 |
+
// structured kernel (what was usually done with `set_output`), use one of
|
| 78 |
+
// these 3 variants, instead. In order to decide which variant to use, check
|
| 79 |
+
// the following decision tree:
|
| 80 |
+
//
|
| 81 |
+
// - Can the kernel you are going to implement support output tensors
|
| 82 |
+
// with arbitrary strides?
|
| 83 |
+
// |
|
| 84 |
+
// -- YES: `set_output_raw_strided`
|
| 85 |
+
// |
|
| 86 |
+
// -- NO: Should the output tensor strides be contiguous?
|
| 87 |
+
// |
|
| 88 |
+
// -- YES: `set_output_contiguous`
|
| 89 |
+
// |
|
| 90 |
+
// -- NO: `set_output_strided`
|
| 91 |
+
//
|
| 92 |
+
// Use this function whenever the kernel requires specific strides for the
|
| 93 |
+
// output. If `strides` does not match the given output strides, proxy outputs
|
| 94 |
+
// will be created and passed to the IMPL function.
|
| 95 |
+
virtual void set_output_strided(
|
| 96 |
+
int64_t output_idx,
|
| 97 |
+
IntArrayRef sizes,
|
| 98 |
+
IntArrayRef strides,
|
| 99 |
+
TensorOptions options,
|
| 100 |
+
DimnameList names = {}) {
|
| 101 |
+
TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// Use this function whenever the kernel knows how to handle arbitrary strided
|
| 105 |
+
// outputs. This function has the same behavior as the old `set_output`: it
|
| 106 |
+
// will only re-stride if the given output was resized.
|
| 107 |
+
virtual void set_output_raw_strided(
|
| 108 |
+
int64_t output_idx,
|
| 109 |
+
IntArrayRef sizes,
|
| 110 |
+
IntArrayRef strides_hint,
|
| 111 |
+
TensorOptions options,
|
| 112 |
+
DimnameList names = {}) {
|
| 113 |
+
TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// Use this function if the kernel requires contiguous strides.
|
| 117 |
+
// Alias for `set_output_strided`, but with contiguous strides.
|
| 118 |
+
void set_output_contiguous(
|
| 119 |
+
int64_t output_idx,
|
| 120 |
+
IntArrayRef sizes,
|
| 121 |
+
TensorOptions options,
|
| 122 |
+
DimnameList names = {}) {
|
| 123 |
+
auto strides = c10::contiguous_strides(sizes);
|
| 124 |
+
set_output_strided(output_idx, sizes, strides, options, names);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
// Returns a reference to an undefined tensor if there is no presupplied
|
| 128 |
+
// output
|
| 129 |
+
const Tensor& maybe_get_output() {
|
| 130 |
+
return maybe_get_output(0);
|
| 131 |
+
}
|
| 132 |
+
virtual ~MetaBase() = default;
|
| 133 |
+
};
|
| 134 |
+
|
| 135 |
+
} // namespace impl
|
| 136 |
+
|
| 137 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorNames.h
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/WrapDimUtils.h>
|
| 4 |
+
|
| 5 |
+
namespace at::namedinference {
|
| 6 |
+
|
| 7 |
+
// TensorName and TensorNames are wrappers around Dimname and DimnameList
|
| 8 |
+
// that contain helper functions to make writing name inference rules easier.
|
| 9 |
+
//
|
| 10 |
+
// A TensorName represents a Dimname associated with some DimnameList (from a
|
| 11 |
+
// Tensor). This encapsulates all the information that is needed to check if
|
| 12 |
+
// names *match* and to *unify* names.
|
| 13 |
+
//
|
| 14 |
+
// Definition: Two names in two tensors *match* if they are equal, or if at
|
| 15 |
+
// least one of them is a wildcard that can be *refined* to the other name.
|
| 16 |
+
//
|
| 17 |
+
// Definition: unify(name, other) fails if the names do not match. Otherwise,
|
| 18 |
+
// it returns the most refined of name and other.
|
| 19 |
+
//
|
| 20 |
+
// Here is an example of checking if two names match.
|
| 21 |
+
// tensor: Tensor[A, None]
|
| 22 |
+
// other: Tensor[A]
|
| 23 |
+
//
|
| 24 |
+
// Let's say we wish to check if tensor.names[-1] matches other.names[-1].
|
| 25 |
+
// None (in tensor) cannot match A (in other) because if the None were refined
|
| 26 |
+
// to A, `tensor` would have duplicate names [A, A]. Therefore we need to check
|
| 27 |
+
// tensor.names [A, None] for the existence of A.
|
| 28 |
+
struct TORCH_API TensorName {
|
| 29 |
+
explicit TensorName(ArrayRef<Dimname> origin, int origin_idx)
|
| 30 |
+
: origin_(origin),
|
| 31 |
+
name_(origin[maybe_wrap_dim(
|
| 32 |
+
origin_idx,
|
| 33 |
+
static_cast<int64_t>(origin.size()))]),
|
| 34 |
+
origin_idx_(origin_idx) {}
|
| 35 |
+
|
| 36 |
+
// op_name is only used for error reporting.
|
| 37 |
+
const TensorName& unify(const TensorName& other, const char* op_name) const;
|
| 38 |
+
Dimname toDimname() const;
|
| 39 |
+
|
| 40 |
+
private:
|
| 41 |
+
ArrayRef<Dimname> origin_;
|
| 42 |
+
Dimname name_;
|
| 43 |
+
int origin_idx_; // A named tensor can have at most 64 dims.
|
| 44 |
+
|
| 45 |
+
TORCH_API friend std::ostream& operator<<(
|
| 46 |
+
std::ostream& out,
|
| 47 |
+
const TensorName& tensorname);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
using TensorNameVec = SmallVector<TensorName, 10>;
|
| 51 |
+
|
| 52 |
+
struct TORCH_API TensorNames {
|
| 53 |
+
explicit TensorNames(ArrayRef<Dimname> names);
|
| 54 |
+
|
| 55 |
+
// Create TensorNames from names[start:end]. Each individual TensorName stores
|
| 56 |
+
// `names`, NOT names[start:end], because the original tensor's names are
|
| 57 |
+
// `names`.
|
| 58 |
+
explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end);
|
| 59 |
+
|
| 60 |
+
// op_name is only used for error reporting.
|
| 61 |
+
TensorNames& unifyFromRightInplace(
|
| 62 |
+
const TensorNames& other,
|
| 63 |
+
const char* op_name = "unify");
|
| 64 |
+
void checkUnique(const char* op_name) const;
|
| 65 |
+
|
| 66 |
+
void append(TensorName name);
|
| 67 |
+
std::vector<Dimname> toDimnameVec() const;
|
| 68 |
+
|
| 69 |
+
private:
|
| 70 |
+
explicit TensorNames(TensorNameVec&& names) : names_(std::move(names)){};
|
| 71 |
+
|
| 72 |
+
TensorNameVec names_;
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
} // namespace at::namedinference
|