Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/freezing.py +266 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/quantized_lowerings.py +15 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Backtrace.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFunctions_inl.h +576 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CPUGeneratorImpl.h +49 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CUDAFunctions.h +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CUDAFunctions_inl.h +614 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CachedTensorUtils.h +24 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h +323 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h +500 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DLConvertor.h +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Device.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DeviceGuard.h +41 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ExpandUtils.h +527 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalStorageImpl.h +126 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/InitialTensorOptions.h +15 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Layout.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LinalgBackend.h +31 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MatrixRef.h +109 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions.h +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MethodOperators.h +443 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensorUtils.h +215 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NativeFunctions.h +1317 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelOpenMP.h +54 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/RedispatchFunctions.h +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SmallVector.h +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorUtils.h +411 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorGeometry.h +144 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TracerMode.h +132 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/autocast_mode.h +647 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cpp_custom_type_hack.h +110 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh +47 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/AsmUtils.cuh +149 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDABlas.h +375 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAConfig.h +19 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContextLight.h +95 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADevice.h +23 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAEvent.h +208 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDASparse.h +76 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h +290 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAUtils.h +20 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h +37 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh +121 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/EmptyTensor.h +44 -0
.gitattributes
CHANGED
|
@@ -75,3 +75,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_r
|
|
| 75 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text
|
| 76 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 77 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 75 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text
|
| 76 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 77 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.11 filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/t64.exe
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81a618f21cb87db9076134e70388b6e9cb7c2106739011b6a51772d22cae06b7
|
| 3 |
+
size 108032
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/freezing.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
import weakref
|
| 7 |
+
from typing import Any, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils._pytree as pytree
|
| 11 |
+
from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code
|
| 12 |
+
from torch._functorch.aot_autograd import MutationType
|
| 13 |
+
from torch._functorch.compile_utils import fx_graph_cse
|
| 14 |
+
from torch._inductor.constant_folding import constant_fold, replace_node_with_constant
|
| 15 |
+
|
| 16 |
+
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
|
| 17 |
+
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
| 18 |
+
|
| 19 |
+
from . import config
|
| 20 |
+
|
| 21 |
+
aten = torch.ops.aten
|
| 22 |
+
prims = torch.ops.prims
|
| 23 |
+
|
| 24 |
+
log = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def replace_params_with_constants(
|
| 28 |
+
gm: torch.fx.GraphModule,
|
| 29 |
+
flat_params: list[Any],
|
| 30 |
+
fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
|
| 31 |
+
) -> List[int]:
|
| 32 |
+
"""
|
| 33 |
+
Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
|
| 34 |
+
Returns a list of indices representing the input parameters that were not converted to constants.
|
| 35 |
+
"""
|
| 36 |
+
params = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
| 37 |
+
fake_inp_nodes = params[: len(params)]
|
| 38 |
+
preserved_arg_indices = []
|
| 39 |
+
aliased_input_args = [
|
| 40 |
+
out_info.base_idx
|
| 41 |
+
for out_info in fw_metadata.output_info
|
| 42 |
+
if out_info.base_idx is not None
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
# TODO (tmanlaibaatar) figure out why this is different
|
| 46 |
+
# from mutated_inp_runtime_indices
|
| 47 |
+
mutated_inps = [
|
| 48 |
+
i
|
| 49 |
+
for i, m in enumerate(fw_metadata.input_info)
|
| 50 |
+
if m.mutation_type
|
| 51 |
+
in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH)
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)):
|
| 55 |
+
if i in mutated_inps or i in aliased_input_args:
|
| 56 |
+
preserved_arg_indices.append(i)
|
| 57 |
+
continue
|
| 58 |
+
replace_node_with_constant(gm, node, real_input)
|
| 59 |
+
# add on non param inputs
|
| 60 |
+
preserved_arg_indices.extend(range(len(flat_params), len(params)))
|
| 61 |
+
# is this necessary ?
|
| 62 |
+
gm.recompile()
|
| 63 |
+
return preserved_arg_indices
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def freeze(
|
| 67 |
+
dynamo_gm: torch.fx.GraphModule,
|
| 68 |
+
aot_autograd_gm: torch.fx.GraphModule,
|
| 69 |
+
example_inputs: List[torch._subclasses.FakeTensor],
|
| 70 |
+
) -> Tuple[torch.fx.GraphModule, List[int]]:
|
| 71 |
+
"""
|
| 72 |
+
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
|
| 73 |
+
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
|
| 74 |
+
|
| 75 |
+
Assumes that this function is run in dynamo tracing post aot_autograd.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule.
|
| 79 |
+
aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen.
|
| 80 |
+
example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices
|
| 84 |
+
of the inputs that were preserved (not turned into constants).
|
| 85 |
+
"""
|
| 86 |
+
# We have convert conv's weight to channels last which may meet error for .view
|
| 87 |
+
# when doing fake_tensor_prop. So we need to convert view to reshape first.
|
| 88 |
+
# See the details in fx_codegen_and_compile of compile_fx.py.
|
| 89 |
+
view_to_reshape(aot_autograd_gm)
|
| 90 |
+
|
| 91 |
+
if tracing_context := torch._guards.TracingContext.try_get():
|
| 92 |
+
fw_metadata = tracing_context.fw_metadata
|
| 93 |
+
params_flat = tracing_context.params_flat
|
| 94 |
+
assert fw_metadata is not None and params_flat is not None
|
| 95 |
+
|
| 96 |
+
preserved_arg_indices = replace_params_with_constants(
|
| 97 |
+
aot_autograd_gm, params_flat, fw_metadata
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
inputs = [
|
| 101 |
+
node for node in aot_autograd_gm.graph.nodes if node.op == "placeholder"
|
| 102 |
+
]
|
| 103 |
+
preserved_arg_indices = list(range(len(inputs)))
|
| 104 |
+
|
| 105 |
+
# TODO - further restrict cse ? right now needed to dedup aliasing ops
|
| 106 |
+
cse_graph = fx_graph_cse(aot_autograd_gm.graph)
|
| 107 |
+
aot_autograd_gm.graph = cse_graph
|
| 108 |
+
aot_autograd_gm.recompile()
|
| 109 |
+
|
| 110 |
+
aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
|
| 111 |
+
freezing_passes(aot_autograd_gm, aot_example_inputs)
|
| 112 |
+
|
| 113 |
+
constant_fold(aot_autograd_gm)
|
| 114 |
+
# invalidate nn Modules
|
| 115 |
+
if config.freezing_discard_parameters:
|
| 116 |
+
invalidate_eager_modules()
|
| 117 |
+
discard_traced_gm_params(dynamo_gm)
|
| 118 |
+
|
| 119 |
+
log.debug("%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm))
|
| 120 |
+
|
| 121 |
+
return aot_autograd_gm, preserved_arg_indices
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ErasedTensor(torch.Tensor):
|
| 125 |
+
@staticmethod
|
| 126 |
+
def __new__(cls, elem, name, owning_mod):
|
| 127 |
+
return super().__new__(cls, elem.to(device="meta"))
|
| 128 |
+
|
| 129 |
+
def __init__(self, elem, name: Optional[str], mod):
|
| 130 |
+
self.erased_name = name
|
| 131 |
+
self.owning_mod_ref = weakref.ref(mod)
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
| 135 |
+
erased_tensors = [
|
| 136 |
+
e
|
| 137 |
+
for e in pytree.arg_tree_leaves(*args, **kwargs)
|
| 138 |
+
if isinstance(e, ErasedTensor)
|
| 139 |
+
]
|
| 140 |
+
assert len(erased_tensors) > 0
|
| 141 |
+
e = erased_tensors[0]
|
| 142 |
+
|
| 143 |
+
raise RuntimeError(
|
| 144 |
+
f"Trying to run Pytorch Eager Module after Dynamo Freezing. "
|
| 145 |
+
"The original parameters have been discarded for memory efficiency. "
|
| 146 |
+
f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@torch.utils._python_dispatch._disable_current_modes()
|
| 151 |
+
def invalidate_eager_modules():
|
| 152 |
+
for mod in torch._guards.TracingContext.get().module_context.nn_modules.values():
|
| 153 |
+
if not isinstance(mod, torch.nn.Module):
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
for attr_name, tensor in list(
|
| 157 |
+
itertools.chain(
|
| 158 |
+
mod.named_parameters(recurse=False), mod.named_buffers(recurse=False)
|
| 159 |
+
)
|
| 160 |
+
):
|
| 161 |
+
with torch._dispatch.python.no_python_dispatcher():
|
| 162 |
+
e_t = ErasedTensor(tensor, attr_name, mod)
|
| 163 |
+
if isinstance(tensor, torch.nn.Parameter):
|
| 164 |
+
e_t.requires_grad_(True)
|
| 165 |
+
e_t._is_param = True # type: ignore[attr-defined]
|
| 166 |
+
setattr(mod, attr_name, e_t)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@torch.utils._python_dispatch._disable_current_modes()
|
| 170 |
+
def discard_traced_gm_params(mod: torch.fx.GraphModule):
|
| 171 |
+
for attr_name, tensor in list(
|
| 172 |
+
itertools.chain(
|
| 173 |
+
mod.named_parameters(recurse=False), mod.named_buffers(recurse=False)
|
| 174 |
+
)
|
| 175 |
+
):
|
| 176 |
+
with torch._dispatch.python.no_python_dispatcher():
|
| 177 |
+
e_t = ErasedTensor(tensor, attr_name, mod)
|
| 178 |
+
if isinstance(tensor, torch.nn.Parameter):
|
| 179 |
+
e_t.requires_grad_(True)
|
| 180 |
+
e_t._is_param = True # type: ignore[attr-defined]
|
| 181 |
+
setattr(mod, attr_name, e_t)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def enforce_output_layout(gm: torch.fx.GraphModule):
|
| 185 |
+
"""
|
| 186 |
+
Make sure the output node's layout does not change due to compiler optimizations
|
| 187 |
+
by adding aten.as_strided nodes with the expected strides.
|
| 188 |
+
|
| 189 |
+
Only used for inference so we can assume all graph outputs are model outputs.
|
| 190 |
+
"""
|
| 191 |
+
*_, output_node = gm.graph.nodes
|
| 192 |
+
out_list = output_node.args[0]
|
| 193 |
+
with gm.graph.inserting_before(output_node):
|
| 194 |
+
for n in out_list:
|
| 195 |
+
if not isinstance(
|
| 196 |
+
n.meta["val"], torch.Tensor
|
| 197 |
+
) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]):
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
# add a node to enforce eager layout
|
| 201 |
+
ft = n.meta["val"]
|
| 202 |
+
new_node = gm.graph.call_function(
|
| 203 |
+
prims.inductor_force_stride_order.default, (n, ft.stride())
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# can not call
|
| 207 |
+
# n.replace_all_uses_with(new_node)
|
| 208 |
+
# since it will replace the usage of n in new_node itself.
|
| 209 |
+
output_node.replace_input_with(n, new_node)
|
| 210 |
+
|
| 211 |
+
gm.graph.lint()
|
| 212 |
+
gm.recompile()
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def enforce_as_strided_input_layout(gm: torch.fx.GraphModule):
|
| 216 |
+
"""
|
| 217 |
+
Make sure the as_strided node's input's layout does not change due to compiler
|
| 218 |
+
optimizations, because the as_strided strides info depends on input tensor stride info.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
as_strided_ops = [
|
| 222 |
+
torch.ops.aten.as_strided.default,
|
| 223 |
+
torch.ops.aten.as_strided_.default,
|
| 224 |
+
torch.ops.aten.as_strided_scatter.default,
|
| 225 |
+
]
|
| 226 |
+
strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops]
|
| 227 |
+
for n in strided_nodes:
|
| 228 |
+
with gm.graph.inserting_before(n):
|
| 229 |
+
# add a node to enforce eager layout
|
| 230 |
+
ft = n.args[0].meta["val"]
|
| 231 |
+
new_node = gm.graph.call_function(
|
| 232 |
+
prims.inductor_force_stride_order.default, (n.args[0], ft.stride())
|
| 233 |
+
)
|
| 234 |
+
n.replace_input_with(n.args[0], new_node)
|
| 235 |
+
|
| 236 |
+
gm.graph.lint()
|
| 237 |
+
gm.recompile()
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@dynamo_timed
|
| 241 |
+
def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule):
|
| 242 |
+
"""
|
| 243 |
+
Convert 4d convolution weight tensor to channels last format.
|
| 244 |
+
|
| 245 |
+
This pass is performed before freezing so the added nodes can be constant
|
| 246 |
+
folded by freezing.
|
| 247 |
+
"""
|
| 248 |
+
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
|
| 249 |
+
for conv in convs:
|
| 250 |
+
weight_node = conv.args[1]
|
| 251 |
+
if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[
|
| 252 |
+
"val"
|
| 253 |
+
].is_contiguous(memory_format=torch.channels_last):
|
| 254 |
+
# not a 4d tensor or already channels last, skip
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
with gm.graph.inserting_before(conv):
|
| 258 |
+
new_node = gm.graph.call_function(
|
| 259 |
+
aten.clone.default,
|
| 260 |
+
(weight_node,),
|
| 261 |
+
{"memory_format": torch.channels_last},
|
| 262 |
+
)
|
| 263 |
+
conv.replace_input_with(weight_node, new_node)
|
| 264 |
+
|
| 265 |
+
enforce_as_strided_input_layout(gm)
|
| 266 |
+
enforce_output_layout(gm)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/quantized_lowerings.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def register_quantized_ops():
|
| 5 |
+
from . import lowering
|
| 6 |
+
|
| 7 |
+
quantized = torch.ops.quantized
|
| 8 |
+
|
| 9 |
+
lowering.add_needs_realized_inputs(
|
| 10 |
+
[
|
| 11 |
+
quantized.max_pool2d,
|
| 12 |
+
]
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
lowering.make_fallback(quantized.max_pool2d)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Backtrace.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/Backtrace.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CPUFunctions_inl.h
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_cpu_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_adaptive_avg_pool2d_cpu_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_cpu_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_adaptive_avg_pool3d_cpu_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_cpu_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_add_relu_cpu_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_addmm_activation_cpu_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_aminmax_cpu_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_amp_update_scale_cpu_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_assert_async_cpu_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_cdist_backward_cpu_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cdist_forward_cpu_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_cholesky_solve_helper_cpu_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_compute_linear_combination_cpu_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_cpu_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_cpu_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_convert_weight_to_int4pack_cpu_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_ctc_loss_cpu_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_ctc_loss_backward_cpu_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_cummax_helper_cpu_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_cummin_helper_cpu_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_dirichlet_grad_cpu_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_efficientzerotensor_cpu_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_embedding_bag_cpu_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_embedding_bag_dense_backward_cpu_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_embedding_bag_forward_only_cpu_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_cpu_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_empty_affine_quantized_cpu_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_cpu_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_cpu_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cpu_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_cpu_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cpu_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cpu_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_fft_c2c_cpu_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_fft_c2r_cpu_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_fft_r2c_cpu_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_foobar_cpu_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_foreach_abs_cpu_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_foreach_acos_cpu_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_foreach_add_cpu_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_foreach_addcdiv_cpu_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_foreach_addcmul_cpu_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_foreach_asin_cpu_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_foreach_atan_cpu_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_foreach_ceil_cpu_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_foreach_clamp_max_cpu_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_foreach_clamp_min_cpu_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_foreach_copy_cpu_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_foreach_cos_cpu_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_foreach_cosh_cpu_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_foreach_div_cpu_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_foreach_erf_cpu_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_foreach_erfc_cpu_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_foreach_exp_cpu_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_foreach_expm1_cpu_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_foreach_floor_cpu_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_foreach_frac_cpu_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_foreach_lerp_cpu_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_foreach_lgamma_cpu_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_foreach_log_cpu_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_foreach_log10_cpu_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_foreach_log1p_cpu_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_foreach_log2_cpu_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_foreach_maximum_cpu_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_foreach_minimum_cpu_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_foreach_mul_cpu_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_foreach_neg_cpu_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_foreach_norm_cpu_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_foreach_pow_cpu_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_foreach_reciprocal_cpu_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_foreach_round_cpu_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_foreach_sigmoid_cpu_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_foreach_sign_cpu_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_foreach_sin_cpu_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_foreach_sinh_cpu_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_foreach_sqrt_cpu_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_foreach_sub_cpu_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_foreach_tan_cpu_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_foreach_tanh_cpu_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_foreach_trunc_cpu_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_foreach_zero_cpu_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_functional_assert_async_cpu_dispatch.h>
|
| 102 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_cpu_dispatch.h>
|
| 103 |
+
#include <ATen/ops/_fused_sdp_choice_cpu_dispatch.h>
|
| 104 |
+
#include <ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h>
|
| 105 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_cpu_dispatch.h>
|
| 106 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_cpu_dispatch.h>
|
| 107 |
+
#include <ATen/ops/_index_put_impl_cpu_dispatch.h>
|
| 108 |
+
#include <ATen/ops/_linalg_det_cpu_dispatch.h>
|
| 109 |
+
#include <ATen/ops/_linalg_eigh_cpu_dispatch.h>
|
| 110 |
+
#include <ATen/ops/_linalg_eigvals_cpu_dispatch.h>
|
| 111 |
+
#include <ATen/ops/_linalg_slogdet_cpu_dispatch.h>
|
| 112 |
+
#include <ATen/ops/_linalg_solve_ex_cpu_dispatch.h>
|
| 113 |
+
#include <ATen/ops/_linalg_svd_cpu_dispatch.h>
|
| 114 |
+
#include <ATen/ops/_local_scalar_dense_cpu_dispatch.h>
|
| 115 |
+
#include <ATen/ops/_log_softmax_cpu_dispatch.h>
|
| 116 |
+
#include <ATen/ops/_log_softmax_backward_data_cpu_dispatch.h>
|
| 117 |
+
#include <ATen/ops/_logcumsumexp_cpu_dispatch.h>
|
| 118 |
+
#include <ATen/ops/_make_dep_token_cpu_dispatch.h>
|
| 119 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_cpu_dispatch.h>
|
| 120 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_cpu_dispatch.h>
|
| 121 |
+
#include <ATen/ops/_masked_softmax_cpu_dispatch.h>
|
| 122 |
+
#include <ATen/ops/_masked_softmax_backward_cpu_dispatch.h>
|
| 123 |
+
#include <ATen/ops/_native_batch_norm_legit_cpu_dispatch.h>
|
| 124 |
+
#include <ATen/ops/_native_multi_head_attention_cpu_dispatch.h>
|
| 125 |
+
#include <ATen/ops/_nested_from_padded_cpu_dispatch.h>
|
| 126 |
+
#include <ATen/ops/_nested_tensor_from_mask_cpu_dispatch.h>
|
| 127 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_cpu_dispatch.h>
|
| 128 |
+
#include <ATen/ops/_nested_view_from_buffer_cpu_dispatch.h>
|
| 129 |
+
#include <ATen/ops/_pdist_backward_cpu_dispatch.h>
|
| 130 |
+
#include <ATen/ops/_pdist_forward_cpu_dispatch.h>
|
| 131 |
+
#include <ATen/ops/_prelu_kernel_cpu_dispatch.h>
|
| 132 |
+
#include <ATen/ops/_prelu_kernel_backward_cpu_dispatch.h>
|
| 133 |
+
#include <ATen/ops/_reshape_alias_cpu_dispatch.h>
|
| 134 |
+
#include <ATen/ops/_sample_dirichlet_cpu_dispatch.h>
|
| 135 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_cpu_dispatch.h>
|
| 136 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_cpu_dispatch.h>
|
| 137 |
+
#include <ATen/ops/_segment_reduce_backward_cpu_dispatch.h>
|
| 138 |
+
#include <ATen/ops/_slow_conv2d_backward_cpu_dispatch.h>
|
| 139 |
+
#include <ATen/ops/_slow_conv2d_forward_cpu_dispatch.h>
|
| 140 |
+
#include <ATen/ops/_softmax_cpu_dispatch.h>
|
| 141 |
+
#include <ATen/ops/_softmax_backward_data_cpu_dispatch.h>
|
| 142 |
+
#include <ATen/ops/_spdiags_cpu_dispatch.h>
|
| 143 |
+
#include <ATen/ops/_stack_cpu_dispatch.h>
|
| 144 |
+
#include <ATen/ops/_standard_gamma_cpu_dispatch.h>
|
| 145 |
+
#include <ATen/ops/_standard_gamma_grad_cpu_dispatch.h>
|
| 146 |
+
#include <ATen/ops/_test_functorch_fallback_cpu_dispatch.h>
|
| 147 |
+
#include <ATen/ops/_test_optional_filled_intlist_cpu_dispatch.h>
|
| 148 |
+
#include <ATen/ops/_test_optional_floatlist_cpu_dispatch.h>
|
| 149 |
+
#include <ATen/ops/_test_optional_intlist_cpu_dispatch.h>
|
| 150 |
+
#include <ATen/ops/_to_sparse_cpu_dispatch.h>
|
| 151 |
+
#include <ATen/ops/_to_sparse_bsc_cpu_dispatch.h>
|
| 152 |
+
#include <ATen/ops/_to_sparse_bsr_cpu_dispatch.h>
|
| 153 |
+
#include <ATen/ops/_to_sparse_csc_cpu_dispatch.h>
|
| 154 |
+
#include <ATen/ops/_to_sparse_csr_cpu_dispatch.h>
|
| 155 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_cpu_dispatch.h>
|
| 156 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_cpu_dispatch.h>
|
| 157 |
+
#include <ATen/ops/_unique_cpu_dispatch.h>
|
| 158 |
+
#include <ATen/ops/_unique2_cpu_dispatch.h>
|
| 159 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_cpu_dispatch.h>
|
| 160 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_cpu_dispatch.h>
|
| 161 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_cpu_dispatch.h>
|
| 162 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_cpu_dispatch.h>
|
| 163 |
+
#include <ATen/ops/_upsample_nearest_exact1d_cpu_dispatch.h>
|
| 164 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_cpu_dispatch.h>
|
| 165 |
+
#include <ATen/ops/_upsample_nearest_exact2d_cpu_dispatch.h>
|
| 166 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_cpu_dispatch.h>
|
| 167 |
+
#include <ATen/ops/_upsample_nearest_exact3d_cpu_dispatch.h>
|
| 168 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_cpu_dispatch.h>
|
| 169 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_cpu_dispatch.h>
|
| 170 |
+
#include <ATen/ops/_weight_int4pack_mm_cpu_dispatch.h>
|
| 171 |
+
#include <ATen/ops/_weight_int8pack_mm_cpu_dispatch.h>
|
| 172 |
+
#include <ATen/ops/_weight_norm_interface_cpu_dispatch.h>
|
| 173 |
+
#include <ATen/ops/_weight_norm_interface_backward_cpu_dispatch.h>
|
| 174 |
+
#include <ATen/ops/abs_cpu_dispatch.h>
|
| 175 |
+
#include <ATen/ops/acos_cpu_dispatch.h>
|
| 176 |
+
#include <ATen/ops/acosh_cpu_dispatch.h>
|
| 177 |
+
#include <ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h>
|
| 178 |
+
#include <ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h>
|
| 179 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h>
|
| 180 |
+
#include <ATen/ops/adaptive_max_pool2d_cpu_dispatch.h>
|
| 181 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h>
|
| 182 |
+
#include <ATen/ops/adaptive_max_pool3d_cpu_dispatch.h>
|
| 183 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h>
|
| 184 |
+
#include <ATen/ops/add_cpu_dispatch.h>
|
| 185 |
+
#include <ATen/ops/addbmm_cpu_dispatch.h>
|
| 186 |
+
#include <ATen/ops/addcdiv_cpu_dispatch.h>
|
| 187 |
+
#include <ATen/ops/addcmul_cpu_dispatch.h>
|
| 188 |
+
#include <ATen/ops/addmm_cpu_dispatch.h>
|
| 189 |
+
#include <ATen/ops/addmv_cpu_dispatch.h>
|
| 190 |
+
#include <ATen/ops/addr_cpu_dispatch.h>
|
| 191 |
+
#include <ATen/ops/all_cpu_dispatch.h>
|
| 192 |
+
#include <ATen/ops/amax_cpu_dispatch.h>
|
| 193 |
+
#include <ATen/ops/amin_cpu_dispatch.h>
|
| 194 |
+
#include <ATen/ops/aminmax_cpu_dispatch.h>
|
| 195 |
+
#include <ATen/ops/angle_cpu_dispatch.h>
|
| 196 |
+
#include <ATen/ops/any_cpu_dispatch.h>
|
| 197 |
+
#include <ATen/ops/arange_cpu_dispatch.h>
|
| 198 |
+
#include <ATen/ops/argmax_cpu_dispatch.h>
|
| 199 |
+
#include <ATen/ops/argmin_cpu_dispatch.h>
|
| 200 |
+
#include <ATen/ops/argsort_cpu_dispatch.h>
|
| 201 |
+
#include <ATen/ops/as_strided_cpu_dispatch.h>
|
| 202 |
+
#include <ATen/ops/asin_cpu_dispatch.h>
|
| 203 |
+
#include <ATen/ops/asinh_cpu_dispatch.h>
|
| 204 |
+
#include <ATen/ops/atan_cpu_dispatch.h>
|
| 205 |
+
#include <ATen/ops/atan2_cpu_dispatch.h>
|
| 206 |
+
#include <ATen/ops/atanh_cpu_dispatch.h>
|
| 207 |
+
#include <ATen/ops/avg_pool2d_cpu_dispatch.h>
|
| 208 |
+
#include <ATen/ops/avg_pool2d_backward_cpu_dispatch.h>
|
| 209 |
+
#include <ATen/ops/avg_pool3d_cpu_dispatch.h>
|
| 210 |
+
#include <ATen/ops/avg_pool3d_backward_cpu_dispatch.h>
|
| 211 |
+
#include <ATen/ops/baddbmm_cpu_dispatch.h>
|
| 212 |
+
#include <ATen/ops/batch_norm_update_stats_cpu_dispatch.h>
|
| 213 |
+
#include <ATen/ops/bernoulli_cpu_dispatch.h>
|
| 214 |
+
#include <ATen/ops/binary_cross_entropy_cpu_dispatch.h>
|
| 215 |
+
#include <ATen/ops/binary_cross_entropy_backward_cpu_dispatch.h>
|
| 216 |
+
#include <ATen/ops/bincount_cpu_dispatch.h>
|
| 217 |
+
#include <ATen/ops/binomial_cpu_dispatch.h>
|
| 218 |
+
#include <ATen/ops/bitwise_and_cpu_dispatch.h>
|
| 219 |
+
#include <ATen/ops/bitwise_left_shift_cpu_dispatch.h>
|
| 220 |
+
#include <ATen/ops/bitwise_not_cpu_dispatch.h>
|
| 221 |
+
#include <ATen/ops/bitwise_or_cpu_dispatch.h>
|
| 222 |
+
#include <ATen/ops/bitwise_right_shift_cpu_dispatch.h>
|
| 223 |
+
#include <ATen/ops/bitwise_xor_cpu_dispatch.h>
|
| 224 |
+
#include <ATen/ops/bmm_cpu_dispatch.h>
|
| 225 |
+
#include <ATen/ops/bucketize_cpu_dispatch.h>
|
| 226 |
+
#include <ATen/ops/cat_cpu_dispatch.h>
|
| 227 |
+
#include <ATen/ops/cauchy_cpu_dispatch.h>
|
| 228 |
+
#include <ATen/ops/ceil_cpu_dispatch.h>
|
| 229 |
+
#include <ATen/ops/channel_shuffle_cpu_dispatch.h>
|
| 230 |
+
#include <ATen/ops/cholesky_cpu_dispatch.h>
|
| 231 |
+
#include <ATen/ops/cholesky_inverse_cpu_dispatch.h>
|
| 232 |
+
#include <ATen/ops/clamp_cpu_dispatch.h>
|
| 233 |
+
#include <ATen/ops/clamp_max_cpu_dispatch.h>
|
| 234 |
+
#include <ATen/ops/clamp_min_cpu_dispatch.h>
|
| 235 |
+
#include <ATen/ops/col2im_cpu_dispatch.h>
|
| 236 |
+
#include <ATen/ops/complex_cpu_dispatch.h>
|
| 237 |
+
#include <ATen/ops/conj_physical_cpu_dispatch.h>
|
| 238 |
+
#include <ATen/ops/copysign_cpu_dispatch.h>
|
| 239 |
+
#include <ATen/ops/cos_cpu_dispatch.h>
|
| 240 |
+
#include <ATen/ops/cosh_cpu_dispatch.h>
|
| 241 |
+
#include <ATen/ops/count_nonzero_cpu_dispatch.h>
|
| 242 |
+
#include <ATen/ops/cumprod_cpu_dispatch.h>
|
| 243 |
+
#include <ATen/ops/cumsum_cpu_dispatch.h>
|
| 244 |
+
#include <ATen/ops/dense_dim_cpu_dispatch.h>
|
| 245 |
+
#include <ATen/ops/dequantize_cpu_dispatch.h>
|
| 246 |
+
#include <ATen/ops/digamma_cpu_dispatch.h>
|
| 247 |
+
#include <ATen/ops/div_cpu_dispatch.h>
|
| 248 |
+
#include <ATen/ops/dot_cpu_dispatch.h>
|
| 249 |
+
#include <ATen/ops/elu_cpu_dispatch.h>
|
| 250 |
+
#include <ATen/ops/elu_backward_cpu_dispatch.h>
|
| 251 |
+
#include <ATen/ops/embedding_dense_backward_cpu_dispatch.h>
|
| 252 |
+
#include <ATen/ops/embedding_renorm_cpu_dispatch.h>
|
| 253 |
+
#include <ATen/ops/empty_cpu_dispatch.h>
|
| 254 |
+
#include <ATen/ops/empty_strided_cpu_dispatch.h>
|
| 255 |
+
#include <ATen/ops/eq_cpu_dispatch.h>
|
| 256 |
+
#include <ATen/ops/equal_cpu_dispatch.h>
|
| 257 |
+
#include <ATen/ops/erf_cpu_dispatch.h>
|
| 258 |
+
#include <ATen/ops/erfc_cpu_dispatch.h>
|
| 259 |
+
#include <ATen/ops/erfinv_cpu_dispatch.h>
|
| 260 |
+
#include <ATen/ops/exp_cpu_dispatch.h>
|
| 261 |
+
#include <ATen/ops/exp2_cpu_dispatch.h>
|
| 262 |
+
#include <ATen/ops/expm1_cpu_dispatch.h>
|
| 263 |
+
#include <ATen/ops/exponential_cpu_dispatch.h>
|
| 264 |
+
#include <ATen/ops/eye_cpu_dispatch.h>
|
| 265 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_cpu_dispatch.h>
|
| 266 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_cpu_dispatch.h>
|
| 267 |
+
#include <ATen/ops/fill_cpu_dispatch.h>
|
| 268 |
+
#include <ATen/ops/flip_cpu_dispatch.h>
|
| 269 |
+
#include <ATen/ops/floor_cpu_dispatch.h>
|
| 270 |
+
#include <ATen/ops/floor_divide_cpu_dispatch.h>
|
| 271 |
+
#include <ATen/ops/fmax_cpu_dispatch.h>
|
| 272 |
+
#include <ATen/ops/fmin_cpu_dispatch.h>
|
| 273 |
+
#include <ATen/ops/fmod_cpu_dispatch.h>
|
| 274 |
+
#include <ATen/ops/frac_cpu_dispatch.h>
|
| 275 |
+
#include <ATen/ops/fractional_max_pool2d_cpu_dispatch.h>
|
| 276 |
+
#include <ATen/ops/fractional_max_pool2d_backward_cpu_dispatch.h>
|
| 277 |
+
#include <ATen/ops/fractional_max_pool3d_cpu_dispatch.h>
|
| 278 |
+
#include <ATen/ops/fractional_max_pool3d_backward_cpu_dispatch.h>
|
| 279 |
+
#include <ATen/ops/frexp_cpu_dispatch.h>
|
| 280 |
+
#include <ATen/ops/from_file_cpu_dispatch.h>
|
| 281 |
+
#include <ATen/ops/gather_cpu_dispatch.h>
|
| 282 |
+
#include <ATen/ops/gcd_cpu_dispatch.h>
|
| 283 |
+
#include <ATen/ops/ge_cpu_dispatch.h>
|
| 284 |
+
#include <ATen/ops/gelu_cpu_dispatch.h>
|
| 285 |
+
#include <ATen/ops/gelu_backward_cpu_dispatch.h>
|
| 286 |
+
#include <ATen/ops/geometric_cpu_dispatch.h>
|
| 287 |
+
#include <ATen/ops/geqrf_cpu_dispatch.h>
|
| 288 |
+
#include <ATen/ops/glu_cpu_dispatch.h>
|
| 289 |
+
#include <ATen/ops/glu_backward_cpu_dispatch.h>
|
| 290 |
+
#include <ATen/ops/glu_backward_jvp_cpu_dispatch.h>
|
| 291 |
+
#include <ATen/ops/glu_jvp_cpu_dispatch.h>
|
| 292 |
+
#include <ATen/ops/grid_sampler_2d_cpu_dispatch.h>
|
| 293 |
+
#include <ATen/ops/grid_sampler_2d_backward_cpu_dispatch.h>
|
| 294 |
+
#include <ATen/ops/grid_sampler_3d_cpu_dispatch.h>
|
| 295 |
+
#include <ATen/ops/grid_sampler_3d_backward_cpu_dispatch.h>
|
| 296 |
+
#include <ATen/ops/gt_cpu_dispatch.h>
|
| 297 |
+
#include <ATen/ops/hardshrink_cpu_dispatch.h>
|
| 298 |
+
#include <ATen/ops/hardshrink_backward_cpu_dispatch.h>
|
| 299 |
+
#include <ATen/ops/hardsigmoid_cpu_dispatch.h>
|
| 300 |
+
#include <ATen/ops/hardsigmoid_backward_cpu_dispatch.h>
|
| 301 |
+
#include <ATen/ops/hardswish_cpu_dispatch.h>
|
| 302 |
+
#include <ATen/ops/hardswish_backward_cpu_dispatch.h>
|
| 303 |
+
#include <ATen/ops/hardtanh_cpu_dispatch.h>
|
| 304 |
+
#include <ATen/ops/hardtanh_backward_cpu_dispatch.h>
|
| 305 |
+
#include <ATen/ops/heaviside_cpu_dispatch.h>
|
| 306 |
+
#include <ATen/ops/histc_cpu_dispatch.h>
|
| 307 |
+
#include <ATen/ops/histogram_cpu_dispatch.h>
|
| 308 |
+
#include <ATen/ops/huber_loss_cpu_dispatch.h>
|
| 309 |
+
#include <ATen/ops/huber_loss_backward_cpu_dispatch.h>
|
| 310 |
+
#include <ATen/ops/hypot_cpu_dispatch.h>
|
| 311 |
+
#include <ATen/ops/i0_cpu_dispatch.h>
|
| 312 |
+
#include <ATen/ops/igamma_cpu_dispatch.h>
|
| 313 |
+
#include <ATen/ops/igammac_cpu_dispatch.h>
|
| 314 |
+
#include <ATen/ops/im2col_cpu_dispatch.h>
|
| 315 |
+
#include <ATen/ops/index_cpu_dispatch.h>
|
| 316 |
+
#include <ATen/ops/index_add_cpu_dispatch.h>
|
| 317 |
+
#include <ATen/ops/index_copy_cpu_dispatch.h>
|
| 318 |
+
#include <ATen/ops/index_fill_cpu_dispatch.h>
|
| 319 |
+
#include <ATen/ops/index_reduce_cpu_dispatch.h>
|
| 320 |
+
#include <ATen/ops/index_select_cpu_dispatch.h>
|
| 321 |
+
#include <ATen/ops/is_set_to_cpu_dispatch.h>
|
| 322 |
+
#include <ATen/ops/isin_cpu_dispatch.h>
|
| 323 |
+
#include <ATen/ops/isnan_cpu_dispatch.h>
|
| 324 |
+
#include <ATen/ops/isneginf_cpu_dispatch.h>
|
| 325 |
+
#include <ATen/ops/isposinf_cpu_dispatch.h>
|
| 326 |
+
#include <ATen/ops/kthvalue_cpu_dispatch.h>
|
| 327 |
+
#include <ATen/ops/lcm_cpu_dispatch.h>
|
| 328 |
+
#include <ATen/ops/le_cpu_dispatch.h>
|
| 329 |
+
#include <ATen/ops/leaky_relu_cpu_dispatch.h>
|
| 330 |
+
#include <ATen/ops/leaky_relu_backward_cpu_dispatch.h>
|
| 331 |
+
#include <ATen/ops/lerp_cpu_dispatch.h>
|
| 332 |
+
#include <ATen/ops/lgamma_cpu_dispatch.h>
|
| 333 |
+
#include <ATen/ops/linalg_cholesky_ex_cpu_dispatch.h>
|
| 334 |
+
#include <ATen/ops/linalg_cross_cpu_dispatch.h>
|
| 335 |
+
#include <ATen/ops/linalg_eig_cpu_dispatch.h>
|
| 336 |
+
#include <ATen/ops/linalg_eigvals_cpu_dispatch.h>
|
| 337 |
+
#include <ATen/ops/linalg_householder_product_cpu_dispatch.h>
|
| 338 |
+
#include <ATen/ops/linalg_inv_ex_cpu_dispatch.h>
|
| 339 |
+
#include <ATen/ops/linalg_ldl_factor_ex_cpu_dispatch.h>
|
| 340 |
+
#include <ATen/ops/linalg_ldl_solve_cpu_dispatch.h>
|
| 341 |
+
#include <ATen/ops/linalg_lstsq_cpu_dispatch.h>
|
| 342 |
+
#include <ATen/ops/linalg_lu_cpu_dispatch.h>
|
| 343 |
+
#include <ATen/ops/linalg_lu_factor_ex_cpu_dispatch.h>
|
| 344 |
+
#include <ATen/ops/linalg_lu_solve_cpu_dispatch.h>
|
| 345 |
+
#include <ATen/ops/linalg_matrix_exp_cpu_dispatch.h>
|
| 346 |
+
#include <ATen/ops/linalg_qr_cpu_dispatch.h>
|
| 347 |
+
#include <ATen/ops/linalg_solve_triangular_cpu_dispatch.h>
|
| 348 |
+
#include <ATen/ops/linalg_vector_norm_cpu_dispatch.h>
|
| 349 |
+
#include <ATen/ops/linspace_cpu_dispatch.h>
|
| 350 |
+
#include <ATen/ops/log_cpu_dispatch.h>
|
| 351 |
+
#include <ATen/ops/log10_cpu_dispatch.h>
|
| 352 |
+
#include <ATen/ops/log1p_cpu_dispatch.h>
|
| 353 |
+
#include <ATen/ops/log2_cpu_dispatch.h>
|
| 354 |
+
#include <ATen/ops/log_normal_cpu_dispatch.h>
|
| 355 |
+
#include <ATen/ops/log_sigmoid_backward_cpu_dispatch.h>
|
| 356 |
+
#include <ATen/ops/log_sigmoid_forward_cpu_dispatch.h>
|
| 357 |
+
#include <ATen/ops/logaddexp_cpu_dispatch.h>
|
| 358 |
+
#include <ATen/ops/logaddexp2_cpu_dispatch.h>
|
| 359 |
+
#include <ATen/ops/logical_and_cpu_dispatch.h>
|
| 360 |
+
#include <ATen/ops/logical_not_cpu_dispatch.h>
|
| 361 |
+
#include <ATen/ops/logical_or_cpu_dispatch.h>
|
| 362 |
+
#include <ATen/ops/logical_xor_cpu_dispatch.h>
|
| 363 |
+
#include <ATen/ops/logit_cpu_dispatch.h>
|
| 364 |
+
#include <ATen/ops/logit_backward_cpu_dispatch.h>
|
| 365 |
+
#include <ATen/ops/logspace_cpu_dispatch.h>
|
| 366 |
+
#include <ATen/ops/lshift_cpu_dispatch.h>
|
| 367 |
+
#include <ATen/ops/lt_cpu_dispatch.h>
|
| 368 |
+
#include <ATen/ops/lu_unpack_cpu_dispatch.h>
|
| 369 |
+
#include <ATen/ops/masked_fill_cpu_dispatch.h>
|
| 370 |
+
#include <ATen/ops/masked_scatter_cpu_dispatch.h>
|
| 371 |
+
#include <ATen/ops/masked_select_cpu_dispatch.h>
|
| 372 |
+
#include <ATen/ops/max_cpu_dispatch.h>
|
| 373 |
+
#include <ATen/ops/max_pool2d_with_indices_cpu_dispatch.h>
|
| 374 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_cpu_dispatch.h>
|
| 375 |
+
#include <ATen/ops/max_pool3d_with_indices_cpu_dispatch.h>
|
| 376 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_cpu_dispatch.h>
|
| 377 |
+
#include <ATen/ops/max_unpool2d_cpu_dispatch.h>
|
| 378 |
+
#include <ATen/ops/max_unpool3d_cpu_dispatch.h>
|
| 379 |
+
#include <ATen/ops/maximum_cpu_dispatch.h>
|
| 380 |
+
#include <ATen/ops/mean_cpu_dispatch.h>
|
| 381 |
+
#include <ATen/ops/median_cpu_dispatch.h>
|
| 382 |
+
#include <ATen/ops/min_cpu_dispatch.h>
|
| 383 |
+
#include <ATen/ops/minimum_cpu_dispatch.h>
|
| 384 |
+
#include <ATen/ops/mish_cpu_dispatch.h>
|
| 385 |
+
#include <ATen/ops/mish_backward_cpu_dispatch.h>
|
| 386 |
+
#include <ATen/ops/mkldnn_rnn_layer_cpu_dispatch.h>
|
| 387 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_cpu_dispatch.h>
|
| 388 |
+
#include <ATen/ops/mm_cpu_dispatch.h>
|
| 389 |
+
#include <ATen/ops/mode_cpu_dispatch.h>
|
| 390 |
+
#include <ATen/ops/mse_loss_cpu_dispatch.h>
|
| 391 |
+
#include <ATen/ops/mse_loss_backward_cpu_dispatch.h>
|
| 392 |
+
#include <ATen/ops/mul_cpu_dispatch.h>
|
| 393 |
+
#include <ATen/ops/multi_margin_loss_cpu_dispatch.h>
|
| 394 |
+
#include <ATen/ops/multi_margin_loss_backward_cpu_dispatch.h>
|
| 395 |
+
#include <ATen/ops/multilabel_margin_loss_backward_cpu_dispatch.h>
|
| 396 |
+
#include <ATen/ops/multilabel_margin_loss_forward_cpu_dispatch.h>
|
| 397 |
+
#include <ATen/ops/multinomial_cpu_dispatch.h>
|
| 398 |
+
#include <ATen/ops/mvlgamma_cpu_dispatch.h>
|
| 399 |
+
#include <ATen/ops/nan_to_num_cpu_dispatch.h>
|
| 400 |
+
#include <ATen/ops/nanmedian_cpu_dispatch.h>
|
| 401 |
+
#include <ATen/ops/nansum_cpu_dispatch.h>
|
| 402 |
+
#include <ATen/ops/narrow_copy_cpu_dispatch.h>
|
| 403 |
+
#include <ATen/ops/native_batch_norm_cpu_dispatch.h>
|
| 404 |
+
#include <ATen/ops/native_batch_norm_backward_cpu_dispatch.h>
|
| 405 |
+
#include <ATen/ops/native_channel_shuffle_cpu_dispatch.h>
|
| 406 |
+
#include <ATen/ops/native_dropout_cpu_dispatch.h>
|
| 407 |
+
#include <ATen/ops/native_dropout_backward_cpu_dispatch.h>
|
| 408 |
+
#include <ATen/ops/native_group_norm_cpu_dispatch.h>
|
| 409 |
+
#include <ATen/ops/native_group_norm_backward_cpu_dispatch.h>
|
| 410 |
+
#include <ATen/ops/native_layer_norm_cpu_dispatch.h>
|
| 411 |
+
#include <ATen/ops/native_layer_norm_backward_cpu_dispatch.h>
|
| 412 |
+
#include <ATen/ops/ne_cpu_dispatch.h>
|
| 413 |
+
#include <ATen/ops/neg_cpu_dispatch.h>
|
| 414 |
+
#include <ATen/ops/nextafter_cpu_dispatch.h>
|
| 415 |
+
#include <ATen/ops/nll_loss2d_backward_cpu_dispatch.h>
|
| 416 |
+
#include <ATen/ops/nll_loss2d_forward_cpu_dispatch.h>
|
| 417 |
+
#include <ATen/ops/nll_loss_backward_cpu_dispatch.h>
|
| 418 |
+
#include <ATen/ops/nll_loss_forward_cpu_dispatch.h>
|
| 419 |
+
#include <ATen/ops/nonzero_cpu_dispatch.h>
|
| 420 |
+
#include <ATen/ops/nonzero_static_cpu_dispatch.h>
|
| 421 |
+
#include <ATen/ops/norm_cpu_dispatch.h>
|
| 422 |
+
#include <ATen/ops/normal_cpu_dispatch.h>
|
| 423 |
+
#include <ATen/ops/ormqr_cpu_dispatch.h>
|
| 424 |
+
#include <ATen/ops/pixel_shuffle_cpu_dispatch.h>
|
| 425 |
+
#include <ATen/ops/pixel_unshuffle_cpu_dispatch.h>
|
| 426 |
+
#include <ATen/ops/poisson_cpu_dispatch.h>
|
| 427 |
+
#include <ATen/ops/polar_cpu_dispatch.h>
|
| 428 |
+
#include <ATen/ops/polygamma_cpu_dispatch.h>
|
| 429 |
+
#include <ATen/ops/pow_cpu_dispatch.h>
|
| 430 |
+
#include <ATen/ops/prod_cpu_dispatch.h>
|
| 431 |
+
#include <ATen/ops/put_cpu_dispatch.h>
|
| 432 |
+
#include <ATen/ops/quantize_per_channel_cpu_dispatch.h>
|
| 433 |
+
#include <ATen/ops/quantize_per_tensor_cpu_dispatch.h>
|
| 434 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_cpu_dispatch.h>
|
| 435 |
+
#include <ATen/ops/random_cpu_dispatch.h>
|
| 436 |
+
#include <ATen/ops/randperm_cpu_dispatch.h>
|
| 437 |
+
#include <ATen/ops/range_cpu_dispatch.h>
|
| 438 |
+
#include <ATen/ops/reciprocal_cpu_dispatch.h>
|
| 439 |
+
#include <ATen/ops/reflection_pad1d_cpu_dispatch.h>
|
| 440 |
+
#include <ATen/ops/reflection_pad1d_backward_cpu_dispatch.h>
|
| 441 |
+
#include <ATen/ops/reflection_pad2d_cpu_dispatch.h>
|
| 442 |
+
#include <ATen/ops/reflection_pad2d_backward_cpu_dispatch.h>
|
| 443 |
+
#include <ATen/ops/reflection_pad3d_cpu_dispatch.h>
|
| 444 |
+
#include <ATen/ops/reflection_pad3d_backward_cpu_dispatch.h>
|
| 445 |
+
#include <ATen/ops/relu_cpu_dispatch.h>
|
| 446 |
+
#include <ATen/ops/remainder_cpu_dispatch.h>
|
| 447 |
+
#include <ATen/ops/renorm_cpu_dispatch.h>
|
| 448 |
+
#include <ATen/ops/repeat_interleave_cpu_dispatch.h>
|
| 449 |
+
#include <ATen/ops/replication_pad1d_cpu_dispatch.h>
|
| 450 |
+
#include <ATen/ops/replication_pad1d_backward_cpu_dispatch.h>
|
| 451 |
+
#include <ATen/ops/replication_pad2d_cpu_dispatch.h>
|
| 452 |
+
#include <ATen/ops/replication_pad2d_backward_cpu_dispatch.h>
|
| 453 |
+
#include <ATen/ops/replication_pad3d_cpu_dispatch.h>
|
| 454 |
+
#include <ATen/ops/replication_pad3d_backward_cpu_dispatch.h>
|
| 455 |
+
#include <ATen/ops/resize_cpu_dispatch.h>
|
| 456 |
+
#include <ATen/ops/roll_cpu_dispatch.h>
|
| 457 |
+
#include <ATen/ops/round_cpu_dispatch.h>
|
| 458 |
+
#include <ATen/ops/rrelu_with_noise_cpu_dispatch.h>
|
| 459 |
+
#include <ATen/ops/rshift_cpu_dispatch.h>
|
| 460 |
+
#include <ATen/ops/rsqrt_cpu_dispatch.h>
|
| 461 |
+
#include <ATen/ops/rsub_cpu_dispatch.h>
|
| 462 |
+
#include <ATen/ops/scatter_cpu_dispatch.h>
|
| 463 |
+
#include <ATen/ops/scatter_add_cpu_dispatch.h>
|
| 464 |
+
#include <ATen/ops/scatter_reduce_cpu_dispatch.h>
|
| 465 |
+
#include <ATen/ops/searchsorted_cpu_dispatch.h>
|
| 466 |
+
#include <ATen/ops/segment_reduce_cpu_dispatch.h>
|
| 467 |
+
#include <ATen/ops/set_cpu_dispatch.h>
|
| 468 |
+
#include <ATen/ops/sgn_cpu_dispatch.h>
|
| 469 |
+
#include <ATen/ops/sigmoid_cpu_dispatch.h>
|
| 470 |
+
#include <ATen/ops/sigmoid_backward_cpu_dispatch.h>
|
| 471 |
+
#include <ATen/ops/sign_cpu_dispatch.h>
|
| 472 |
+
#include <ATen/ops/signbit_cpu_dispatch.h>
|
| 473 |
+
#include <ATen/ops/silu_cpu_dispatch.h>
|
| 474 |
+
#include <ATen/ops/silu_backward_cpu_dispatch.h>
|
| 475 |
+
#include <ATen/ops/sin_cpu_dispatch.h>
|
| 476 |
+
#include <ATen/ops/sinc_cpu_dispatch.h>
|
| 477 |
+
#include <ATen/ops/sinh_cpu_dispatch.h>
|
| 478 |
+
#include <ATen/ops/slow_conv3d_forward_cpu_dispatch.h>
|
| 479 |
+
#include <ATen/ops/slow_conv_dilated2d_cpu_dispatch.h>
|
| 480 |
+
#include <ATen/ops/slow_conv_dilated3d_cpu_dispatch.h>
|
| 481 |
+
#include <ATen/ops/slow_conv_transpose2d_cpu_dispatch.h>
|
| 482 |
+
#include <ATen/ops/slow_conv_transpose3d_cpu_dispatch.h>
|
| 483 |
+
#include <ATen/ops/smooth_l1_loss_cpu_dispatch.h>
|
| 484 |
+
#include <ATen/ops/smooth_l1_loss_backward_cpu_dispatch.h>
|
| 485 |
+
#include <ATen/ops/softplus_cpu_dispatch.h>
|
| 486 |
+
#include <ATen/ops/softplus_backward_cpu_dispatch.h>
|
| 487 |
+
#include <ATen/ops/softshrink_cpu_dispatch.h>
|
| 488 |
+
#include <ATen/ops/softshrink_backward_cpu_dispatch.h>
|
| 489 |
+
#include <ATen/ops/sort_cpu_dispatch.h>
|
| 490 |
+
#include <ATen/ops/sparse_dim_cpu_dispatch.h>
|
| 491 |
+
#include <ATen/ops/special_airy_ai_cpu_dispatch.h>
|
| 492 |
+
#include <ATen/ops/special_bessel_j0_cpu_dispatch.h>
|
| 493 |
+
#include <ATen/ops/special_bessel_j1_cpu_dispatch.h>
|
| 494 |
+
#include <ATen/ops/special_bessel_y0_cpu_dispatch.h>
|
| 495 |
+
#include <ATen/ops/special_bessel_y1_cpu_dispatch.h>
|
| 496 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_cpu_dispatch.h>
|
| 497 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_cpu_dispatch.h>
|
| 498 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_cpu_dispatch.h>
|
| 499 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_cpu_dispatch.h>
|
| 500 |
+
#include <ATen/ops/special_entr_cpu_dispatch.h>
|
| 501 |
+
#include <ATen/ops/special_erfcx_cpu_dispatch.h>
|
| 502 |
+
#include <ATen/ops/special_hermite_polynomial_h_cpu_dispatch.h>
|
| 503 |
+
#include <ATen/ops/special_hermite_polynomial_he_cpu_dispatch.h>
|
| 504 |
+
#include <ATen/ops/special_i0e_cpu_dispatch.h>
|
| 505 |
+
#include <ATen/ops/special_i1_cpu_dispatch.h>
|
| 506 |
+
#include <ATen/ops/special_i1e_cpu_dispatch.h>
|
| 507 |
+
#include <ATen/ops/special_laguerre_polynomial_l_cpu_dispatch.h>
|
| 508 |
+
#include <ATen/ops/special_legendre_polynomial_p_cpu_dispatch.h>
|
| 509 |
+
#include <ATen/ops/special_log_ndtr_cpu_dispatch.h>
|
| 510 |
+
#include <ATen/ops/special_modified_bessel_i0_cpu_dispatch.h>
|
| 511 |
+
#include <ATen/ops/special_modified_bessel_i1_cpu_dispatch.h>
|
| 512 |
+
#include <ATen/ops/special_modified_bessel_k0_cpu_dispatch.h>
|
| 513 |
+
#include <ATen/ops/special_modified_bessel_k1_cpu_dispatch.h>
|
| 514 |
+
#include <ATen/ops/special_ndtri_cpu_dispatch.h>
|
| 515 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_cpu_dispatch.h>
|
| 516 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_cpu_dispatch.h>
|
| 517 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_cpu_dispatch.h>
|
| 518 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_cpu_dispatch.h>
|
| 519 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_cpu_dispatch.h>
|
| 520 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_cpu_dispatch.h>
|
| 521 |
+
#include <ATen/ops/special_spherical_bessel_j0_cpu_dispatch.h>
|
| 522 |
+
#include <ATen/ops/special_xlog1py_cpu_dispatch.h>
|
| 523 |
+
#include <ATen/ops/special_zeta_cpu_dispatch.h>
|
| 524 |
+
#include <ATen/ops/sqrt_cpu_dispatch.h>
|
| 525 |
+
#include <ATen/ops/sspaddmm_cpu_dispatch.h>
|
| 526 |
+
#include <ATen/ops/std_cpu_dispatch.h>
|
| 527 |
+
#include <ATen/ops/std_mean_cpu_dispatch.h>
|
| 528 |
+
#include <ATen/ops/sub_cpu_dispatch.h>
|
| 529 |
+
#include <ATen/ops/sum_cpu_dispatch.h>
|
| 530 |
+
#include <ATen/ops/take_cpu_dispatch.h>
|
| 531 |
+
#include <ATen/ops/tan_cpu_dispatch.h>
|
| 532 |
+
#include <ATen/ops/tanh_cpu_dispatch.h>
|
| 533 |
+
#include <ATen/ops/tanh_backward_cpu_dispatch.h>
|
| 534 |
+
#include <ATen/ops/threshold_cpu_dispatch.h>
|
| 535 |
+
#include <ATen/ops/threshold_backward_cpu_dispatch.h>
|
| 536 |
+
#include <ATen/ops/to_mkldnn_cpu_dispatch.h>
|
| 537 |
+
#include <ATen/ops/topk_cpu_dispatch.h>
|
| 538 |
+
#include <ATen/ops/trace_cpu_dispatch.h>
|
| 539 |
+
#include <ATen/ops/triangular_solve_cpu_dispatch.h>
|
| 540 |
+
#include <ATen/ops/tril_cpu_dispatch.h>
|
| 541 |
+
#include <ATen/ops/tril_indices_cpu_dispatch.h>
|
| 542 |
+
#include <ATen/ops/triu_cpu_dispatch.h>
|
| 543 |
+
#include <ATen/ops/triu_indices_cpu_dispatch.h>
|
| 544 |
+
#include <ATen/ops/trunc_cpu_dispatch.h>
|
| 545 |
+
#include <ATen/ops/unfold_cpu_dispatch.h>
|
| 546 |
+
#include <ATen/ops/unfold_backward_cpu_dispatch.h>
|
| 547 |
+
#include <ATen/ops/uniform_cpu_dispatch.h>
|
| 548 |
+
#include <ATen/ops/unique_consecutive_cpu_dispatch.h>
|
| 549 |
+
#include <ATen/ops/unique_dim_cpu_dispatch.h>
|
| 550 |
+
#include <ATen/ops/unique_dim_consecutive_cpu_dispatch.h>
|
| 551 |
+
#include <ATen/ops/upsample_bicubic2d_cpu_dispatch.h>
|
| 552 |
+
#include <ATen/ops/upsample_bicubic2d_backward_cpu_dispatch.h>
|
| 553 |
+
#include <ATen/ops/upsample_bilinear2d_cpu_dispatch.h>
|
| 554 |
+
#include <ATen/ops/upsample_bilinear2d_backward_cpu_dispatch.h>
|
| 555 |
+
#include <ATen/ops/upsample_linear1d_cpu_dispatch.h>
|
| 556 |
+
#include <ATen/ops/upsample_linear1d_backward_cpu_dispatch.h>
|
| 557 |
+
#include <ATen/ops/upsample_nearest1d_cpu_dispatch.h>
|
| 558 |
+
#include <ATen/ops/upsample_nearest1d_backward_cpu_dispatch.h>
|
| 559 |
+
#include <ATen/ops/upsample_nearest2d_cpu_dispatch.h>
|
| 560 |
+
#include <ATen/ops/upsample_nearest2d_backward_cpu_dispatch.h>
|
| 561 |
+
#include <ATen/ops/upsample_nearest3d_cpu_dispatch.h>
|
| 562 |
+
#include <ATen/ops/upsample_nearest3d_backward_cpu_dispatch.h>
|
| 563 |
+
#include <ATen/ops/upsample_trilinear3d_cpu_dispatch.h>
|
| 564 |
+
#include <ATen/ops/upsample_trilinear3d_backward_cpu_dispatch.h>
|
| 565 |
+
#include <ATen/ops/var_cpu_dispatch.h>
|
| 566 |
+
#include <ATen/ops/var_mean_cpu_dispatch.h>
|
| 567 |
+
#include <ATen/ops/vdot_cpu_dispatch.h>
|
| 568 |
+
#include <ATen/ops/view_cpu_dispatch.h>
|
| 569 |
+
#include <ATen/ops/view_as_complex_cpu_dispatch.h>
|
| 570 |
+
#include <ATen/ops/view_as_real_cpu_dispatch.h>
|
| 571 |
+
#include <ATen/ops/where_cpu_dispatch.h>
|
| 572 |
+
#include <ATen/ops/xlogy_cpu_dispatch.h>
|
| 573 |
+
#include <ATen/ops/zero_cpu_dispatch.h>
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CPUGeneratorImpl.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Generator.h>
|
| 4 |
+
#include <ATen/core/MT19937RNGEngine.h>
|
| 5 |
+
#include <c10/core/GeneratorImpl.h>
|
| 6 |
+
#include <c10/util/Optional.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
|
| 11 |
+
// Constructors
|
| 12 |
+
CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
|
| 13 |
+
~CPUGeneratorImpl() override = default;
|
| 14 |
+
|
| 15 |
+
// CPUGeneratorImpl methods
|
| 16 |
+
std::shared_ptr<CPUGeneratorImpl> clone() const;
|
| 17 |
+
void set_current_seed(uint64_t seed) override;
|
| 18 |
+
void set_offset(uint64_t offset) override;
|
| 19 |
+
uint64_t get_offset() const override;
|
| 20 |
+
uint64_t current_seed() const override;
|
| 21 |
+
uint64_t seed() override;
|
| 22 |
+
void set_state(const c10::TensorImpl& new_state) override;
|
| 23 |
+
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
| 24 |
+
static c10::DeviceType device_type();
|
| 25 |
+
uint32_t random();
|
| 26 |
+
uint64_t random64();
|
| 27 |
+
c10::optional<float> next_float_normal_sample();
|
| 28 |
+
c10::optional<double> next_double_normal_sample();
|
| 29 |
+
void set_next_float_normal_sample(c10::optional<float> randn);
|
| 30 |
+
void set_next_double_normal_sample(c10::optional<double> randn);
|
| 31 |
+
at::mt19937 engine();
|
| 32 |
+
void set_engine(at::mt19937 engine);
|
| 33 |
+
|
| 34 |
+
private:
|
| 35 |
+
CPUGeneratorImpl* clone_impl() const override;
|
| 36 |
+
at::mt19937 engine_;
|
| 37 |
+
c10::optional<float> next_float_normal_sample_;
|
| 38 |
+
c10::optional<double> next_double_normal_sample_;
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
namespace detail {
|
| 42 |
+
|
| 43 |
+
TORCH_API const Generator& getDefaultCPUGenerator();
|
| 44 |
+
TORCH_API Generator
|
| 45 |
+
createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
|
| 46 |
+
|
| 47 |
+
} // namespace detail
|
| 48 |
+
|
| 49 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CUDAFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CUDAFunctions_inl.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CUDAFunctions_inl.h
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_cuda_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_adaptive_avg_pool2d_cuda_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_cuda_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_adaptive_avg_pool3d_cuda_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_cuda_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_addmm_activation_cuda_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_aminmax_cuda_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_cuda_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_amp_update_scale_cuda_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_assert_async_cuda_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_cdist_backward_cuda_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_cdist_forward_cuda_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cholesky_solve_helper_cuda_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_compute_linear_combination_cuda_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_conv_depthwise2d_cuda_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_cuda_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_cuda_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_convert_weight_to_int4pack_cuda_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_cslt_compress_cuda_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_cslt_sparse_mm_cuda_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_cslt_sparse_mm_search_cuda_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_ctc_loss_cuda_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_ctc_loss_backward_cuda_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_cudnn_ctc_loss_cuda_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_cudnn_init_dropout_state_cuda_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_cudnn_rnn_cuda_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_cudnn_rnn_backward_cuda_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_cuda_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_cummax_helper_cuda_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_cummin_helper_cuda_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_dirichlet_grad_cuda_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_efficient_attention_backward_cuda_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_efficient_attention_forward_cuda_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_efficientzerotensor_cuda_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_embedding_bag_cuda_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_embedding_bag_dense_backward_cuda_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_embedding_bag_forward_only_cuda_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_cuda_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_cuda_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cuda_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_cuda_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cuda_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cuda_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_fft_c2c_cuda_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_fft_c2r_cuda_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_fft_r2c_cuda_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_cuda_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_flash_attention_backward_cuda_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_flash_attention_forward_cuda_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_foreach_abs_cuda_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_foreach_acos_cuda_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_foreach_add_cuda_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_foreach_addcdiv_cuda_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_foreach_addcmul_cuda_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_foreach_asin_cuda_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_foreach_atan_cuda_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_foreach_ceil_cuda_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_foreach_clamp_max_cuda_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_foreach_clamp_min_cuda_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_foreach_copy_cuda_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_foreach_cos_cuda_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_foreach_cosh_cuda_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_foreach_div_cuda_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_foreach_erf_cuda_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_foreach_erfc_cuda_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_foreach_exp_cuda_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_foreach_expm1_cuda_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_foreach_floor_cuda_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_foreach_frac_cuda_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_foreach_lerp_cuda_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_foreach_lgamma_cuda_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_foreach_log_cuda_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_foreach_log10_cuda_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_foreach_log1p_cuda_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_foreach_log2_cuda_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_foreach_maximum_cuda_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_foreach_minimum_cuda_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_foreach_mul_cuda_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_foreach_neg_cuda_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_foreach_norm_cuda_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_foreach_pow_cuda_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_foreach_reciprocal_cuda_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_foreach_round_cuda_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_foreach_sigmoid_cuda_dispatch.h>
|
| 102 |
+
#include <ATen/ops/_foreach_sign_cuda_dispatch.h>
|
| 103 |
+
#include <ATen/ops/_foreach_sin_cuda_dispatch.h>
|
| 104 |
+
#include <ATen/ops/_foreach_sinh_cuda_dispatch.h>
|
| 105 |
+
#include <ATen/ops/_foreach_sqrt_cuda_dispatch.h>
|
| 106 |
+
#include <ATen/ops/_foreach_sub_cuda_dispatch.h>
|
| 107 |
+
#include <ATen/ops/_foreach_tan_cuda_dispatch.h>
|
| 108 |
+
#include <ATen/ops/_foreach_tanh_cuda_dispatch.h>
|
| 109 |
+
#include <ATen/ops/_foreach_trunc_cuda_dispatch.h>
|
| 110 |
+
#include <ATen/ops/_foreach_zero_cuda_dispatch.h>
|
| 111 |
+
#include <ATen/ops/_fused_adam_cuda_dispatch.h>
|
| 112 |
+
#include <ATen/ops/_fused_adamw_cuda_dispatch.h>
|
| 113 |
+
#include <ATen/ops/_fused_dropout_cuda_dispatch.h>
|
| 114 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_cuda_dispatch.h>
|
| 115 |
+
#include <ATen/ops/_fused_sdp_choice_cuda_dispatch.h>
|
| 116 |
+
#include <ATen/ops/_fused_sgd_cuda_dispatch.h>
|
| 117 |
+
#include <ATen/ops/_index_put_impl_cuda_dispatch.h>
|
| 118 |
+
#include <ATen/ops/_int_mm_cuda_dispatch.h>
|
| 119 |
+
#include <ATen/ops/_linalg_det_cuda_dispatch.h>
|
| 120 |
+
#include <ATen/ops/_linalg_eigh_cuda_dispatch.h>
|
| 121 |
+
#include <ATen/ops/_linalg_eigvals_cuda_dispatch.h>
|
| 122 |
+
#include <ATen/ops/_linalg_slogdet_cuda_dispatch.h>
|
| 123 |
+
#include <ATen/ops/_linalg_solve_ex_cuda_dispatch.h>
|
| 124 |
+
#include <ATen/ops/_linalg_svd_cuda_dispatch.h>
|
| 125 |
+
#include <ATen/ops/_local_scalar_dense_cuda_dispatch.h>
|
| 126 |
+
#include <ATen/ops/_log_softmax_cuda_dispatch.h>
|
| 127 |
+
#include <ATen/ops/_log_softmax_backward_data_cuda_dispatch.h>
|
| 128 |
+
#include <ATen/ops/_logcumsumexp_cuda_dispatch.h>
|
| 129 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_cuda_dispatch.h>
|
| 130 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_cuda_dispatch.h>
|
| 131 |
+
#include <ATen/ops/_masked_scale_cuda_dispatch.h>
|
| 132 |
+
#include <ATen/ops/_masked_softmax_cuda_dispatch.h>
|
| 133 |
+
#include <ATen/ops/_masked_softmax_backward_cuda_dispatch.h>
|
| 134 |
+
#include <ATen/ops/_mixed_dtypes_linear_cuda_dispatch.h>
|
| 135 |
+
#include <ATen/ops/_native_batch_norm_legit_cuda_dispatch.h>
|
| 136 |
+
#include <ATen/ops/_native_multi_head_attention_cuda_dispatch.h>
|
| 137 |
+
#include <ATen/ops/_nested_from_padded_cuda_dispatch.h>
|
| 138 |
+
#include <ATen/ops/_nested_tensor_from_mask_cuda_dispatch.h>
|
| 139 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_cuda_dispatch.h>
|
| 140 |
+
#include <ATen/ops/_nested_view_from_buffer_cuda_dispatch.h>
|
| 141 |
+
#include <ATen/ops/_pdist_backward_cuda_dispatch.h>
|
| 142 |
+
#include <ATen/ops/_pdist_forward_cuda_dispatch.h>
|
| 143 |
+
#include <ATen/ops/_pin_memory_cuda_dispatch.h>
|
| 144 |
+
#include <ATen/ops/_prelu_kernel_cuda_dispatch.h>
|
| 145 |
+
#include <ATen/ops/_prelu_kernel_backward_cuda_dispatch.h>
|
| 146 |
+
#include <ATen/ops/_reshape_alias_cuda_dispatch.h>
|
| 147 |
+
#include <ATen/ops/_sample_dirichlet_cuda_dispatch.h>
|
| 148 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_cuda_dispatch.h>
|
| 149 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_cuda_dispatch.h>
|
| 150 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_cuda_dispatch.h>
|
| 151 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_cuda_dispatch.h>
|
| 152 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_cuda_dispatch.h>
|
| 153 |
+
#include <ATen/ops/_scaled_mm_cuda_dispatch.h>
|
| 154 |
+
#include <ATen/ops/_segment_reduce_backward_cuda_dispatch.h>
|
| 155 |
+
#include <ATen/ops/_slow_conv2d_backward_cuda_dispatch.h>
|
| 156 |
+
#include <ATen/ops/_slow_conv2d_forward_cuda_dispatch.h>
|
| 157 |
+
#include <ATen/ops/_softmax_cuda_dispatch.h>
|
| 158 |
+
#include <ATen/ops/_softmax_backward_data_cuda_dispatch.h>
|
| 159 |
+
#include <ATen/ops/_sparse_semi_structured_linear_cuda_dispatch.h>
|
| 160 |
+
#include <ATen/ops/_standard_gamma_cuda_dispatch.h>
|
| 161 |
+
#include <ATen/ops/_standard_gamma_grad_cuda_dispatch.h>
|
| 162 |
+
#include <ATen/ops/_thnn_fused_gru_cell_cuda_dispatch.h>
|
| 163 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_cuda_dispatch.h>
|
| 164 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_cuda_dispatch.h>
|
| 165 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_cuda_dispatch.h>
|
| 166 |
+
#include <ATen/ops/_to_sparse_cuda_dispatch.h>
|
| 167 |
+
#include <ATen/ops/_to_sparse_bsc_cuda_dispatch.h>
|
| 168 |
+
#include <ATen/ops/_to_sparse_bsr_cuda_dispatch.h>
|
| 169 |
+
#include <ATen/ops/_to_sparse_csc_cuda_dispatch.h>
|
| 170 |
+
#include <ATen/ops/_to_sparse_csr_cuda_dispatch.h>
|
| 171 |
+
#include <ATen/ops/_to_sparse_semi_structured_cuda_dispatch.h>
|
| 172 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_cuda_dispatch.h>
|
| 173 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_cuda_dispatch.h>
|
| 174 |
+
#include <ATen/ops/_triton_multi_head_attention_cuda_dispatch.h>
|
| 175 |
+
#include <ATen/ops/_triton_scaled_dot_attention_cuda_dispatch.h>
|
| 176 |
+
#include <ATen/ops/_unique_cuda_dispatch.h>
|
| 177 |
+
#include <ATen/ops/_unique2_cuda_dispatch.h>
|
| 178 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_cuda_dispatch.h>
|
| 179 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_cuda_dispatch.h>
|
| 180 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_cuda_dispatch.h>
|
| 181 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_cuda_dispatch.h>
|
| 182 |
+
#include <ATen/ops/_upsample_nearest_exact1d_cuda_dispatch.h>
|
| 183 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_cuda_dispatch.h>
|
| 184 |
+
#include <ATen/ops/_upsample_nearest_exact2d_cuda_dispatch.h>
|
| 185 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_cuda_dispatch.h>
|
| 186 |
+
#include <ATen/ops/_upsample_nearest_exact3d_cuda_dispatch.h>
|
| 187 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_cuda_dispatch.h>
|
| 188 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_cuda_dispatch.h>
|
| 189 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_cuda_dispatch.h>
|
| 190 |
+
#include <ATen/ops/_weight_int4pack_mm_cuda_dispatch.h>
|
| 191 |
+
#include <ATen/ops/_weight_norm_interface_cuda_dispatch.h>
|
| 192 |
+
#include <ATen/ops/_weight_norm_interface_backward_cuda_dispatch.h>
|
| 193 |
+
#include <ATen/ops/abs_cuda_dispatch.h>
|
| 194 |
+
#include <ATen/ops/acos_cuda_dispatch.h>
|
| 195 |
+
#include <ATen/ops/acosh_cuda_dispatch.h>
|
| 196 |
+
#include <ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h>
|
| 197 |
+
#include <ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h>
|
| 198 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h>
|
| 199 |
+
#include <ATen/ops/adaptive_max_pool2d_cuda_dispatch.h>
|
| 200 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h>
|
| 201 |
+
#include <ATen/ops/adaptive_max_pool3d_cuda_dispatch.h>
|
| 202 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h>
|
| 203 |
+
#include <ATen/ops/add_cuda_dispatch.h>
|
| 204 |
+
#include <ATen/ops/addbmm_cuda_dispatch.h>
|
| 205 |
+
#include <ATen/ops/addcdiv_cuda_dispatch.h>
|
| 206 |
+
#include <ATen/ops/addcmul_cuda_dispatch.h>
|
| 207 |
+
#include <ATen/ops/addmm_cuda_dispatch.h>
|
| 208 |
+
#include <ATen/ops/addmv_cuda_dispatch.h>
|
| 209 |
+
#include <ATen/ops/addr_cuda_dispatch.h>
|
| 210 |
+
#include <ATen/ops/all_cuda_dispatch.h>
|
| 211 |
+
#include <ATen/ops/amax_cuda_dispatch.h>
|
| 212 |
+
#include <ATen/ops/amin_cuda_dispatch.h>
|
| 213 |
+
#include <ATen/ops/aminmax_cuda_dispatch.h>
|
| 214 |
+
#include <ATen/ops/angle_cuda_dispatch.h>
|
| 215 |
+
#include <ATen/ops/any_cuda_dispatch.h>
|
| 216 |
+
#include <ATen/ops/arange_cuda_dispatch.h>
|
| 217 |
+
#include <ATen/ops/argmax_cuda_dispatch.h>
|
| 218 |
+
#include <ATen/ops/argmin_cuda_dispatch.h>
|
| 219 |
+
#include <ATen/ops/argsort_cuda_dispatch.h>
|
| 220 |
+
#include <ATen/ops/as_strided_cuda_dispatch.h>
|
| 221 |
+
#include <ATen/ops/asin_cuda_dispatch.h>
|
| 222 |
+
#include <ATen/ops/asinh_cuda_dispatch.h>
|
| 223 |
+
#include <ATen/ops/atan_cuda_dispatch.h>
|
| 224 |
+
#include <ATen/ops/atan2_cuda_dispatch.h>
|
| 225 |
+
#include <ATen/ops/atanh_cuda_dispatch.h>
|
| 226 |
+
#include <ATen/ops/avg_pool2d_cuda_dispatch.h>
|
| 227 |
+
#include <ATen/ops/avg_pool2d_backward_cuda_dispatch.h>
|
| 228 |
+
#include <ATen/ops/avg_pool3d_cuda_dispatch.h>
|
| 229 |
+
#include <ATen/ops/avg_pool3d_backward_cuda_dispatch.h>
|
| 230 |
+
#include <ATen/ops/baddbmm_cuda_dispatch.h>
|
| 231 |
+
#include <ATen/ops/batch_norm_backward_elemt_cuda_dispatch.h>
|
| 232 |
+
#include <ATen/ops/batch_norm_backward_reduce_cuda_dispatch.h>
|
| 233 |
+
#include <ATen/ops/batch_norm_elemt_cuda_dispatch.h>
|
| 234 |
+
#include <ATen/ops/batch_norm_gather_stats_cuda_dispatch.h>
|
| 235 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_cuda_dispatch.h>
|
| 236 |
+
#include <ATen/ops/batch_norm_stats_cuda_dispatch.h>
|
| 237 |
+
#include <ATen/ops/batch_norm_update_stats_cuda_dispatch.h>
|
| 238 |
+
#include <ATen/ops/bernoulli_cuda_dispatch.h>
|
| 239 |
+
#include <ATen/ops/binary_cross_entropy_cuda_dispatch.h>
|
| 240 |
+
#include <ATen/ops/binary_cross_entropy_backward_cuda_dispatch.h>
|
| 241 |
+
#include <ATen/ops/bincount_cuda_dispatch.h>
|
| 242 |
+
#include <ATen/ops/binomial_cuda_dispatch.h>
|
| 243 |
+
#include <ATen/ops/bitwise_and_cuda_dispatch.h>
|
| 244 |
+
#include <ATen/ops/bitwise_left_shift_cuda_dispatch.h>
|
| 245 |
+
#include <ATen/ops/bitwise_not_cuda_dispatch.h>
|
| 246 |
+
#include <ATen/ops/bitwise_or_cuda_dispatch.h>
|
| 247 |
+
#include <ATen/ops/bitwise_right_shift_cuda_dispatch.h>
|
| 248 |
+
#include <ATen/ops/bitwise_xor_cuda_dispatch.h>
|
| 249 |
+
#include <ATen/ops/bmm_cuda_dispatch.h>
|
| 250 |
+
#include <ATen/ops/bucketize_cuda_dispatch.h>
|
| 251 |
+
#include <ATen/ops/cat_cuda_dispatch.h>
|
| 252 |
+
#include <ATen/ops/cauchy_cuda_dispatch.h>
|
| 253 |
+
#include <ATen/ops/ceil_cuda_dispatch.h>
|
| 254 |
+
#include <ATen/ops/channel_shuffle_cuda_dispatch.h>
|
| 255 |
+
#include <ATen/ops/cholesky_cuda_dispatch.h>
|
| 256 |
+
#include <ATen/ops/cholesky_inverse_cuda_dispatch.h>
|
| 257 |
+
#include <ATen/ops/clamp_cuda_dispatch.h>
|
| 258 |
+
#include <ATen/ops/clamp_max_cuda_dispatch.h>
|
| 259 |
+
#include <ATen/ops/clamp_min_cuda_dispatch.h>
|
| 260 |
+
#include <ATen/ops/col2im_cuda_dispatch.h>
|
| 261 |
+
#include <ATen/ops/complex_cuda_dispatch.h>
|
| 262 |
+
#include <ATen/ops/conj_physical_cuda_dispatch.h>
|
| 263 |
+
#include <ATen/ops/conv_depthwise3d_cuda_dispatch.h>
|
| 264 |
+
#include <ATen/ops/convolution_backward_cuda_dispatch.h>
|
| 265 |
+
#include <ATen/ops/copysign_cuda_dispatch.h>
|
| 266 |
+
#include <ATen/ops/cos_cuda_dispatch.h>
|
| 267 |
+
#include <ATen/ops/cosh_cuda_dispatch.h>
|
| 268 |
+
#include <ATen/ops/count_nonzero_cuda_dispatch.h>
|
| 269 |
+
#include <ATen/ops/cudnn_affine_grid_generator_cuda_dispatch.h>
|
| 270 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_cuda_dispatch.h>
|
| 271 |
+
#include <ATen/ops/cudnn_batch_norm_cuda_dispatch.h>
|
| 272 |
+
#include <ATen/ops/cudnn_batch_norm_backward_cuda_dispatch.h>
|
| 273 |
+
#include <ATen/ops/cudnn_convolution_cuda_dispatch.h>
|
| 274 |
+
#include <ATen/ops/cudnn_convolution_add_relu_cuda_dispatch.h>
|
| 275 |
+
#include <ATen/ops/cudnn_convolution_relu_cuda_dispatch.h>
|
| 276 |
+
#include <ATen/ops/cudnn_convolution_transpose_cuda_dispatch.h>
|
| 277 |
+
#include <ATen/ops/cudnn_grid_sampler_cuda_dispatch.h>
|
| 278 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_cuda_dispatch.h>
|
| 279 |
+
#include <ATen/ops/cumprod_cuda_dispatch.h>
|
| 280 |
+
#include <ATen/ops/cumsum_cuda_dispatch.h>
|
| 281 |
+
#include <ATen/ops/dense_dim_cuda_dispatch.h>
|
| 282 |
+
#include <ATen/ops/dequantize_cuda_dispatch.h>
|
| 283 |
+
#include <ATen/ops/digamma_cuda_dispatch.h>
|
| 284 |
+
#include <ATen/ops/div_cuda_dispatch.h>
|
| 285 |
+
#include <ATen/ops/dot_cuda_dispatch.h>
|
| 286 |
+
#include <ATen/ops/elu_cuda_dispatch.h>
|
| 287 |
+
#include <ATen/ops/elu_backward_cuda_dispatch.h>
|
| 288 |
+
#include <ATen/ops/embedding_dense_backward_cuda_dispatch.h>
|
| 289 |
+
#include <ATen/ops/embedding_renorm_cuda_dispatch.h>
|
| 290 |
+
#include <ATen/ops/empty_cuda_dispatch.h>
|
| 291 |
+
#include <ATen/ops/empty_strided_cuda_dispatch.h>
|
| 292 |
+
#include <ATen/ops/eq_cuda_dispatch.h>
|
| 293 |
+
#include <ATen/ops/equal_cuda_dispatch.h>
|
| 294 |
+
#include <ATen/ops/erf_cuda_dispatch.h>
|
| 295 |
+
#include <ATen/ops/erfc_cuda_dispatch.h>
|
| 296 |
+
#include <ATen/ops/erfinv_cuda_dispatch.h>
|
| 297 |
+
#include <ATen/ops/exp_cuda_dispatch.h>
|
| 298 |
+
#include <ATen/ops/exp2_cuda_dispatch.h>
|
| 299 |
+
#include <ATen/ops/expm1_cuda_dispatch.h>
|
| 300 |
+
#include <ATen/ops/exponential_cuda_dispatch.h>
|
| 301 |
+
#include <ATen/ops/eye_cuda_dispatch.h>
|
| 302 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_cuda_dispatch.h>
|
| 303 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_cuda_dispatch.h>
|
| 304 |
+
#include <ATen/ops/fill_cuda_dispatch.h>
|
| 305 |
+
#include <ATen/ops/flip_cuda_dispatch.h>
|
| 306 |
+
#include <ATen/ops/floor_cuda_dispatch.h>
|
| 307 |
+
#include <ATen/ops/floor_divide_cuda_dispatch.h>
|
| 308 |
+
#include <ATen/ops/fmax_cuda_dispatch.h>
|
| 309 |
+
#include <ATen/ops/fmin_cuda_dispatch.h>
|
| 310 |
+
#include <ATen/ops/fmod_cuda_dispatch.h>
|
| 311 |
+
#include <ATen/ops/frac_cuda_dispatch.h>
|
| 312 |
+
#include <ATen/ops/fractional_max_pool2d_cuda_dispatch.h>
|
| 313 |
+
#include <ATen/ops/fractional_max_pool2d_backward_cuda_dispatch.h>
|
| 314 |
+
#include <ATen/ops/fractional_max_pool3d_cuda_dispatch.h>
|
| 315 |
+
#include <ATen/ops/fractional_max_pool3d_backward_cuda_dispatch.h>
|
| 316 |
+
#include <ATen/ops/frexp_cuda_dispatch.h>
|
| 317 |
+
#include <ATen/ops/gather_cuda_dispatch.h>
|
| 318 |
+
#include <ATen/ops/gcd_cuda_dispatch.h>
|
| 319 |
+
#include <ATen/ops/ge_cuda_dispatch.h>
|
| 320 |
+
#include <ATen/ops/gelu_cuda_dispatch.h>
|
| 321 |
+
#include <ATen/ops/gelu_backward_cuda_dispatch.h>
|
| 322 |
+
#include <ATen/ops/geometric_cuda_dispatch.h>
|
| 323 |
+
#include <ATen/ops/geqrf_cuda_dispatch.h>
|
| 324 |
+
#include <ATen/ops/glu_cuda_dispatch.h>
|
| 325 |
+
#include <ATen/ops/glu_backward_cuda_dispatch.h>
|
| 326 |
+
#include <ATen/ops/glu_backward_jvp_cuda_dispatch.h>
|
| 327 |
+
#include <ATen/ops/glu_jvp_cuda_dispatch.h>
|
| 328 |
+
#include <ATen/ops/grid_sampler_2d_cuda_dispatch.h>
|
| 329 |
+
#include <ATen/ops/grid_sampler_2d_backward_cuda_dispatch.h>
|
| 330 |
+
#include <ATen/ops/grid_sampler_3d_cuda_dispatch.h>
|
| 331 |
+
#include <ATen/ops/grid_sampler_3d_backward_cuda_dispatch.h>
|
| 332 |
+
#include <ATen/ops/gt_cuda_dispatch.h>
|
| 333 |
+
#include <ATen/ops/hardshrink_cuda_dispatch.h>
|
| 334 |
+
#include <ATen/ops/hardshrink_backward_cuda_dispatch.h>
|
| 335 |
+
#include <ATen/ops/hardsigmoid_cuda_dispatch.h>
|
| 336 |
+
#include <ATen/ops/hardsigmoid_backward_cuda_dispatch.h>
|
| 337 |
+
#include <ATen/ops/hardswish_cuda_dispatch.h>
|
| 338 |
+
#include <ATen/ops/hardswish_backward_cuda_dispatch.h>
|
| 339 |
+
#include <ATen/ops/hardtanh_cuda_dispatch.h>
|
| 340 |
+
#include <ATen/ops/hardtanh_backward_cuda_dispatch.h>
|
| 341 |
+
#include <ATen/ops/heaviside_cuda_dispatch.h>
|
| 342 |
+
#include <ATen/ops/histc_cuda_dispatch.h>
|
| 343 |
+
#include <ATen/ops/huber_loss_cuda_dispatch.h>
|
| 344 |
+
#include <ATen/ops/huber_loss_backward_cuda_dispatch.h>
|
| 345 |
+
#include <ATen/ops/hypot_cuda_dispatch.h>
|
| 346 |
+
#include <ATen/ops/i0_cuda_dispatch.h>
|
| 347 |
+
#include <ATen/ops/igamma_cuda_dispatch.h>
|
| 348 |
+
#include <ATen/ops/igammac_cuda_dispatch.h>
|
| 349 |
+
#include <ATen/ops/im2col_cuda_dispatch.h>
|
| 350 |
+
#include <ATen/ops/index_cuda_dispatch.h>
|
| 351 |
+
#include <ATen/ops/index_add_cuda_dispatch.h>
|
| 352 |
+
#include <ATen/ops/index_copy_cuda_dispatch.h>
|
| 353 |
+
#include <ATen/ops/index_fill_cuda_dispatch.h>
|
| 354 |
+
#include <ATen/ops/index_reduce_cuda_dispatch.h>
|
| 355 |
+
#include <ATen/ops/index_select_cuda_dispatch.h>
|
| 356 |
+
#include <ATen/ops/is_pinned_cuda_dispatch.h>
|
| 357 |
+
#include <ATen/ops/is_set_to_cuda_dispatch.h>
|
| 358 |
+
#include <ATen/ops/isin_cuda_dispatch.h>
|
| 359 |
+
#include <ATen/ops/isnan_cuda_dispatch.h>
|
| 360 |
+
#include <ATen/ops/isneginf_cuda_dispatch.h>
|
| 361 |
+
#include <ATen/ops/isposinf_cuda_dispatch.h>
|
| 362 |
+
#include <ATen/ops/kthvalue_cuda_dispatch.h>
|
| 363 |
+
#include <ATen/ops/lcm_cuda_dispatch.h>
|
| 364 |
+
#include <ATen/ops/le_cuda_dispatch.h>
|
| 365 |
+
#include <ATen/ops/leaky_relu_cuda_dispatch.h>
|
| 366 |
+
#include <ATen/ops/leaky_relu_backward_cuda_dispatch.h>
|
| 367 |
+
#include <ATen/ops/lerp_cuda_dispatch.h>
|
| 368 |
+
#include <ATen/ops/lgamma_cuda_dispatch.h>
|
| 369 |
+
#include <ATen/ops/linalg_cholesky_ex_cuda_dispatch.h>
|
| 370 |
+
#include <ATen/ops/linalg_cross_cuda_dispatch.h>
|
| 371 |
+
#include <ATen/ops/linalg_eig_cuda_dispatch.h>
|
| 372 |
+
#include <ATen/ops/linalg_eigvals_cuda_dispatch.h>
|
| 373 |
+
#include <ATen/ops/linalg_householder_product_cuda_dispatch.h>
|
| 374 |
+
#include <ATen/ops/linalg_inv_ex_cuda_dispatch.h>
|
| 375 |
+
#include <ATen/ops/linalg_ldl_factor_ex_cuda_dispatch.h>
|
| 376 |
+
#include <ATen/ops/linalg_ldl_solve_cuda_dispatch.h>
|
| 377 |
+
#include <ATen/ops/linalg_lstsq_cuda_dispatch.h>
|
| 378 |
+
#include <ATen/ops/linalg_lu_cuda_dispatch.h>
|
| 379 |
+
#include <ATen/ops/linalg_lu_factor_ex_cuda_dispatch.h>
|
| 380 |
+
#include <ATen/ops/linalg_lu_solve_cuda_dispatch.h>
|
| 381 |
+
#include <ATen/ops/linalg_matrix_exp_cuda_dispatch.h>
|
| 382 |
+
#include <ATen/ops/linalg_qr_cuda_dispatch.h>
|
| 383 |
+
#include <ATen/ops/linalg_solve_triangular_cuda_dispatch.h>
|
| 384 |
+
#include <ATen/ops/linalg_vector_norm_cuda_dispatch.h>
|
| 385 |
+
#include <ATen/ops/linspace_cuda_dispatch.h>
|
| 386 |
+
#include <ATen/ops/log_cuda_dispatch.h>
|
| 387 |
+
#include <ATen/ops/log10_cuda_dispatch.h>
|
| 388 |
+
#include <ATen/ops/log1p_cuda_dispatch.h>
|
| 389 |
+
#include <ATen/ops/log2_cuda_dispatch.h>
|
| 390 |
+
#include <ATen/ops/log_normal_cuda_dispatch.h>
|
| 391 |
+
#include <ATen/ops/log_sigmoid_backward_cuda_dispatch.h>
|
| 392 |
+
#include <ATen/ops/log_sigmoid_forward_cuda_dispatch.h>
|
| 393 |
+
#include <ATen/ops/logaddexp_cuda_dispatch.h>
|
| 394 |
+
#include <ATen/ops/logaddexp2_cuda_dispatch.h>
|
| 395 |
+
#include <ATen/ops/logical_and_cuda_dispatch.h>
|
| 396 |
+
#include <ATen/ops/logical_not_cuda_dispatch.h>
|
| 397 |
+
#include <ATen/ops/logical_or_cuda_dispatch.h>
|
| 398 |
+
#include <ATen/ops/logical_xor_cuda_dispatch.h>
|
| 399 |
+
#include <ATen/ops/logit_cuda_dispatch.h>
|
| 400 |
+
#include <ATen/ops/logit_backward_cuda_dispatch.h>
|
| 401 |
+
#include <ATen/ops/logspace_cuda_dispatch.h>
|
| 402 |
+
#include <ATen/ops/lshift_cuda_dispatch.h>
|
| 403 |
+
#include <ATen/ops/lt_cuda_dispatch.h>
|
| 404 |
+
#include <ATen/ops/lu_unpack_cuda_dispatch.h>
|
| 405 |
+
#include <ATen/ops/masked_fill_cuda_dispatch.h>
|
| 406 |
+
#include <ATen/ops/masked_scatter_cuda_dispatch.h>
|
| 407 |
+
#include <ATen/ops/masked_select_cuda_dispatch.h>
|
| 408 |
+
#include <ATen/ops/max_cuda_dispatch.h>
|
| 409 |
+
#include <ATen/ops/max_pool2d_with_indices_cuda_dispatch.h>
|
| 410 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_cuda_dispatch.h>
|
| 411 |
+
#include <ATen/ops/max_pool3d_with_indices_cuda_dispatch.h>
|
| 412 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_cuda_dispatch.h>
|
| 413 |
+
#include <ATen/ops/max_unpool2d_cuda_dispatch.h>
|
| 414 |
+
#include <ATen/ops/max_unpool3d_cuda_dispatch.h>
|
| 415 |
+
#include <ATen/ops/maximum_cuda_dispatch.h>
|
| 416 |
+
#include <ATen/ops/mean_cuda_dispatch.h>
|
| 417 |
+
#include <ATen/ops/median_cuda_dispatch.h>
|
| 418 |
+
#include <ATen/ops/min_cuda_dispatch.h>
|
| 419 |
+
#include <ATen/ops/minimum_cuda_dispatch.h>
|
| 420 |
+
#include <ATen/ops/miopen_batch_norm_cuda_dispatch.h>
|
| 421 |
+
#include <ATen/ops/miopen_batch_norm_backward_cuda_dispatch.h>
|
| 422 |
+
#include <ATen/ops/miopen_convolution_cuda_dispatch.h>
|
| 423 |
+
#include <ATen/ops/miopen_convolution_add_relu_cuda_dispatch.h>
|
| 424 |
+
#include <ATen/ops/miopen_convolution_relu_cuda_dispatch.h>
|
| 425 |
+
#include <ATen/ops/miopen_convolution_transpose_cuda_dispatch.h>
|
| 426 |
+
#include <ATen/ops/miopen_depthwise_convolution_cuda_dispatch.h>
|
| 427 |
+
#include <ATen/ops/miopen_rnn_cuda_dispatch.h>
|
| 428 |
+
#include <ATen/ops/miopen_rnn_backward_cuda_dispatch.h>
|
| 429 |
+
#include <ATen/ops/mish_cuda_dispatch.h>
|
| 430 |
+
#include <ATen/ops/mish_backward_cuda_dispatch.h>
|
| 431 |
+
#include <ATen/ops/mm_cuda_dispatch.h>
|
| 432 |
+
#include <ATen/ops/mode_cuda_dispatch.h>
|
| 433 |
+
#include <ATen/ops/mse_loss_cuda_dispatch.h>
|
| 434 |
+
#include <ATen/ops/mse_loss_backward_cuda_dispatch.h>
|
| 435 |
+
#include <ATen/ops/mul_cuda_dispatch.h>
|
| 436 |
+
#include <ATen/ops/multi_margin_loss_cuda_dispatch.h>
|
| 437 |
+
#include <ATen/ops/multi_margin_loss_backward_cuda_dispatch.h>
|
| 438 |
+
#include <ATen/ops/multilabel_margin_loss_backward_cuda_dispatch.h>
|
| 439 |
+
#include <ATen/ops/multilabel_margin_loss_forward_cuda_dispatch.h>
|
| 440 |
+
#include <ATen/ops/multinomial_cuda_dispatch.h>
|
| 441 |
+
#include <ATen/ops/mvlgamma_cuda_dispatch.h>
|
| 442 |
+
#include <ATen/ops/nan_to_num_cuda_dispatch.h>
|
| 443 |
+
#include <ATen/ops/nanmedian_cuda_dispatch.h>
|
| 444 |
+
#include <ATen/ops/nansum_cuda_dispatch.h>
|
| 445 |
+
#include <ATen/ops/native_batch_norm_cuda_dispatch.h>
|
| 446 |
+
#include <ATen/ops/native_batch_norm_backward_cuda_dispatch.h>
|
| 447 |
+
#include <ATen/ops/native_dropout_cuda_dispatch.h>
|
| 448 |
+
#include <ATen/ops/native_dropout_backward_cuda_dispatch.h>
|
| 449 |
+
#include <ATen/ops/native_group_norm_cuda_dispatch.h>
|
| 450 |
+
#include <ATen/ops/native_group_norm_backward_cuda_dispatch.h>
|
| 451 |
+
#include <ATen/ops/native_layer_norm_cuda_dispatch.h>
|
| 452 |
+
#include <ATen/ops/native_layer_norm_backward_cuda_dispatch.h>
|
| 453 |
+
#include <ATen/ops/ne_cuda_dispatch.h>
|
| 454 |
+
#include <ATen/ops/neg_cuda_dispatch.h>
|
| 455 |
+
#include <ATen/ops/nextafter_cuda_dispatch.h>
|
| 456 |
+
#include <ATen/ops/nll_loss2d_backward_cuda_dispatch.h>
|
| 457 |
+
#include <ATen/ops/nll_loss2d_forward_cuda_dispatch.h>
|
| 458 |
+
#include <ATen/ops/nll_loss_backward_cuda_dispatch.h>
|
| 459 |
+
#include <ATen/ops/nll_loss_forward_cuda_dispatch.h>
|
| 460 |
+
#include <ATen/ops/nonzero_cuda_dispatch.h>
|
| 461 |
+
#include <ATen/ops/norm_cuda_dispatch.h>
|
| 462 |
+
#include <ATen/ops/normal_cuda_dispatch.h>
|
| 463 |
+
#include <ATen/ops/ormqr_cuda_dispatch.h>
|
| 464 |
+
#include <ATen/ops/poisson_cuda_dispatch.h>
|
| 465 |
+
#include <ATen/ops/polar_cuda_dispatch.h>
|
| 466 |
+
#include <ATen/ops/polygamma_cuda_dispatch.h>
|
| 467 |
+
#include <ATen/ops/pow_cuda_dispatch.h>
|
| 468 |
+
#include <ATen/ops/prod_cuda_dispatch.h>
|
| 469 |
+
#include <ATen/ops/put_cuda_dispatch.h>
|
| 470 |
+
#include <ATen/ops/quantize_per_channel_cuda_dispatch.h>
|
| 471 |
+
#include <ATen/ops/quantize_per_tensor_cuda_dispatch.h>
|
| 472 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_cuda_dispatch.h>
|
| 473 |
+
#include <ATen/ops/random_cuda_dispatch.h>
|
| 474 |
+
#include <ATen/ops/randperm_cuda_dispatch.h>
|
| 475 |
+
#include <ATen/ops/range_cuda_dispatch.h>
|
| 476 |
+
#include <ATen/ops/reciprocal_cuda_dispatch.h>
|
| 477 |
+
#include <ATen/ops/record_stream_cuda_dispatch.h>
|
| 478 |
+
#include <ATen/ops/reflection_pad1d_cuda_dispatch.h>
|
| 479 |
+
#include <ATen/ops/reflection_pad1d_backward_cuda_dispatch.h>
|
| 480 |
+
#include <ATen/ops/reflection_pad2d_cuda_dispatch.h>
|
| 481 |
+
#include <ATen/ops/reflection_pad2d_backward_cuda_dispatch.h>
|
| 482 |
+
#include <ATen/ops/reflection_pad3d_cuda_dispatch.h>
|
| 483 |
+
#include <ATen/ops/reflection_pad3d_backward_cuda_dispatch.h>
|
| 484 |
+
#include <ATen/ops/relu_cuda_dispatch.h>
|
| 485 |
+
#include <ATen/ops/remainder_cuda_dispatch.h>
|
| 486 |
+
#include <ATen/ops/renorm_cuda_dispatch.h>
|
| 487 |
+
#include <ATen/ops/repeat_interleave_cuda_dispatch.h>
|
| 488 |
+
#include <ATen/ops/replication_pad1d_cuda_dispatch.h>
|
| 489 |
+
#include <ATen/ops/replication_pad1d_backward_cuda_dispatch.h>
|
| 490 |
+
#include <ATen/ops/replication_pad2d_cuda_dispatch.h>
|
| 491 |
+
#include <ATen/ops/replication_pad2d_backward_cuda_dispatch.h>
|
| 492 |
+
#include <ATen/ops/replication_pad3d_cuda_dispatch.h>
|
| 493 |
+
#include <ATen/ops/replication_pad3d_backward_cuda_dispatch.h>
|
| 494 |
+
#include <ATen/ops/resize_cuda_dispatch.h>
|
| 495 |
+
#include <ATen/ops/roll_cuda_dispatch.h>
|
| 496 |
+
#include <ATen/ops/round_cuda_dispatch.h>
|
| 497 |
+
#include <ATen/ops/rrelu_with_noise_cuda_dispatch.h>
|
| 498 |
+
#include <ATen/ops/rshift_cuda_dispatch.h>
|
| 499 |
+
#include <ATen/ops/rsqrt_cuda_dispatch.h>
|
| 500 |
+
#include <ATen/ops/rsub_cuda_dispatch.h>
|
| 501 |
+
#include <ATen/ops/scatter_cuda_dispatch.h>
|
| 502 |
+
#include <ATen/ops/scatter_add_cuda_dispatch.h>
|
| 503 |
+
#include <ATen/ops/scatter_reduce_cuda_dispatch.h>
|
| 504 |
+
#include <ATen/ops/searchsorted_cuda_dispatch.h>
|
| 505 |
+
#include <ATen/ops/segment_reduce_cuda_dispatch.h>
|
| 506 |
+
#include <ATen/ops/set_cuda_dispatch.h>
|
| 507 |
+
#include <ATen/ops/sgn_cuda_dispatch.h>
|
| 508 |
+
#include <ATen/ops/sigmoid_cuda_dispatch.h>
|
| 509 |
+
#include <ATen/ops/sigmoid_backward_cuda_dispatch.h>
|
| 510 |
+
#include <ATen/ops/sign_cuda_dispatch.h>
|
| 511 |
+
#include <ATen/ops/signbit_cuda_dispatch.h>
|
| 512 |
+
#include <ATen/ops/silu_cuda_dispatch.h>
|
| 513 |
+
#include <ATen/ops/silu_backward_cuda_dispatch.h>
|
| 514 |
+
#include <ATen/ops/sin_cuda_dispatch.h>
|
| 515 |
+
#include <ATen/ops/sinc_cuda_dispatch.h>
|
| 516 |
+
#include <ATen/ops/sinh_cuda_dispatch.h>
|
| 517 |
+
#include <ATen/ops/slow_conv_dilated2d_cuda_dispatch.h>
|
| 518 |
+
#include <ATen/ops/slow_conv_dilated3d_cuda_dispatch.h>
|
| 519 |
+
#include <ATen/ops/slow_conv_transpose2d_cuda_dispatch.h>
|
| 520 |
+
#include <ATen/ops/slow_conv_transpose3d_cuda_dispatch.h>
|
| 521 |
+
#include <ATen/ops/smooth_l1_loss_cuda_dispatch.h>
|
| 522 |
+
#include <ATen/ops/smooth_l1_loss_backward_cuda_dispatch.h>
|
| 523 |
+
#include <ATen/ops/softplus_cuda_dispatch.h>
|
| 524 |
+
#include <ATen/ops/softplus_backward_cuda_dispatch.h>
|
| 525 |
+
#include <ATen/ops/softshrink_cuda_dispatch.h>
|
| 526 |
+
#include <ATen/ops/softshrink_backward_cuda_dispatch.h>
|
| 527 |
+
#include <ATen/ops/sort_cuda_dispatch.h>
|
| 528 |
+
#include <ATen/ops/sparse_dim_cuda_dispatch.h>
|
| 529 |
+
#include <ATen/ops/special_airy_ai_cuda_dispatch.h>
|
| 530 |
+
#include <ATen/ops/special_bessel_j0_cuda_dispatch.h>
|
| 531 |
+
#include <ATen/ops/special_bessel_j1_cuda_dispatch.h>
|
| 532 |
+
#include <ATen/ops/special_bessel_y0_cuda_dispatch.h>
|
| 533 |
+
#include <ATen/ops/special_bessel_y1_cuda_dispatch.h>
|
| 534 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_cuda_dispatch.h>
|
| 535 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_cuda_dispatch.h>
|
| 536 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_cuda_dispatch.h>
|
| 537 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_cuda_dispatch.h>
|
| 538 |
+
#include <ATen/ops/special_entr_cuda_dispatch.h>
|
| 539 |
+
#include <ATen/ops/special_erfcx_cuda_dispatch.h>
|
| 540 |
+
#include <ATen/ops/special_hermite_polynomial_h_cuda_dispatch.h>
|
| 541 |
+
#include <ATen/ops/special_hermite_polynomial_he_cuda_dispatch.h>
|
| 542 |
+
#include <ATen/ops/special_i0e_cuda_dispatch.h>
|
| 543 |
+
#include <ATen/ops/special_i1_cuda_dispatch.h>
|
| 544 |
+
#include <ATen/ops/special_i1e_cuda_dispatch.h>
|
| 545 |
+
#include <ATen/ops/special_laguerre_polynomial_l_cuda_dispatch.h>
|
| 546 |
+
#include <ATen/ops/special_legendre_polynomial_p_cuda_dispatch.h>
|
| 547 |
+
#include <ATen/ops/special_log_ndtr_cuda_dispatch.h>
|
| 548 |
+
#include <ATen/ops/special_modified_bessel_i0_cuda_dispatch.h>
|
| 549 |
+
#include <ATen/ops/special_modified_bessel_i1_cuda_dispatch.h>
|
| 550 |
+
#include <ATen/ops/special_modified_bessel_k0_cuda_dispatch.h>
|
| 551 |
+
#include <ATen/ops/special_modified_bessel_k1_cuda_dispatch.h>
|
| 552 |
+
#include <ATen/ops/special_ndtri_cuda_dispatch.h>
|
| 553 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_cuda_dispatch.h>
|
| 554 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_cuda_dispatch.h>
|
| 555 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_cuda_dispatch.h>
|
| 556 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_cuda_dispatch.h>
|
| 557 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_cuda_dispatch.h>
|
| 558 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_cuda_dispatch.h>
|
| 559 |
+
#include <ATen/ops/special_spherical_bessel_j0_cuda_dispatch.h>
|
| 560 |
+
#include <ATen/ops/special_xlog1py_cuda_dispatch.h>
|
| 561 |
+
#include <ATen/ops/special_zeta_cuda_dispatch.h>
|
| 562 |
+
#include <ATen/ops/split_with_sizes_copy_cuda_dispatch.h>
|
| 563 |
+
#include <ATen/ops/sqrt_cuda_dispatch.h>
|
| 564 |
+
#include <ATen/ops/sspaddmm_cuda_dispatch.h>
|
| 565 |
+
#include <ATen/ops/std_cuda_dispatch.h>
|
| 566 |
+
#include <ATen/ops/std_mean_cuda_dispatch.h>
|
| 567 |
+
#include <ATen/ops/sub_cuda_dispatch.h>
|
| 568 |
+
#include <ATen/ops/sum_cuda_dispatch.h>
|
| 569 |
+
#include <ATen/ops/take_cuda_dispatch.h>
|
| 570 |
+
#include <ATen/ops/tan_cuda_dispatch.h>
|
| 571 |
+
#include <ATen/ops/tanh_cuda_dispatch.h>
|
| 572 |
+
#include <ATen/ops/tanh_backward_cuda_dispatch.h>
|
| 573 |
+
#include <ATen/ops/threshold_cuda_dispatch.h>
|
| 574 |
+
#include <ATen/ops/threshold_backward_cuda_dispatch.h>
|
| 575 |
+
#include <ATen/ops/topk_cuda_dispatch.h>
|
| 576 |
+
#include <ATen/ops/trace_cuda_dispatch.h>
|
| 577 |
+
#include <ATen/ops/triangular_solve_cuda_dispatch.h>
|
| 578 |
+
#include <ATen/ops/tril_cuda_dispatch.h>
|
| 579 |
+
#include <ATen/ops/tril_indices_cuda_dispatch.h>
|
| 580 |
+
#include <ATen/ops/triu_cuda_dispatch.h>
|
| 581 |
+
#include <ATen/ops/triu_indices_cuda_dispatch.h>
|
| 582 |
+
#include <ATen/ops/trunc_cuda_dispatch.h>
|
| 583 |
+
#include <ATen/ops/unfold_cuda_dispatch.h>
|
| 584 |
+
#include <ATen/ops/unfold_backward_cuda_dispatch.h>
|
| 585 |
+
#include <ATen/ops/uniform_cuda_dispatch.h>
|
| 586 |
+
#include <ATen/ops/unique_consecutive_cuda_dispatch.h>
|
| 587 |
+
#include <ATen/ops/unique_dim_cuda_dispatch.h>
|
| 588 |
+
#include <ATen/ops/unique_dim_consecutive_cuda_dispatch.h>
|
| 589 |
+
#include <ATen/ops/upsample_bicubic2d_cuda_dispatch.h>
|
| 590 |
+
#include <ATen/ops/upsample_bicubic2d_backward_cuda_dispatch.h>
|
| 591 |
+
#include <ATen/ops/upsample_bilinear2d_cuda_dispatch.h>
|
| 592 |
+
#include <ATen/ops/upsample_bilinear2d_backward_cuda_dispatch.h>
|
| 593 |
+
#include <ATen/ops/upsample_linear1d_cuda_dispatch.h>
|
| 594 |
+
#include <ATen/ops/upsample_linear1d_backward_cuda_dispatch.h>
|
| 595 |
+
#include <ATen/ops/upsample_nearest1d_cuda_dispatch.h>
|
| 596 |
+
#include <ATen/ops/upsample_nearest1d_backward_cuda_dispatch.h>
|
| 597 |
+
#include <ATen/ops/upsample_nearest2d_cuda_dispatch.h>
|
| 598 |
+
#include <ATen/ops/upsample_nearest2d_backward_cuda_dispatch.h>
|
| 599 |
+
#include <ATen/ops/upsample_nearest3d_cuda_dispatch.h>
|
| 600 |
+
#include <ATen/ops/upsample_nearest3d_backward_cuda_dispatch.h>
|
| 601 |
+
#include <ATen/ops/upsample_trilinear3d_cuda_dispatch.h>
|
| 602 |
+
#include <ATen/ops/upsample_trilinear3d_backward_cuda_dispatch.h>
|
| 603 |
+
#include <ATen/ops/var_cuda_dispatch.h>
|
| 604 |
+
#include <ATen/ops/var_mean_cuda_dispatch.h>
|
| 605 |
+
#include <ATen/ops/vdot_cuda_dispatch.h>
|
| 606 |
+
#include <ATen/ops/view_cuda_dispatch.h>
|
| 607 |
+
#include <ATen/ops/view_as_complex_cuda_dispatch.h>
|
| 608 |
+
#include <ATen/ops/view_as_real_cuda_dispatch.h>
|
| 609 |
+
#include <ATen/ops/where_cuda_dispatch.h>
|
| 610 |
+
#include <ATen/ops/xlogy_cuda_dispatch.h>
|
| 611 |
+
#include <ATen/ops/zero_cuda_dispatch.h>
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CachedTensorUtils.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/ATen.h>
|
| 4 |
+
|
| 5 |
+
namespace at::caching {
|
| 6 |
+
|
| 7 |
+
// Some systems (just cudagraphs currently) will persist a static tensor output
|
| 8 |
+
// whose TensorImpl does not change across iterations. For these tensors caching
|
| 9 |
+
// dtype conversions is invalid. Additionally, there will be an extra reference
|
| 10 |
+
// count to these cached tensors that would prevent buffer inplacing and other
|
| 11 |
+
// checks on tensor uniqueness. If we are not using these systems the enabled
|
| 12 |
+
// flag will be false and we will avoid the hash lookup.
|
| 13 |
+
|
| 14 |
+
TORCH_API bool is_cached_tensor(const at::Tensor& t);
|
| 15 |
+
TORCH_API void add_cached_tensor(const at::Tensor& t);
|
| 16 |
+
TORCH_API void remove_cached_tensor(const at::Tensor& t);
|
| 17 |
+
TORCH_API void set_cached_tensors_enabled(bool enable);
|
| 18 |
+
|
| 19 |
+
// For gradient buffer stealing we will adjust the use count of tensors
|
| 20 |
+
// which are persisted by cudagraphs, just as we need to adjust reference
|
| 21 |
+
// count of tensors with hooks.
|
| 22 |
+
TORCH_API size_t adjusted_use_count(const at::Tensor& t);
|
| 23 |
+
|
| 24 |
+
} // namespace at::caching
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeExplicitAutogradFunctions_inl.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeexplicitautogradnonfunctional_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_addmm_activation_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_conj_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_fw_primal_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_linalg_det_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_linalg_eigh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_linalg_slogdet_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_linalg_solve_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_linalg_svd_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_log_softmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_log_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_make_dual_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_neg_view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_nested_get_values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_reshape_alias_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_softmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_trilinear_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_upsample_nearest_exact1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_upsample_nearest_exact2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_upsample_nearest_exact3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 54 |
+
#include <ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 55 |
+
#include <ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 56 |
+
#include <ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 57 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 58 |
+
#include <ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 59 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 60 |
+
#include <ATen/ops/add_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 61 |
+
#include <ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 62 |
+
#include <ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 63 |
+
#include <ATen/ops/addmm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 64 |
+
#include <ATen/ops/addmv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 65 |
+
#include <ATen/ops/alias_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 66 |
+
#include <ATen/ops/all_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 67 |
+
#include <ATen/ops/amax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 68 |
+
#include <ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 69 |
+
#include <ATen/ops/aminmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 70 |
+
#include <ATen/ops/any_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 71 |
+
#include <ATen/ops/argmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 72 |
+
#include <ATen/ops/argmin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 73 |
+
#include <ATen/ops/as_strided_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 74 |
+
#include <ATen/ops/as_strided_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 75 |
+
#include <ATen/ops/as_strided_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 76 |
+
#include <ATen/ops/asin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 77 |
+
#include <ATen/ops/asinh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 78 |
+
#include <ATen/ops/atan_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 79 |
+
#include <ATen/ops/atan2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 80 |
+
#include <ATen/ops/atanh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 81 |
+
#include <ATen/ops/avg_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 82 |
+
#include <ATen/ops/avg_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 83 |
+
#include <ATen/ops/avg_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 84 |
+
#include <ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 85 |
+
#include <ATen/ops/baddbmm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 86 |
+
#include <ATen/ops/bernoulli_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 87 |
+
#include <ATen/ops/bitwise_and_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 88 |
+
#include <ATen/ops/bitwise_left_shift_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 89 |
+
#include <ATen/ops/bitwise_not_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 90 |
+
#include <ATen/ops/bitwise_or_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 91 |
+
#include <ATen/ops/bitwise_right_shift_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 92 |
+
#include <ATen/ops/bitwise_xor_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 93 |
+
#include <ATen/ops/bmm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 94 |
+
#include <ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 95 |
+
#include <ATen/ops/ccol_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 96 |
+
#include <ATen/ops/ceil_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 97 |
+
#include <ATen/ops/clamp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 98 |
+
#include <ATen/ops/clamp_max_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 99 |
+
#include <ATen/ops/clamp_min_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 100 |
+
#include <ATen/ops/col_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 101 |
+
#include <ATen/ops/copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 102 |
+
#include <ATen/ops/copysign_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 103 |
+
#include <ATen/ops/cos_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 104 |
+
#include <ATen/ops/cosh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 105 |
+
#include <ATen/ops/crow_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 106 |
+
#include <ATen/ops/cumprod_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 107 |
+
#include <ATen/ops/cumsum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 108 |
+
#include <ATen/ops/detach_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 109 |
+
#include <ATen/ops/diag_embed_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 110 |
+
#include <ATen/ops/diagonal_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 111 |
+
#include <ATen/ops/diagonal_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 112 |
+
#include <ATen/ops/digamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 113 |
+
#include <ATen/ops/div_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 114 |
+
#include <ATen/ops/elu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 115 |
+
#include <ATen/ops/elu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 116 |
+
#include <ATen/ops/eq_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 117 |
+
#include <ATen/ops/erf_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 118 |
+
#include <ATen/ops/erfc_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 119 |
+
#include <ATen/ops/erfinv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 120 |
+
#include <ATen/ops/exp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 121 |
+
#include <ATen/ops/exp2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 122 |
+
#include <ATen/ops/expand_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 123 |
+
#include <ATen/ops/expm1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 124 |
+
#include <ATen/ops/floor_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 125 |
+
#include <ATen/ops/fmax_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 126 |
+
#include <ATen/ops/fmin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 127 |
+
#include <ATen/ops/fmod_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 128 |
+
#include <ATen/ops/frac_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 129 |
+
#include <ATen/ops/fractional_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 130 |
+
#include <ATen/ops/fractional_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 131 |
+
#include <ATen/ops/fractional_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 132 |
+
#include <ATen/ops/gather_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 133 |
+
#include <ATen/ops/gcd_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 134 |
+
#include <ATen/ops/ge_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 135 |
+
#include <ATen/ops/gelu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 136 |
+
#include <ATen/ops/gelu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 137 |
+
#include <ATen/ops/glu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 138 |
+
#include <ATen/ops/gt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 139 |
+
#include <ATen/ops/hardshrink_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 140 |
+
#include <ATen/ops/hardshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 141 |
+
#include <ATen/ops/hardsigmoid_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 142 |
+
#include <ATen/ops/hardsigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 143 |
+
#include <ATen/ops/heaviside_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 144 |
+
#include <ATen/ops/hypot_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 145 |
+
#include <ATen/ops/i0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 146 |
+
#include <ATen/ops/igamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 147 |
+
#include <ATen/ops/igammac_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 148 |
+
#include <ATen/ops/index_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 149 |
+
#include <ATen/ops/index_add_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 150 |
+
#include <ATen/ops/index_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 151 |
+
#include <ATen/ops/index_reduce_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 152 |
+
#include <ATen/ops/indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 153 |
+
#include <ATen/ops/isin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 154 |
+
#include <ATen/ops/isneginf_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 155 |
+
#include <ATen/ops/isposinf_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 156 |
+
#include <ATen/ops/lcm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 157 |
+
#include <ATen/ops/le_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 158 |
+
#include <ATen/ops/leaky_relu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 159 |
+
#include <ATen/ops/leaky_relu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 160 |
+
#include <ATen/ops/lerp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 161 |
+
#include <ATen/ops/lgamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 162 |
+
#include <ATen/ops/lift_fresh_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 163 |
+
#include <ATen/ops/linalg_cholesky_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 164 |
+
#include <ATen/ops/linalg_cross_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 165 |
+
#include <ATen/ops/linalg_inv_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 166 |
+
#include <ATen/ops/linalg_ldl_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 167 |
+
#include <ATen/ops/linalg_ldl_solve_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 168 |
+
#include <ATen/ops/linalg_lu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 169 |
+
#include <ATen/ops/linalg_lu_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 170 |
+
#include <ATen/ops/linalg_lu_solve_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 171 |
+
#include <ATen/ops/linalg_pinv_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 172 |
+
#include <ATen/ops/linalg_qr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 173 |
+
#include <ATen/ops/linalg_vector_norm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 174 |
+
#include <ATen/ops/log_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 175 |
+
#include <ATen/ops/log10_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 176 |
+
#include <ATen/ops/log1p_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 177 |
+
#include <ATen/ops/log2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 178 |
+
#include <ATen/ops/logaddexp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 179 |
+
#include <ATen/ops/logaddexp2_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 180 |
+
#include <ATen/ops/logit_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 181 |
+
#include <ATen/ops/logsumexp_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 182 |
+
#include <ATen/ops/lt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 183 |
+
#include <ATen/ops/lu_unpack_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 184 |
+
#include <ATen/ops/max_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 185 |
+
#include <ATen/ops/max_pool2d_with_indices_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 186 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 187 |
+
#include <ATen/ops/maximum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 188 |
+
#include <ATen/ops/mean_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 189 |
+
#include <ATen/ops/min_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 190 |
+
#include <ATen/ops/minimum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 191 |
+
#include <ATen/ops/mish_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 192 |
+
#include <ATen/ops/mm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 193 |
+
#include <ATen/ops/mse_loss_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 194 |
+
#include <ATen/ops/mul_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 195 |
+
#include <ATen/ops/narrow_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 196 |
+
#include <ATen/ops/ne_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 197 |
+
#include <ATen/ops/neg_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 198 |
+
#include <ATen/ops/new_empty_strided_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 199 |
+
#include <ATen/ops/nextafter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 200 |
+
#include <ATen/ops/nll_loss_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 201 |
+
#include <ATen/ops/nll_loss_forward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 202 |
+
#include <ATen/ops/norm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 203 |
+
#include <ATen/ops/permute_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 204 |
+
#include <ATen/ops/pixel_shuffle_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 205 |
+
#include <ATen/ops/pixel_unshuffle_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 206 |
+
#include <ATen/ops/polygamma_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 207 |
+
#include <ATen/ops/pow_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 208 |
+
#include <ATen/ops/prod_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 209 |
+
#include <ATen/ops/reciprocal_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 210 |
+
#include <ATen/ops/reflection_pad1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 211 |
+
#include <ATen/ops/reflection_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 212 |
+
#include <ATen/ops/reflection_pad3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 213 |
+
#include <ATen/ops/reflection_pad3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 214 |
+
#include <ATen/ops/remainder_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 215 |
+
#include <ATen/ops/renorm_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 216 |
+
#include <ATen/ops/replication_pad1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 217 |
+
#include <ATen/ops/replication_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 218 |
+
#include <ATen/ops/replication_pad2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 219 |
+
#include <ATen/ops/replication_pad3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 220 |
+
#include <ATen/ops/round_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 221 |
+
#include <ATen/ops/row_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 222 |
+
#include <ATen/ops/rsqrt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 223 |
+
#include <ATen/ops/scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 224 |
+
#include <ATen/ops/scatter_add_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 225 |
+
#include <ATen/ops/scatter_reduce_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 226 |
+
#include <ATen/ops/select_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 227 |
+
#include <ATen/ops/select_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 228 |
+
#include <ATen/ops/select_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 229 |
+
#include <ATen/ops/sgn_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 230 |
+
#include <ATen/ops/sigmoid_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 231 |
+
#include <ATen/ops/sigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 232 |
+
#include <ATen/ops/sign_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 233 |
+
#include <ATen/ops/signbit_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 234 |
+
#include <ATen/ops/silu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 235 |
+
#include <ATen/ops/silu_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 236 |
+
#include <ATen/ops/sin_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 237 |
+
#include <ATen/ops/sinc_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 238 |
+
#include <ATen/ops/sinh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 239 |
+
#include <ATen/ops/slice_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 240 |
+
#include <ATen/ops/slice_scatter_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 241 |
+
#include <ATen/ops/slow_conv_transpose2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 242 |
+
#include <ATen/ops/smooth_l1_loss_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 243 |
+
#include <ATen/ops/softplus_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 244 |
+
#include <ATen/ops/softplus_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 245 |
+
#include <ATen/ops/softshrink_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 246 |
+
#include <ATen/ops/softshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 247 |
+
#include <ATen/ops/sort_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 248 |
+
#include <ATen/ops/special_airy_ai_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 249 |
+
#include <ATen/ops/special_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 250 |
+
#include <ATen/ops/special_bessel_j1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 251 |
+
#include <ATen/ops/special_bessel_y0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 252 |
+
#include <ATen/ops/special_bessel_y1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 253 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 254 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 255 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 256 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 257 |
+
#include <ATen/ops/special_entr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 258 |
+
#include <ATen/ops/special_erfcx_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 259 |
+
#include <ATen/ops/special_hermite_polynomial_h_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 260 |
+
#include <ATen/ops/special_hermite_polynomial_he_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 261 |
+
#include <ATen/ops/special_i0e_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 262 |
+
#include <ATen/ops/special_i1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 263 |
+
#include <ATen/ops/special_i1e_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 264 |
+
#include <ATen/ops/special_laguerre_polynomial_l_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 265 |
+
#include <ATen/ops/special_legendre_polynomial_p_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 266 |
+
#include <ATen/ops/special_log_ndtr_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 267 |
+
#include <ATen/ops/special_modified_bessel_i0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 268 |
+
#include <ATen/ops/special_modified_bessel_i1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 269 |
+
#include <ATen/ops/special_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 270 |
+
#include <ATen/ops/special_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 271 |
+
#include <ATen/ops/special_ndtri_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 272 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 273 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 274 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 275 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 276 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 277 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 278 |
+
#include <ATen/ops/special_spherical_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 279 |
+
#include <ATen/ops/special_xlog1py_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 280 |
+
#include <ATen/ops/special_zeta_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 281 |
+
#include <ATen/ops/split_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 282 |
+
#include <ATen/ops/split_with_sizes_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 283 |
+
#include <ATen/ops/sqrt_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 284 |
+
#include <ATen/ops/squeeze_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 285 |
+
#include <ATen/ops/sub_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 286 |
+
#include <ATen/ops/sum_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 287 |
+
#include <ATen/ops/t_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 288 |
+
#include <ATen/ops/tan_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 289 |
+
#include <ATen/ops/tanh_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 290 |
+
#include <ATen/ops/tanh_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 291 |
+
#include <ATen/ops/threshold_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 292 |
+
#include <ATen/ops/threshold_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 293 |
+
#include <ATen/ops/topk_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 294 |
+
#include <ATen/ops/transpose_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 295 |
+
#include <ATen/ops/triangular_solve_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 296 |
+
#include <ATen/ops/tril_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 297 |
+
#include <ATen/ops/triu_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 298 |
+
#include <ATen/ops/trunc_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 299 |
+
#include <ATen/ops/unbind_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 300 |
+
#include <ATen/ops/unfold_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 301 |
+
#include <ATen/ops/unsqueeze_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 302 |
+
#include <ATen/ops/upsample_bicubic2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 303 |
+
#include <ATen/ops/upsample_bicubic2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 304 |
+
#include <ATen/ops/upsample_bilinear2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 305 |
+
#include <ATen/ops/upsample_bilinear2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 306 |
+
#include <ATen/ops/upsample_linear1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 307 |
+
#include <ATen/ops/upsample_linear1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 308 |
+
#include <ATen/ops/upsample_nearest1d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 309 |
+
#include <ATen/ops/upsample_nearest1d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 310 |
+
#include <ATen/ops/upsample_nearest2d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 311 |
+
#include <ATen/ops/upsample_nearest2d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 312 |
+
#include <ATen/ops/upsample_nearest3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 313 |
+
#include <ATen/ops/upsample_nearest3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 314 |
+
#include <ATen/ops/upsample_trilinear3d_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 315 |
+
#include <ATen/ops/upsample_trilinear3d_backward_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 316 |
+
#include <ATen/ops/values_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 317 |
+
#include <ATen/ops/view_as_complex_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 318 |
+
#include <ATen/ops/view_as_real_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 319 |
+
#include <ATen/ops/view_copy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 320 |
+
#include <ATen/ops/xlogy_compositeexplicitautogradnonfunctional_dispatch.h>
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h
|
| 3 |
+
|
| 4 |
+
// NB: The implementing C++ file is RegisterDispatchKey.cpp
|
| 5 |
+
|
| 6 |
+
// The only #includes we need are for custom classes that have defaults in the C++ API
|
| 7 |
+
#include <c10/core/MemoryFormat.h>
|
| 8 |
+
#include <c10/core/Scalar.h>
|
| 9 |
+
#include <ATen/core/Reduction.h>
|
| 10 |
+
|
| 11 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 12 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 13 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 14 |
+
Consider including a specific operator from \
|
| 15 |
+
<ATen/ops/{my_operator}_compositeimplicitautograd_dispatch.h>. \
|
| 16 |
+
See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#include <ATen/ops/_add_batch_dim_compositeimplicitautograd_dispatch.h>
|
| 20 |
+
#include <ATen/ops/_assert_tensor_metadata_compositeimplicitautograd_dispatch.h>
|
| 21 |
+
#include <ATen/ops/_autocast_to_full_precision_compositeimplicitautograd_dispatch.h>
|
| 22 |
+
#include <ATen/ops/_autocast_to_reduced_precision_compositeimplicitautograd_dispatch.h>
|
| 23 |
+
#include <ATen/ops/_backward_compositeimplicitautograd_dispatch.h>
|
| 24 |
+
#include <ATen/ops/_batch_norm_impl_index_compositeimplicitautograd_dispatch.h>
|
| 25 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_compositeimplicitautograd_dispatch.h>
|
| 26 |
+
#include <ATen/ops/_cast_Byte_compositeimplicitautograd_dispatch.h>
|
| 27 |
+
#include <ATen/ops/_cast_Char_compositeimplicitautograd_dispatch.h>
|
| 28 |
+
#include <ATen/ops/_cast_Double_compositeimplicitautograd_dispatch.h>
|
| 29 |
+
#include <ATen/ops/_cast_Float_compositeimplicitautograd_dispatch.h>
|
| 30 |
+
#include <ATen/ops/_cast_Half_compositeimplicitautograd_dispatch.h>
|
| 31 |
+
#include <ATen/ops/_cast_Int_compositeimplicitautograd_dispatch.h>
|
| 32 |
+
#include <ATen/ops/_cast_Long_compositeimplicitautograd_dispatch.h>
|
| 33 |
+
#include <ATen/ops/_cast_Short_compositeimplicitautograd_dispatch.h>
|
| 34 |
+
#include <ATen/ops/_choose_qparams_per_tensor_compositeimplicitautograd_dispatch.h>
|
| 35 |
+
#include <ATen/ops/_convolution_compositeimplicitautograd_dispatch.h>
|
| 36 |
+
#include <ATen/ops/_convolution_double_backward_compositeimplicitautograd_dispatch.h>
|
| 37 |
+
#include <ATen/ops/_convolution_mode_compositeimplicitautograd_dispatch.h>
|
| 38 |
+
#include <ATen/ops/_cufft_clear_plan_cache_compositeimplicitautograd_dispatch.h>
|
| 39 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_compositeimplicitautograd_dispatch.h>
|
| 40 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_compositeimplicitautograd_dispatch.h>
|
| 41 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_compositeimplicitautograd_dispatch.h>
|
| 42 |
+
#include <ATen/ops/_debug_has_internal_overlap_compositeimplicitautograd_dispatch.h>
|
| 43 |
+
#include <ATen/ops/_dim_arange_compositeimplicitautograd_dispatch.h>
|
| 44 |
+
#include <ATen/ops/_embedding_bag_backward_compositeimplicitautograd_dispatch.h>
|
| 45 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_compositeimplicitautograd_dispatch.h>
|
| 46 |
+
#include <ATen/ops/_gather_sparse_backward_compositeimplicitautograd_dispatch.h>
|
| 47 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_compositeimplicitautograd_dispatch.h>
|
| 48 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_compositeimplicitautograd_dispatch.h>
|
| 49 |
+
#include <ATen/ops/_is_zerotensor_compositeimplicitautograd_dispatch.h>
|
| 50 |
+
#include <ATen/ops/_lu_with_info_compositeimplicitautograd_dispatch.h>
|
| 51 |
+
#include <ATen/ops/_nnpack_available_compositeimplicitautograd_dispatch.h>
|
| 52 |
+
#include <ATen/ops/_pack_padded_sequence_backward_compositeimplicitautograd_dispatch.h>
|
| 53 |
+
#include <ATen/ops/_pad_circular_compositeimplicitautograd_dispatch.h>
|
| 54 |
+
#include <ATen/ops/_pad_enum_compositeimplicitautograd_dispatch.h>
|
| 55 |
+
#include <ATen/ops/_pad_packed_sequence_compositeimplicitautograd_dispatch.h>
|
| 56 |
+
#include <ATen/ops/_propagate_xla_data_compositeimplicitautograd_dispatch.h>
|
| 57 |
+
#include <ATen/ops/_remove_batch_dim_compositeimplicitautograd_dispatch.h>
|
| 58 |
+
#include <ATen/ops/_reshape_from_tensor_compositeimplicitautograd_dispatch.h>
|
| 59 |
+
#include <ATen/ops/_rowwise_prune_compositeimplicitautograd_dispatch.h>
|
| 60 |
+
#include <ATen/ops/_saturate_weight_to_fp16_compositeimplicitautograd_dispatch.h>
|
| 61 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_compositeimplicitautograd_dispatch.h>
|
| 62 |
+
#include <ATen/ops/_shape_as_tensor_compositeimplicitautograd_dispatch.h>
|
| 63 |
+
#include <ATen/ops/_sobol_engine_draw_compositeimplicitautograd_dispatch.h>
|
| 64 |
+
#include <ATen/ops/_sobol_engine_ff_compositeimplicitautograd_dispatch.h>
|
| 65 |
+
#include <ATen/ops/_sobol_engine_initialize_state_compositeimplicitautograd_dispatch.h>
|
| 66 |
+
#include <ATen/ops/_sobol_engine_scramble_compositeimplicitautograd_dispatch.h>
|
| 67 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 68 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 69 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 70 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 71 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 72 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_compositeimplicitautograd_dispatch.h>
|
| 73 |
+
#include <ATen/ops/_sparse_log_softmax_compositeimplicitautograd_dispatch.h>
|
| 74 |
+
#include <ATen/ops/_sparse_mm_compositeimplicitautograd_dispatch.h>
|
| 75 |
+
#include <ATen/ops/_sparse_softmax_compositeimplicitautograd_dispatch.h>
|
| 76 |
+
#include <ATen/ops/_sparse_sum_compositeimplicitautograd_dispatch.h>
|
| 77 |
+
#include <ATen/ops/_test_ambiguous_defaults_compositeimplicitautograd_dispatch.h>
|
| 78 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_compositeimplicitautograd_dispatch.h>
|
| 79 |
+
#include <ATen/ops/_test_check_tensor_compositeimplicitautograd_dispatch.h>
|
| 80 |
+
#include <ATen/ops/_test_serialization_subcmul_compositeimplicitautograd_dispatch.h>
|
| 81 |
+
#include <ATen/ops/_test_string_default_compositeimplicitautograd_dispatch.h>
|
| 82 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_compositeimplicitautograd_dispatch.h>
|
| 83 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_compositeimplicitautograd_dispatch.h>
|
| 84 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_compositeimplicitautograd_dispatch.h>
|
| 85 |
+
#include <ATen/ops/_to_cpu_compositeimplicitautograd_dispatch.h>
|
| 86 |
+
#include <ATen/ops/_unpack_dual_compositeimplicitautograd_dispatch.h>
|
| 87 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_compositeimplicitautograd_dispatch.h>
|
| 88 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_compositeimplicitautograd_dispatch.h>
|
| 89 |
+
#include <ATen/ops/_upsample_nearest_exact1d_compositeimplicitautograd_dispatch.h>
|
| 90 |
+
#include <ATen/ops/_upsample_nearest_exact2d_compositeimplicitautograd_dispatch.h>
|
| 91 |
+
#include <ATen/ops/_upsample_nearest_exact3d_compositeimplicitautograd_dispatch.h>
|
| 92 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_compositeimplicitautograd_dispatch.h>
|
| 93 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 94 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 95 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 96 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 97 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 98 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_compositeimplicitautograd_dispatch.h>
|
| 99 |
+
#include <ATen/ops/_version_compositeimplicitautograd_dispatch.h>
|
| 100 |
+
#include <ATen/ops/_weight_norm_compositeimplicitautograd_dispatch.h>
|
| 101 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_compositeimplicitautograd_dispatch.h>
|
| 102 |
+
#include <ATen/ops/absolute_compositeimplicitautograd_dispatch.h>
|
| 103 |
+
#include <ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h>
|
| 104 |
+
#include <ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h>
|
| 105 |
+
#include <ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h>
|
| 106 |
+
#include <ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h>
|
| 107 |
+
#include <ATen/ops/adjoint_compositeimplicitautograd_dispatch.h>
|
| 108 |
+
#include <ATen/ops/affine_grid_generator_backward_compositeimplicitautograd_dispatch.h>
|
| 109 |
+
#include <ATen/ops/align_as_compositeimplicitautograd_dispatch.h>
|
| 110 |
+
#include <ATen/ops/align_tensors_compositeimplicitautograd_dispatch.h>
|
| 111 |
+
#include <ATen/ops/align_to_compositeimplicitautograd_dispatch.h>
|
| 112 |
+
#include <ATen/ops/all_compositeimplicitautograd_dispatch.h>
|
| 113 |
+
#include <ATen/ops/alpha_dropout_compositeimplicitautograd_dispatch.h>
|
| 114 |
+
#include <ATen/ops/and_compositeimplicitautograd_dispatch.h>
|
| 115 |
+
#include <ATen/ops/any_compositeimplicitautograd_dispatch.h>
|
| 116 |
+
#include <ATen/ops/arccos_compositeimplicitautograd_dispatch.h>
|
| 117 |
+
#include <ATen/ops/arccosh_compositeimplicitautograd_dispatch.h>
|
| 118 |
+
#include <ATen/ops/arcsin_compositeimplicitautograd_dispatch.h>
|
| 119 |
+
#include <ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h>
|
| 120 |
+
#include <ATen/ops/arctan_compositeimplicitautograd_dispatch.h>
|
| 121 |
+
#include <ATen/ops/arctan2_compositeimplicitautograd_dispatch.h>
|
| 122 |
+
#include <ATen/ops/arctanh_compositeimplicitautograd_dispatch.h>
|
| 123 |
+
#include <ATen/ops/argsort_compositeimplicitautograd_dispatch.h>
|
| 124 |
+
#include <ATen/ops/argwhere_compositeimplicitautograd_dispatch.h>
|
| 125 |
+
#include <ATen/ops/atleast_1d_compositeimplicitautograd_dispatch.h>
|
| 126 |
+
#include <ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h>
|
| 127 |
+
#include <ATen/ops/atleast_3d_compositeimplicitautograd_dispatch.h>
|
| 128 |
+
#include <ATen/ops/avg_pool1d_compositeimplicitautograd_dispatch.h>
|
| 129 |
+
#include <ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h>
|
| 130 |
+
#include <ATen/ops/bilinear_compositeimplicitautograd_dispatch.h>
|
| 131 |
+
#include <ATen/ops/broadcast_tensors_compositeimplicitautograd_dispatch.h>
|
| 132 |
+
#include <ATen/ops/broadcast_to_compositeimplicitautograd_dispatch.h>
|
| 133 |
+
#include <ATen/ops/can_cast_compositeimplicitautograd_dispatch.h>
|
| 134 |
+
#include <ATen/ops/cartesian_prod_compositeimplicitautograd_dispatch.h>
|
| 135 |
+
#include <ATen/ops/cat_compositeimplicitautograd_dispatch.h>
|
| 136 |
+
#include <ATen/ops/cdist_compositeimplicitautograd_dispatch.h>
|
| 137 |
+
#include <ATen/ops/chain_matmul_compositeimplicitautograd_dispatch.h>
|
| 138 |
+
#include <ATen/ops/chalf_compositeimplicitautograd_dispatch.h>
|
| 139 |
+
#include <ATen/ops/choose_qparams_optimized_compositeimplicitautograd_dispatch.h>
|
| 140 |
+
#include <ATen/ops/chunk_compositeimplicitautograd_dispatch.h>
|
| 141 |
+
#include <ATen/ops/clip_compositeimplicitautograd_dispatch.h>
|
| 142 |
+
#include <ATen/ops/coalesce_compositeimplicitautograd_dispatch.h>
|
| 143 |
+
#include <ATen/ops/column_stack_compositeimplicitautograd_dispatch.h>
|
| 144 |
+
#include <ATen/ops/combinations_compositeimplicitautograd_dispatch.h>
|
| 145 |
+
#include <ATen/ops/concat_compositeimplicitautograd_dispatch.h>
|
| 146 |
+
#include <ATen/ops/concatenate_compositeimplicitautograd_dispatch.h>
|
| 147 |
+
#include <ATen/ops/conj_compositeimplicitautograd_dispatch.h>
|
| 148 |
+
#include <ATen/ops/conj_physical_compositeimplicitautograd_dispatch.h>
|
| 149 |
+
#include <ATen/ops/contiguous_compositeimplicitautograd_dispatch.h>
|
| 150 |
+
#include <ATen/ops/conv1d_compositeimplicitautograd_dispatch.h>
|
| 151 |
+
#include <ATen/ops/conv2d_compositeimplicitautograd_dispatch.h>
|
| 152 |
+
#include <ATen/ops/conv3d_compositeimplicitautograd_dispatch.h>
|
| 153 |
+
#include <ATen/ops/conv_tbc_backward_compositeimplicitautograd_dispatch.h>
|
| 154 |
+
#include <ATen/ops/conv_transpose1d_compositeimplicitautograd_dispatch.h>
|
| 155 |
+
#include <ATen/ops/conv_transpose2d_compositeimplicitautograd_dispatch.h>
|
| 156 |
+
#include <ATen/ops/conv_transpose3d_compositeimplicitautograd_dispatch.h>
|
| 157 |
+
#include <ATen/ops/corrcoef_compositeimplicitautograd_dispatch.h>
|
| 158 |
+
#include <ATen/ops/cosine_embedding_loss_compositeimplicitautograd_dispatch.h>
|
| 159 |
+
#include <ATen/ops/cosine_similarity_compositeimplicitautograd_dispatch.h>
|
| 160 |
+
#include <ATen/ops/cov_compositeimplicitautograd_dispatch.h>
|
| 161 |
+
#include <ATen/ops/cross_compositeimplicitautograd_dispatch.h>
|
| 162 |
+
#include <ATen/ops/cross_entropy_loss_compositeimplicitautograd_dispatch.h>
|
| 163 |
+
#include <ATen/ops/ctc_loss_compositeimplicitautograd_dispatch.h>
|
| 164 |
+
#include <ATen/ops/cudnn_is_acceptable_compositeimplicitautograd_dispatch.h>
|
| 165 |
+
#include <ATen/ops/cummax_compositeimplicitautograd_dispatch.h>
|
| 166 |
+
#include <ATen/ops/cummaxmin_backward_compositeimplicitautograd_dispatch.h>
|
| 167 |
+
#include <ATen/ops/cummin_compositeimplicitautograd_dispatch.h>
|
| 168 |
+
#include <ATen/ops/cumprod_compositeimplicitautograd_dispatch.h>
|
| 169 |
+
#include <ATen/ops/cumprod_backward_compositeimplicitautograd_dispatch.h>
|
| 170 |
+
#include <ATen/ops/cumsum_compositeimplicitautograd_dispatch.h>
|
| 171 |
+
#include <ATen/ops/cumulative_trapezoid_compositeimplicitautograd_dispatch.h>
|
| 172 |
+
#include <ATen/ops/data_compositeimplicitautograd_dispatch.h>
|
| 173 |
+
#include <ATen/ops/det_compositeimplicitautograd_dispatch.h>
|
| 174 |
+
#include <ATen/ops/diag_compositeimplicitautograd_dispatch.h>
|
| 175 |
+
#include <ATen/ops/diagflat_compositeimplicitautograd_dispatch.h>
|
| 176 |
+
#include <ATen/ops/diagonal_compositeimplicitautograd_dispatch.h>
|
| 177 |
+
#include <ATen/ops/diff_compositeimplicitautograd_dispatch.h>
|
| 178 |
+
#include <ATen/ops/divide_compositeimplicitautograd_dispatch.h>
|
| 179 |
+
#include <ATen/ops/dropout_compositeimplicitautograd_dispatch.h>
|
| 180 |
+
#include <ATen/ops/dsplit_compositeimplicitautograd_dispatch.h>
|
| 181 |
+
#include <ATen/ops/dstack_compositeimplicitautograd_dispatch.h>
|
| 182 |
+
#include <ATen/ops/einsum_compositeimplicitautograd_dispatch.h>
|
| 183 |
+
#include <ATen/ops/embedding_backward_compositeimplicitautograd_dispatch.h>
|
| 184 |
+
#include <ATen/ops/embedding_bag_compositeimplicitautograd_dispatch.h>
|
| 185 |
+
#include <ATen/ops/embedding_sparse_backward_compositeimplicitautograd_dispatch.h>
|
| 186 |
+
#include <ATen/ops/empty_compositeimplicitautograd_dispatch.h>
|
| 187 |
+
#include <ATen/ops/expand_as_compositeimplicitautograd_dispatch.h>
|
| 188 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_compositeimplicitautograd_dispatch.h>
|
| 189 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_compositeimplicitautograd_dispatch.h>
|
| 190 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_compositeimplicitautograd_dispatch.h>
|
| 191 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_compositeimplicitautograd_dispatch.h>
|
| 192 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_compositeimplicitautograd_dispatch.h>
|
| 193 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_compositeimplicitautograd_dispatch.h>
|
| 194 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_compositeimplicitautograd_dispatch.h>
|
| 195 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_compositeimplicitautograd_dispatch.h>
|
| 196 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_compositeimplicitautograd_dispatch.h>
|
| 197 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_compositeimplicitautograd_dispatch.h>
|
| 198 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_compositeimplicitautograd_dispatch.h>
|
| 199 |
+
#include <ATen/ops/feature_alpha_dropout_compositeimplicitautograd_dispatch.h>
|
| 200 |
+
#include <ATen/ops/feature_dropout_compositeimplicitautograd_dispatch.h>
|
| 201 |
+
#include <ATen/ops/fft_fft_compositeimplicitautograd_dispatch.h>
|
| 202 |
+
#include <ATen/ops/fft_fft2_compositeimplicitautograd_dispatch.h>
|
| 203 |
+
#include <ATen/ops/fft_fftn_compositeimplicitautograd_dispatch.h>
|
| 204 |
+
#include <ATen/ops/fft_fftshift_compositeimplicitautograd_dispatch.h>
|
| 205 |
+
#include <ATen/ops/fft_hfft_compositeimplicitautograd_dispatch.h>
|
| 206 |
+
#include <ATen/ops/fft_hfft2_compositeimplicitautograd_dispatch.h>
|
| 207 |
+
#include <ATen/ops/fft_hfftn_compositeimplicitautograd_dispatch.h>
|
| 208 |
+
#include <ATen/ops/fft_ifft_compositeimplicitautograd_dispatch.h>
|
| 209 |
+
#include <ATen/ops/fft_ifft2_compositeimplicitautograd_dispatch.h>
|
| 210 |
+
#include <ATen/ops/fft_ifftn_compositeimplicitautograd_dispatch.h>
|
| 211 |
+
#include <ATen/ops/fft_ifftshift_compositeimplicitautograd_dispatch.h>
|
| 212 |
+
#include <ATen/ops/fft_ihfft_compositeimplicitautograd_dispatch.h>
|
| 213 |
+
#include <ATen/ops/fft_ihfft2_compositeimplicitautograd_dispatch.h>
|
| 214 |
+
#include <ATen/ops/fft_ihfftn_compositeimplicitautograd_dispatch.h>
|
| 215 |
+
#include <ATen/ops/fft_irfft_compositeimplicitautograd_dispatch.h>
|
| 216 |
+
#include <ATen/ops/fft_irfft2_compositeimplicitautograd_dispatch.h>
|
| 217 |
+
#include <ATen/ops/fft_irfftn_compositeimplicitautograd_dispatch.h>
|
| 218 |
+
#include <ATen/ops/fft_rfft_compositeimplicitautograd_dispatch.h>
|
| 219 |
+
#include <ATen/ops/fft_rfft2_compositeimplicitautograd_dispatch.h>
|
| 220 |
+
#include <ATen/ops/fft_rfftn_compositeimplicitautograd_dispatch.h>
|
| 221 |
+
#include <ATen/ops/fill_diagonal_compositeimplicitautograd_dispatch.h>
|
| 222 |
+
#include <ATen/ops/fix_compositeimplicitautograd_dispatch.h>
|
| 223 |
+
#include <ATen/ops/flatten_compositeimplicitautograd_dispatch.h>
|
| 224 |
+
#include <ATen/ops/flatten_dense_tensors_compositeimplicitautograd_dispatch.h>
|
| 225 |
+
#include <ATen/ops/fliplr_compositeimplicitautograd_dispatch.h>
|
| 226 |
+
#include <ATen/ops/flipud_compositeimplicitautograd_dispatch.h>
|
| 227 |
+
#include <ATen/ops/float_power_compositeimplicitautograd_dispatch.h>
|
| 228 |
+
#include <ATen/ops/frobenius_norm_compositeimplicitautograd_dispatch.h>
|
| 229 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_compositeimplicitautograd_dispatch.h>
|
| 230 |
+
#include <ATen/ops/gather_compositeimplicitautograd_dispatch.h>
|
| 231 |
+
#include <ATen/ops/gather_backward_compositeimplicitautograd_dispatch.h>
|
| 232 |
+
#include <ATen/ops/ger_compositeimplicitautograd_dispatch.h>
|
| 233 |
+
#include <ATen/ops/gradient_compositeimplicitautograd_dispatch.h>
|
| 234 |
+
#include <ATen/ops/greater_compositeimplicitautograd_dispatch.h>
|
| 235 |
+
#include <ATen/ops/greater_equal_compositeimplicitautograd_dispatch.h>
|
| 236 |
+
#include <ATen/ops/grid_sampler_compositeimplicitautograd_dispatch.h>
|
| 237 |
+
#include <ATen/ops/group_norm_compositeimplicitautograd_dispatch.h>
|
| 238 |
+
#include <ATen/ops/gru_compositeimplicitautograd_dispatch.h>
|
| 239 |
+
#include <ATen/ops/gru_cell_compositeimplicitautograd_dispatch.h>
|
| 240 |
+
#include <ATen/ops/hinge_embedding_loss_compositeimplicitautograd_dispatch.h>
|
| 241 |
+
#include <ATen/ops/histogramdd_compositeimplicitautograd_dispatch.h>
|
| 242 |
+
#include <ATen/ops/hsplit_compositeimplicitautograd_dispatch.h>
|
| 243 |
+
#include <ATen/ops/hstack_compositeimplicitautograd_dispatch.h>
|
| 244 |
+
#include <ATen/ops/imag_compositeimplicitautograd_dispatch.h>
|
| 245 |
+
#include <ATen/ops/index_add_compositeimplicitautograd_dispatch.h>
|
| 246 |
+
#include <ATen/ops/index_copy_compositeimplicitautograd_dispatch.h>
|
| 247 |
+
#include <ATen/ops/index_fill_compositeimplicitautograd_dispatch.h>
|
| 248 |
+
#include <ATen/ops/index_select_compositeimplicitautograd_dispatch.h>
|
| 249 |
+
#include <ATen/ops/index_select_backward_compositeimplicitautograd_dispatch.h>
|
| 250 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_compositeimplicitautograd_dispatch.h>
|
| 251 |
+
#include <ATen/ops/inner_compositeimplicitautograd_dispatch.h>
|
| 252 |
+
#include <ATen/ops/instance_norm_compositeimplicitautograd_dispatch.h>
|
| 253 |
+
#include <ATen/ops/inverse_compositeimplicitautograd_dispatch.h>
|
| 254 |
+
#include <ATen/ops/is_complex_compositeimplicitautograd_dispatch.h>
|
| 255 |
+
#include <ATen/ops/is_conj_compositeimplicitautograd_dispatch.h>
|
| 256 |
+
#include <ATen/ops/is_distributed_compositeimplicitautograd_dispatch.h>
|
| 257 |
+
#include <ATen/ops/is_floating_point_compositeimplicitautograd_dispatch.h>
|
| 258 |
+
#include <ATen/ops/is_inference_compositeimplicitautograd_dispatch.h>
|
| 259 |
+
#include <ATen/ops/is_leaf_compositeimplicitautograd_dispatch.h>
|
| 260 |
+
#include <ATen/ops/is_neg_compositeimplicitautograd_dispatch.h>
|
| 261 |
+
#include <ATen/ops/is_nonzero_compositeimplicitautograd_dispatch.h>
|
| 262 |
+
#include <ATen/ops/is_signed_compositeimplicitautograd_dispatch.h>
|
| 263 |
+
#include <ATen/ops/is_vulkan_available_compositeimplicitautograd_dispatch.h>
|
| 264 |
+
#include <ATen/ops/isclose_compositeimplicitautograd_dispatch.h>
|
| 265 |
+
#include <ATen/ops/isfinite_compositeimplicitautograd_dispatch.h>
|
| 266 |
+
#include <ATen/ops/isreal_compositeimplicitautograd_dispatch.h>
|
| 267 |
+
#include <ATen/ops/istft_compositeimplicitautograd_dispatch.h>
|
| 268 |
+
#include <ATen/ops/item_compositeimplicitautograd_dispatch.h>
|
| 269 |
+
#include <ATen/ops/kl_div_compositeimplicitautograd_dispatch.h>
|
| 270 |
+
#include <ATen/ops/kron_compositeimplicitautograd_dispatch.h>
|
| 271 |
+
#include <ATen/ops/kthvalue_compositeimplicitautograd_dispatch.h>
|
| 272 |
+
#include <ATen/ops/l1_loss_compositeimplicitautograd_dispatch.h>
|
| 273 |
+
#include <ATen/ops/layer_norm_compositeimplicitautograd_dispatch.h>
|
| 274 |
+
#include <ATen/ops/ldexp_compositeimplicitautograd_dispatch.h>
|
| 275 |
+
#include <ATen/ops/less_compositeimplicitautograd_dispatch.h>
|
| 276 |
+
#include <ATen/ops/less_equal_compositeimplicitautograd_dispatch.h>
|
| 277 |
+
#include <ATen/ops/linalg_cholesky_compositeimplicitautograd_dispatch.h>
|
| 278 |
+
#include <ATen/ops/linalg_cond_compositeimplicitautograd_dispatch.h>
|
| 279 |
+
#include <ATen/ops/linalg_det_compositeimplicitautograd_dispatch.h>
|
| 280 |
+
#include <ATen/ops/linalg_diagonal_compositeimplicitautograd_dispatch.h>
|
| 281 |
+
#include <ATen/ops/linalg_eigh_compositeimplicitautograd_dispatch.h>
|
| 282 |
+
#include <ATen/ops/linalg_eigvals_compositeimplicitautograd_dispatch.h>
|
| 283 |
+
#include <ATen/ops/linalg_eigvalsh_compositeimplicitautograd_dispatch.h>
|
| 284 |
+
#include <ATen/ops/linalg_inv_compositeimplicitautograd_dispatch.h>
|
| 285 |
+
#include <ATen/ops/linalg_ldl_factor_compositeimplicitautograd_dispatch.h>
|
| 286 |
+
#include <ATen/ops/linalg_lu_factor_compositeimplicitautograd_dispatch.h>
|
| 287 |
+
#include <ATen/ops/linalg_matmul_compositeimplicitautograd_dispatch.h>
|
| 288 |
+
#include <ATen/ops/linalg_matrix_norm_compositeimplicitautograd_dispatch.h>
|
| 289 |
+
#include <ATen/ops/linalg_matrix_power_compositeimplicitautograd_dispatch.h>
|
| 290 |
+
#include <ATen/ops/linalg_matrix_rank_compositeimplicitautograd_dispatch.h>
|
| 291 |
+
#include <ATen/ops/linalg_multi_dot_compositeimplicitautograd_dispatch.h>
|
| 292 |
+
#include <ATen/ops/linalg_norm_compositeimplicitautograd_dispatch.h>
|
| 293 |
+
#include <ATen/ops/linalg_pinv_compositeimplicitautograd_dispatch.h>
|
| 294 |
+
#include <ATen/ops/linalg_slogdet_compositeimplicitautograd_dispatch.h>
|
| 295 |
+
#include <ATen/ops/linalg_solve_compositeimplicitautograd_dispatch.h>
|
| 296 |
+
#include <ATen/ops/linalg_solve_ex_compositeimplicitautograd_dispatch.h>
|
| 297 |
+
#include <ATen/ops/linalg_svd_compositeimplicitautograd_dispatch.h>
|
| 298 |
+
#include <ATen/ops/linalg_svdvals_compositeimplicitautograd_dispatch.h>
|
| 299 |
+
#include <ATen/ops/linalg_tensorinv_compositeimplicitautograd_dispatch.h>
|
| 300 |
+
#include <ATen/ops/linalg_tensorsolve_compositeimplicitautograd_dispatch.h>
|
| 301 |
+
#include <ATen/ops/linalg_vander_compositeimplicitautograd_dispatch.h>
|
| 302 |
+
#include <ATen/ops/linalg_vecdot_compositeimplicitautograd_dispatch.h>
|
| 303 |
+
#include <ATen/ops/linear_compositeimplicitautograd_dispatch.h>
|
| 304 |
+
#include <ATen/ops/log_sigmoid_compositeimplicitautograd_dispatch.h>
|
| 305 |
+
#include <ATen/ops/log_softmax_compositeimplicitautograd_dispatch.h>
|
| 306 |
+
#include <ATen/ops/logcumsumexp_compositeimplicitautograd_dispatch.h>
|
| 307 |
+
#include <ATen/ops/logdet_compositeimplicitautograd_dispatch.h>
|
| 308 |
+
#include <ATen/ops/logsumexp_compositeimplicitautograd_dispatch.h>
|
| 309 |
+
#include <ATen/ops/lstm_compositeimplicitautograd_dispatch.h>
|
| 310 |
+
#include <ATen/ops/lstm_cell_compositeimplicitautograd_dispatch.h>
|
| 311 |
+
#include <ATen/ops/lu_solve_compositeimplicitautograd_dispatch.h>
|
| 312 |
+
#include <ATen/ops/mH_compositeimplicitautograd_dispatch.h>
|
| 313 |
+
#include <ATen/ops/mT_compositeimplicitautograd_dispatch.h>
|
| 314 |
+
#include <ATen/ops/margin_ranking_loss_compositeimplicitautograd_dispatch.h>
|
| 315 |
+
#include <ATen/ops/masked_select_backward_compositeimplicitautograd_dispatch.h>
|
| 316 |
+
#include <ATen/ops/matmul_compositeimplicitautograd_dispatch.h>
|
| 317 |
+
#include <ATen/ops/matrix_H_compositeimplicitautograd_dispatch.h>
|
| 318 |
+
#include <ATen/ops/matrix_exp_compositeimplicitautograd_dispatch.h>
|
| 319 |
+
#include <ATen/ops/matrix_exp_backward_compositeimplicitautograd_dispatch.h>
|
| 320 |
+
#include <ATen/ops/matrix_power_compositeimplicitautograd_dispatch.h>
|
| 321 |
+
#include <ATen/ops/max_compositeimplicitautograd_dispatch.h>
|
| 322 |
+
#include <ATen/ops/max_pool1d_compositeimplicitautograd_dispatch.h>
|
| 323 |
+
#include <ATen/ops/max_pool1d_with_indices_compositeimplicitautograd_dispatch.h>
|
| 324 |
+
#include <ATen/ops/max_pool2d_compositeimplicitautograd_dispatch.h>
|
| 325 |
+
#include <ATen/ops/max_pool3d_compositeimplicitautograd_dispatch.h>
|
| 326 |
+
#include <ATen/ops/mean_compositeimplicitautograd_dispatch.h>
|
| 327 |
+
#include <ATen/ops/median_compositeimplicitautograd_dispatch.h>
|
| 328 |
+
#include <ATen/ops/meshgrid_compositeimplicitautograd_dispatch.h>
|
| 329 |
+
#include <ATen/ops/min_compositeimplicitautograd_dispatch.h>
|
| 330 |
+
#include <ATen/ops/mish_backward_compositeimplicitautograd_dispatch.h>
|
| 331 |
+
#include <ATen/ops/mode_compositeimplicitautograd_dispatch.h>
|
| 332 |
+
#include <ATen/ops/moveaxis_compositeimplicitautograd_dispatch.h>
|
| 333 |
+
#include <ATen/ops/movedim_compositeimplicitautograd_dispatch.h>
|
| 334 |
+
#include <ATen/ops/msort_compositeimplicitautograd_dispatch.h>
|
| 335 |
+
#include <ATen/ops/multilabel_margin_loss_compositeimplicitautograd_dispatch.h>
|
| 336 |
+
#include <ATen/ops/multiply_compositeimplicitautograd_dispatch.h>
|
| 337 |
+
#include <ATen/ops/nanmean_compositeimplicitautograd_dispatch.h>
|
| 338 |
+
#include <ATen/ops/nanmedian_compositeimplicitautograd_dispatch.h>
|
| 339 |
+
#include <ATen/ops/nanquantile_compositeimplicitautograd_dispatch.h>
|
| 340 |
+
#include <ATen/ops/narrow_compositeimplicitautograd_dispatch.h>
|
| 341 |
+
#include <ATen/ops/native_channel_shuffle_compositeimplicitautograd_dispatch.h>
|
| 342 |
+
#include <ATen/ops/negative_compositeimplicitautograd_dispatch.h>
|
| 343 |
+
#include <ATen/ops/nested_to_padded_tensor_compositeimplicitautograd_dispatch.h>
|
| 344 |
+
#include <ATen/ops/nll_loss_compositeimplicitautograd_dispatch.h>
|
| 345 |
+
#include <ATen/ops/nll_loss2d_compositeimplicitautograd_dispatch.h>
|
| 346 |
+
#include <ATen/ops/nll_loss_nd_compositeimplicitautograd_dispatch.h>
|
| 347 |
+
#include <ATen/ops/nonzero_numpy_compositeimplicitautograd_dispatch.h>
|
| 348 |
+
#include <ATen/ops/norm_compositeimplicitautograd_dispatch.h>
|
| 349 |
+
#include <ATen/ops/norm_except_dim_compositeimplicitautograd_dispatch.h>
|
| 350 |
+
#include <ATen/ops/not_equal_compositeimplicitautograd_dispatch.h>
|
| 351 |
+
#include <ATen/ops/nuclear_norm_compositeimplicitautograd_dispatch.h>
|
| 352 |
+
#include <ATen/ops/numpy_T_compositeimplicitautograd_dispatch.h>
|
| 353 |
+
#include <ATen/ops/one_hot_compositeimplicitautograd_dispatch.h>
|
| 354 |
+
#include <ATen/ops/or_compositeimplicitautograd_dispatch.h>
|
| 355 |
+
#include <ATen/ops/orgqr_compositeimplicitautograd_dispatch.h>
|
| 356 |
+
#include <ATen/ops/outer_compositeimplicitautograd_dispatch.h>
|
| 357 |
+
#include <ATen/ops/output_nr_compositeimplicitautograd_dispatch.h>
|
| 358 |
+
#include <ATen/ops/pad_compositeimplicitautograd_dispatch.h>
|
| 359 |
+
#include <ATen/ops/pad_sequence_compositeimplicitautograd_dispatch.h>
|
| 360 |
+
#include <ATen/ops/pairwise_distance_compositeimplicitautograd_dispatch.h>
|
| 361 |
+
#include <ATen/ops/pdist_compositeimplicitautograd_dispatch.h>
|
| 362 |
+
#include <ATen/ops/pin_memory_compositeimplicitautograd_dispatch.h>
|
| 363 |
+
#include <ATen/ops/pinverse_compositeimplicitautograd_dispatch.h>
|
| 364 |
+
#include <ATen/ops/poisson_nll_loss_compositeimplicitautograd_dispatch.h>
|
| 365 |
+
#include <ATen/ops/positive_compositeimplicitautograd_dispatch.h>
|
| 366 |
+
#include <ATen/ops/prelu_compositeimplicitautograd_dispatch.h>
|
| 367 |
+
#include <ATen/ops/prod_compositeimplicitautograd_dispatch.h>
|
| 368 |
+
#include <ATen/ops/promote_types_compositeimplicitautograd_dispatch.h>
|
| 369 |
+
#include <ATen/ops/qr_compositeimplicitautograd_dispatch.h>
|
| 370 |
+
#include <ATen/ops/quantile_compositeimplicitautograd_dispatch.h>
|
| 371 |
+
#include <ATen/ops/quantized_gru_cell_compositeimplicitautograd_dispatch.h>
|
| 372 |
+
#include <ATen/ops/quantized_lstm_cell_compositeimplicitautograd_dispatch.h>
|
| 373 |
+
#include <ATen/ops/quantized_rnn_relu_cell_compositeimplicitautograd_dispatch.h>
|
| 374 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_compositeimplicitautograd_dispatch.h>
|
| 375 |
+
#include <ATen/ops/rand_compositeimplicitautograd_dispatch.h>
|
| 376 |
+
#include <ATen/ops/randn_compositeimplicitautograd_dispatch.h>
|
| 377 |
+
#include <ATen/ops/ravel_compositeimplicitautograd_dispatch.h>
|
| 378 |
+
#include <ATen/ops/real_compositeimplicitautograd_dispatch.h>
|
| 379 |
+
#include <ATen/ops/refine_names_compositeimplicitautograd_dispatch.h>
|
| 380 |
+
#include <ATen/ops/relu6_compositeimplicitautograd_dispatch.h>
|
| 381 |
+
#include <ATen/ops/rename_compositeimplicitautograd_dispatch.h>
|
| 382 |
+
#include <ATen/ops/repeat_interleave_compositeimplicitautograd_dispatch.h>
|
| 383 |
+
#include <ATen/ops/requires_grad_compositeimplicitautograd_dispatch.h>
|
| 384 |
+
#include <ATen/ops/reshape_compositeimplicitautograd_dispatch.h>
|
| 385 |
+
#include <ATen/ops/reshape_as_compositeimplicitautograd_dispatch.h>
|
| 386 |
+
#include <ATen/ops/resolve_conj_compositeimplicitautograd_dispatch.h>
|
| 387 |
+
#include <ATen/ops/resolve_neg_compositeimplicitautograd_dispatch.h>
|
| 388 |
+
#include <ATen/ops/result_type_compositeimplicitautograd_dispatch.h>
|
| 389 |
+
#include <ATen/ops/retain_grad_compositeimplicitautograd_dispatch.h>
|
| 390 |
+
#include <ATen/ops/retains_grad_compositeimplicitautograd_dispatch.h>
|
| 391 |
+
#include <ATen/ops/rnn_relu_compositeimplicitautograd_dispatch.h>
|
| 392 |
+
#include <ATen/ops/rnn_relu_cell_compositeimplicitautograd_dispatch.h>
|
| 393 |
+
#include <ATen/ops/rnn_tanh_compositeimplicitautograd_dispatch.h>
|
| 394 |
+
#include <ATen/ops/rnn_tanh_cell_compositeimplicitautograd_dispatch.h>
|
| 395 |
+
#include <ATen/ops/row_stack_compositeimplicitautograd_dispatch.h>
|
| 396 |
+
#include <ATen/ops/rrelu_compositeimplicitautograd_dispatch.h>
|
| 397 |
+
#include <ATen/ops/scaled_dot_product_attention_compositeimplicitautograd_dispatch.h>
|
| 398 |
+
#include <ATen/ops/scatter_compositeimplicitautograd_dispatch.h>
|
| 399 |
+
#include <ATen/ops/scatter_add_compositeimplicitautograd_dispatch.h>
|
| 400 |
+
#include <ATen/ops/select_compositeimplicitautograd_dispatch.h>
|
| 401 |
+
#include <ATen/ops/selu_compositeimplicitautograd_dispatch.h>
|
| 402 |
+
#include <ATen/ops/set_compositeimplicitautograd_dispatch.h>
|
| 403 |
+
#include <ATen/ops/set_data_compositeimplicitautograd_dispatch.h>
|
| 404 |
+
#include <ATen/ops/silu_backward_compositeimplicitautograd_dispatch.h>
|
| 405 |
+
#include <ATen/ops/size_compositeimplicitautograd_dispatch.h>
|
| 406 |
+
#include <ATen/ops/slogdet_compositeimplicitautograd_dispatch.h>
|
| 407 |
+
#include <ATen/ops/slow_conv3d_compositeimplicitautograd_dispatch.h>
|
| 408 |
+
#include <ATen/ops/smm_compositeimplicitautograd_dispatch.h>
|
| 409 |
+
#include <ATen/ops/softmax_compositeimplicitautograd_dispatch.h>
|
| 410 |
+
#include <ATen/ops/sort_compositeimplicitautograd_dispatch.h>
|
| 411 |
+
#include <ATen/ops/sparse_bsc_tensor_compositeimplicitautograd_dispatch.h>
|
| 412 |
+
#include <ATen/ops/sparse_bsr_tensor_compositeimplicitautograd_dispatch.h>
|
| 413 |
+
#include <ATen/ops/sparse_coo_tensor_compositeimplicitautograd_dispatch.h>
|
| 414 |
+
#include <ATen/ops/sparse_csc_tensor_compositeimplicitautograd_dispatch.h>
|
| 415 |
+
#include <ATen/ops/sparse_csr_tensor_compositeimplicitautograd_dispatch.h>
|
| 416 |
+
#include <ATen/ops/special_digamma_compositeimplicitautograd_dispatch.h>
|
| 417 |
+
#include <ATen/ops/special_erf_compositeimplicitautograd_dispatch.h>
|
| 418 |
+
#include <ATen/ops/special_erfc_compositeimplicitautograd_dispatch.h>
|
| 419 |
+
#include <ATen/ops/special_erfinv_compositeimplicitautograd_dispatch.h>
|
| 420 |
+
#include <ATen/ops/special_exp2_compositeimplicitautograd_dispatch.h>
|
| 421 |
+
#include <ATen/ops/special_expit_compositeimplicitautograd_dispatch.h>
|
| 422 |
+
#include <ATen/ops/special_expm1_compositeimplicitautograd_dispatch.h>
|
| 423 |
+
#include <ATen/ops/special_gammainc_compositeimplicitautograd_dispatch.h>
|
| 424 |
+
#include <ATen/ops/special_gammaincc_compositeimplicitautograd_dispatch.h>
|
| 425 |
+
#include <ATen/ops/special_gammaln_compositeimplicitautograd_dispatch.h>
|
| 426 |
+
#include <ATen/ops/special_i0_compositeimplicitautograd_dispatch.h>
|
| 427 |
+
#include <ATen/ops/special_log1p_compositeimplicitautograd_dispatch.h>
|
| 428 |
+
#include <ATen/ops/special_log_softmax_compositeimplicitautograd_dispatch.h>
|
| 429 |
+
#include <ATen/ops/special_logit_compositeimplicitautograd_dispatch.h>
|
| 430 |
+
#include <ATen/ops/special_logsumexp_compositeimplicitautograd_dispatch.h>
|
| 431 |
+
#include <ATen/ops/special_multigammaln_compositeimplicitautograd_dispatch.h>
|
| 432 |
+
#include <ATen/ops/special_ndtr_compositeimplicitautograd_dispatch.h>
|
| 433 |
+
#include <ATen/ops/special_polygamma_compositeimplicitautograd_dispatch.h>
|
| 434 |
+
#include <ATen/ops/special_psi_compositeimplicitautograd_dispatch.h>
|
| 435 |
+
#include <ATen/ops/special_round_compositeimplicitautograd_dispatch.h>
|
| 436 |
+
#include <ATen/ops/special_sinc_compositeimplicitautograd_dispatch.h>
|
| 437 |
+
#include <ATen/ops/special_softmax_compositeimplicitautograd_dispatch.h>
|
| 438 |
+
#include <ATen/ops/special_xlogy_compositeimplicitautograd_dispatch.h>
|
| 439 |
+
#include <ATen/ops/split_compositeimplicitautograd_dispatch.h>
|
| 440 |
+
#include <ATen/ops/square_compositeimplicitautograd_dispatch.h>
|
| 441 |
+
#include <ATen/ops/squeeze_compositeimplicitautograd_dispatch.h>
|
| 442 |
+
#include <ATen/ops/sspaddmm_compositeimplicitautograd_dispatch.h>
|
| 443 |
+
#include <ATen/ops/std_compositeimplicitautograd_dispatch.h>
|
| 444 |
+
#include <ATen/ops/std_mean_compositeimplicitautograd_dispatch.h>
|
| 445 |
+
#include <ATen/ops/stft_compositeimplicitautograd_dispatch.h>
|
| 446 |
+
#include <ATen/ops/stride_compositeimplicitautograd_dispatch.h>
|
| 447 |
+
#include <ATen/ops/subtract_compositeimplicitautograd_dispatch.h>
|
| 448 |
+
#include <ATen/ops/sum_compositeimplicitautograd_dispatch.h>
|
| 449 |
+
#include <ATen/ops/sum_to_size_compositeimplicitautograd_dispatch.h>
|
| 450 |
+
#include <ATen/ops/svd_compositeimplicitautograd_dispatch.h>
|
| 451 |
+
#include <ATen/ops/swapaxes_compositeimplicitautograd_dispatch.h>
|
| 452 |
+
#include <ATen/ops/swapdims_compositeimplicitautograd_dispatch.h>
|
| 453 |
+
#include <ATen/ops/sym_numel_compositeimplicitautograd_dispatch.h>
|
| 454 |
+
#include <ATen/ops/sym_size_compositeimplicitautograd_dispatch.h>
|
| 455 |
+
#include <ATen/ops/sym_storage_offset_compositeimplicitautograd_dispatch.h>
|
| 456 |
+
#include <ATen/ops/sym_stride_compositeimplicitautograd_dispatch.h>
|
| 457 |
+
#include <ATen/ops/take_along_dim_compositeimplicitautograd_dispatch.h>
|
| 458 |
+
#include <ATen/ops/tensor_split_compositeimplicitautograd_dispatch.h>
|
| 459 |
+
#include <ATen/ops/tensordot_compositeimplicitautograd_dispatch.h>
|
| 460 |
+
#include <ATen/ops/thnn_conv2d_compositeimplicitautograd_dispatch.h>
|
| 461 |
+
#include <ATen/ops/tile_compositeimplicitautograd_dispatch.h>
|
| 462 |
+
#include <ATen/ops/to_compositeimplicitautograd_dispatch.h>
|
| 463 |
+
#include <ATen/ops/to_dense_compositeimplicitautograd_dispatch.h>
|
| 464 |
+
#include <ATen/ops/to_dense_backward_compositeimplicitautograd_dispatch.h>
|
| 465 |
+
#include <ATen/ops/to_mkldnn_backward_compositeimplicitautograd_dispatch.h>
|
| 466 |
+
#include <ATen/ops/to_sparse_compositeimplicitautograd_dispatch.h>
|
| 467 |
+
#include <ATen/ops/to_sparse_bsc_compositeimplicitautograd_dispatch.h>
|
| 468 |
+
#include <ATen/ops/to_sparse_bsr_compositeimplicitautograd_dispatch.h>
|
| 469 |
+
#include <ATen/ops/to_sparse_csc_compositeimplicitautograd_dispatch.h>
|
| 470 |
+
#include <ATen/ops/to_sparse_csr_compositeimplicitautograd_dispatch.h>
|
| 471 |
+
#include <ATen/ops/trace_backward_compositeimplicitautograd_dispatch.h>
|
| 472 |
+
#include <ATen/ops/transpose_compositeimplicitautograd_dispatch.h>
|
| 473 |
+
#include <ATen/ops/trapezoid_compositeimplicitautograd_dispatch.h>
|
| 474 |
+
#include <ATen/ops/trapz_compositeimplicitautograd_dispatch.h>
|
| 475 |
+
#include <ATen/ops/triplet_margin_loss_compositeimplicitautograd_dispatch.h>
|
| 476 |
+
#include <ATen/ops/true_divide_compositeimplicitautograd_dispatch.h>
|
| 477 |
+
#include <ATen/ops/type_as_compositeimplicitautograd_dispatch.h>
|
| 478 |
+
#include <ATen/ops/unbind_compositeimplicitautograd_dispatch.h>
|
| 479 |
+
#include <ATen/ops/unflatten_compositeimplicitautograd_dispatch.h>
|
| 480 |
+
#include <ATen/ops/unflatten_dense_tensors_compositeimplicitautograd_dispatch.h>
|
| 481 |
+
#include <ATen/ops/unsafe_chunk_compositeimplicitautograd_dispatch.h>
|
| 482 |
+
#include <ATen/ops/upsample_bicubic2d_compositeimplicitautograd_dispatch.h>
|
| 483 |
+
#include <ATen/ops/upsample_bilinear2d_compositeimplicitautograd_dispatch.h>
|
| 484 |
+
#include <ATen/ops/upsample_linear1d_compositeimplicitautograd_dispatch.h>
|
| 485 |
+
#include <ATen/ops/upsample_nearest1d_compositeimplicitautograd_dispatch.h>
|
| 486 |
+
#include <ATen/ops/upsample_nearest2d_compositeimplicitautograd_dispatch.h>
|
| 487 |
+
#include <ATen/ops/upsample_nearest3d_compositeimplicitautograd_dispatch.h>
|
| 488 |
+
#include <ATen/ops/upsample_trilinear3d_compositeimplicitautograd_dispatch.h>
|
| 489 |
+
#include <ATen/ops/value_selecting_reduction_backward_compositeimplicitautograd_dispatch.h>
|
| 490 |
+
#include <ATen/ops/vander_compositeimplicitautograd_dispatch.h>
|
| 491 |
+
#include <ATen/ops/var_compositeimplicitautograd_dispatch.h>
|
| 492 |
+
#include <ATen/ops/var_mean_compositeimplicitautograd_dispatch.h>
|
| 493 |
+
#include <ATen/ops/view_as_compositeimplicitautograd_dispatch.h>
|
| 494 |
+
#include <ATen/ops/vsplit_compositeimplicitautograd_dispatch.h>
|
| 495 |
+
#include <ATen/ops/vstack_compositeimplicitautograd_dispatch.h>
|
| 496 |
+
#include <ATen/ops/where_compositeimplicitautograd_dispatch.h>
|
| 497 |
+
#include <ATen/ops/xor_compositeimplicitautograd_dispatch.h>
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DLConvertor.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/ATen.h>
|
| 4 |
+
#include <ATen/Tensor.h>
|
| 5 |
+
#include <ATen/dlpack.h>
|
| 6 |
+
|
| 7 |
+
// this convertor will:
|
| 8 |
+
// 1) take a Tensor object and wrap it in the DLPack tensor
|
| 9 |
+
// 2) take a dlpack tensor and convert it to the ATen Tensor
|
| 10 |
+
|
| 11 |
+
namespace at {
|
| 12 |
+
|
| 13 |
+
TORCH_API ScalarType toScalarType(const DLDataType& dtype);
|
| 14 |
+
TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
|
| 15 |
+
TORCH_API Tensor fromDLPack(DLManagedTensor* src);
|
| 16 |
+
C10_DEPRECATED_MESSAGE("Please migrate to a non-const variant")
|
| 17 |
+
inline Tensor fromDLPack(const DLManagedTensor* src) {
|
| 18 |
+
return fromDLPack(const_cast<DLManagedTensor*>(src));
|
| 19 |
+
}
|
| 20 |
+
TORCH_API Tensor
|
| 21 |
+
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter);
|
| 22 |
+
TORCH_API DLDataType getDLDataType(const Tensor& t);
|
| 23 |
+
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
|
| 24 |
+
|
| 25 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Device.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Device.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/DeviceGuard.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/IListRef.h>
|
| 4 |
+
#include <ATen/core/Tensor.h>
|
| 5 |
+
#include <c10/core/DeviceGuard.h>
|
| 6 |
+
#include <c10/core/ScalarType.h> // TensorList whyyyyy
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
// Are you here because you're wondering why DeviceGuard(tensor) no
|
| 11 |
+
// longer works? For code organization reasons, we have temporarily(?)
|
| 12 |
+
// removed this constructor from DeviceGuard. The new way to
|
| 13 |
+
// spell it is:
|
| 14 |
+
//
|
| 15 |
+
// OptionalDeviceGuard guard(device_of(tensor));
|
| 16 |
+
|
| 17 |
+
/// Return the Device of a Tensor, if the Tensor is defined.
|
| 18 |
+
inline c10::optional<Device> device_of(const Tensor& t) {
|
| 19 |
+
if (t.defined()) {
|
| 20 |
+
return c10::make_optional(t.device());
|
| 21 |
+
} else {
|
| 22 |
+
return c10::nullopt;
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
inline c10::optional<Device> device_of(const c10::optional<Tensor>& t) {
|
| 27 |
+
return t.has_value() ? device_of(t.value()) : c10::nullopt;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/// Return the Device of a TensorList, if the list is non-empty and
|
| 31 |
+
/// the first Tensor is defined. (This function implicitly assumes
|
| 32 |
+
/// that all tensors in the list have the same device.)
|
| 33 |
+
inline c10::optional<Device> device_of(ITensorListRef t) {
|
| 34 |
+
if (!t.empty()) {
|
| 35 |
+
return device_of(t.front());
|
| 36 |
+
} else {
|
| 37 |
+
return c10::nullopt;
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ExpandUtils.h
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 4 |
+
#include <ATen/Functions.h>
|
| 5 |
+
#else
|
| 6 |
+
#include <ATen/ops/view.h>
|
| 7 |
+
#include <ATen/ops/view_copy.h>
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
#include <ATen/Tensor.h>
|
| 11 |
+
#include <ATen/core/DimVector.h>
|
| 12 |
+
#include <c10/util/Exception.h>
|
| 13 |
+
#include <c10/util/MaybeOwned.h>
|
| 14 |
+
#include <c10/util/irange.h>
|
| 15 |
+
|
| 16 |
+
#include <functional>
|
| 17 |
+
#include <sstream>
|
| 18 |
+
#include <tuple>
|
| 19 |
+
#include <utility>
|
| 20 |
+
|
| 21 |
+
namespace at {
|
| 22 |
+
|
| 23 |
+
TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
|
| 24 |
+
TORCH_API std::vector<SymInt> infer_size_symint(
|
| 25 |
+
SymIntArrayRef a,
|
| 26 |
+
SymIntArrayRef b);
|
| 27 |
+
TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
|
| 28 |
+
TORCH_API SymDimVector
|
| 29 |
+
infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
|
| 30 |
+
|
| 31 |
+
// Named type instead of a pair/tuple so that we can be sure to
|
| 32 |
+
// construct the vectors in place and get NRVO.
|
| 33 |
+
template <typename Container>
|
| 34 |
+
struct InferExpandGeometryResult {
|
| 35 |
+
Container sizes;
|
| 36 |
+
Container strides;
|
| 37 |
+
explicit InferExpandGeometryResult(size_t ndim)
|
| 38 |
+
: sizes(ndim), strides(ndim) {}
|
| 39 |
+
explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
|
| 40 |
+
: sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
| 44 |
+
inferExpandGeometry(
|
| 45 |
+
IntArrayRef tensor_sizes,
|
| 46 |
+
IntArrayRef tensor_strides,
|
| 47 |
+
IntArrayRef sizes);
|
| 48 |
+
|
| 49 |
+
TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
|
| 50 |
+
IntArrayRef tensor_sizes,
|
| 51 |
+
IntArrayRef tensor_strides,
|
| 52 |
+
IntArrayRef sizes);
|
| 53 |
+
|
| 54 |
+
TORCH_API std::vector<int64_t> infer_dense_strides(
|
| 55 |
+
IntArrayRef tensor_sizes,
|
| 56 |
+
IntArrayRef tensor_strides);
|
| 57 |
+
|
| 58 |
+
// True if input shapes are expandable
|
| 59 |
+
// NOTE: infer_size did a similar check, please keep them sync if change is
|
| 60 |
+
// needed
|
| 61 |
+
inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
|
| 62 |
+
size_t ndim1 = shape1.size();
|
| 63 |
+
size_t ndim2 = shape2.size();
|
| 64 |
+
size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
|
| 65 |
+
|
| 66 |
+
for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
|
| 67 |
+
if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
|
| 68 |
+
shape2[ndim2] == 1) {
|
| 69 |
+
continue;
|
| 70 |
+
}
|
| 71 |
+
return false;
|
| 72 |
+
}
|
| 73 |
+
return true;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// avoid copy-construction of Tensor by using a reference_wrapper.
|
| 77 |
+
inline void check_defined(
|
| 78 |
+
std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
|
| 79 |
+
const char* api_name) {
|
| 80 |
+
for (auto& t : tensors) {
|
| 81 |
+
if (!t.get().defined()) {
|
| 82 |
+
AT_ERROR(api_name, "(...) called with an undefined Tensor");
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// NOTE [ ExpandUtils Borrowing ]
|
| 88 |
+
//
|
| 89 |
+
// Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
|
| 90 |
+
// expansion may not actually be needed, in which case we can improve
|
| 91 |
+
// efficiency by returning
|
| 92 |
+
// `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
|
| 93 |
+
// that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
|
| 94 |
+
// must not outlive the original `Tensor` object that `to_expand`
|
| 95 |
+
// referred to! The deleted rvalue reference overloads of these
|
| 96 |
+
// functions help with this by preventing trivial use of a temporary
|
| 97 |
+
// resulting from a function call, but it is still possible to make a
|
| 98 |
+
// mistake.
|
| 99 |
+
|
| 100 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 101 |
+
const Tensor& tensor,
|
| 102 |
+
const Tensor& to_expand) {
|
| 103 |
+
if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
|
| 104 |
+
return c10::MaybeOwned<Tensor>::borrowed(to_expand);
|
| 105 |
+
}
|
| 106 |
+
return c10::MaybeOwned<Tensor>::owned(
|
| 107 |
+
to_expand.expand_symint(tensor.sym_sizes()));
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 111 |
+
const Tensor& tensor,
|
| 112 |
+
Tensor&& to_expand) = delete;
|
| 113 |
+
|
| 114 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 115 |
+
const Tensor& tensor,
|
| 116 |
+
const Tensor& to_expand,
|
| 117 |
+
const char* api_name) {
|
| 118 |
+
check_defined({tensor, to_expand}, api_name);
|
| 119 |
+
return expand_inplace(tensor, to_expand);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
inline c10::MaybeOwned<Tensor> expand_inplace(
|
| 123 |
+
const Tensor& tensor,
|
| 124 |
+
Tensor&& to_expand,
|
| 125 |
+
const char* api_name) = delete;
|
| 126 |
+
|
| 127 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 128 |
+
expand_inplace(
|
| 129 |
+
const Tensor& tensor,
|
| 130 |
+
const Tensor& to_expand1,
|
| 131 |
+
const Tensor& to_expand2) {
|
| 132 |
+
if (tensor.sizes().equals(to_expand1.sizes()) &&
|
| 133 |
+
tensor.sizes().equals((to_expand2.sizes()))) {
|
| 134 |
+
return std::make_tuple(
|
| 135 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
|
| 136 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
return std::make_tuple(
|
| 140 |
+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
|
| 141 |
+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 145 |
+
expand_inplace(
|
| 146 |
+
const Tensor& tensor,
|
| 147 |
+
Tensor&& to_expand1,
|
| 148 |
+
const Tensor& to_expand2) = delete;
|
| 149 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 150 |
+
expand_inplace(
|
| 151 |
+
const Tensor& tensor,
|
| 152 |
+
const Tensor& to_expand1,
|
| 153 |
+
Tensor&& to_expand2) = delete;
|
| 154 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 155 |
+
expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
|
| 156 |
+
delete;
|
| 157 |
+
|
| 158 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 159 |
+
expand_inplace(
|
| 160 |
+
const Tensor& tensor,
|
| 161 |
+
const Tensor& to_expand1,
|
| 162 |
+
const Tensor& to_expand2,
|
| 163 |
+
const char* api_name) {
|
| 164 |
+
check_defined({tensor, to_expand1, to_expand2}, api_name);
|
| 165 |
+
return expand_inplace(tensor, to_expand1, to_expand2);
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 169 |
+
expand_inplace(
|
| 170 |
+
const Tensor& tensor,
|
| 171 |
+
Tensor&& to_expand1,
|
| 172 |
+
const Tensor& to_expand2,
|
| 173 |
+
const char* api_name) = delete;
|
| 174 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 175 |
+
expand_inplace(
|
| 176 |
+
const Tensor& tensor,
|
| 177 |
+
const Tensor& to_expand1,
|
| 178 |
+
Tensor&& to_expand2,
|
| 179 |
+
const char* api_name) = delete;
|
| 180 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 181 |
+
expand_inplace(
|
| 182 |
+
const Tensor& tensor,
|
| 183 |
+
Tensor&& to_expand1,
|
| 184 |
+
Tensor&& to_expand2,
|
| 185 |
+
const char* api_name) = delete;
|
| 186 |
+
|
| 187 |
+
// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
|
| 188 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 189 |
+
expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
|
| 190 |
+
auto s1 = to_expand1.sym_sizes();
|
| 191 |
+
auto s2 = to_expand2.sym_sizes();
|
| 192 |
+
if (s1.equals(s2)) {
|
| 193 |
+
return std::make_tuple(
|
| 194 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
|
| 195 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
auto expanded_size = infer_size_symdimvector(s1, s2);
|
| 199 |
+
return std::make_tuple(
|
| 200 |
+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
|
| 201 |
+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 205 |
+
expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
|
| 206 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 207 |
+
expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
|
| 208 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 209 |
+
expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
|
| 210 |
+
|
| 211 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 212 |
+
expand_outplace(
|
| 213 |
+
const Tensor& to_expand1,
|
| 214 |
+
const Tensor& to_expand2,
|
| 215 |
+
const char* api_name) {
|
| 216 |
+
check_defined({to_expand1, to_expand2}, api_name);
|
| 217 |
+
return expand_outplace(to_expand1, to_expand2);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 221 |
+
expand_outplace(
|
| 222 |
+
Tensor&& to_expand1,
|
| 223 |
+
const Tensor& to_expand2,
|
| 224 |
+
const char* api_name) = delete;
|
| 225 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 226 |
+
expand_outplace(
|
| 227 |
+
const Tensor& to_expand1,
|
| 228 |
+
Tensor&& to_expand2,
|
| 229 |
+
const char* api_name) = delete;
|
| 230 |
+
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
|
| 231 |
+
expand_outplace(
|
| 232 |
+
Tensor&& to_expand1,
|
| 233 |
+
Tensor&& to_expand2,
|
| 234 |
+
const char* api_name) = delete;
|
| 235 |
+
|
| 236 |
+
inline std::tuple<
|
| 237 |
+
c10::MaybeOwned<Tensor>,
|
| 238 |
+
c10::MaybeOwned<Tensor>,
|
| 239 |
+
c10::MaybeOwned<Tensor>>
|
| 240 |
+
expand_outplace(
|
| 241 |
+
const Tensor& to_expand1,
|
| 242 |
+
const Tensor& to_expand2,
|
| 243 |
+
const Tensor& to_expand3) {
|
| 244 |
+
if (to_expand1.sizes().equals(to_expand2.sizes()) &&
|
| 245 |
+
to_expand1.sizes().equals(to_expand3.sizes())) {
|
| 246 |
+
return std::make_tuple(
|
| 247 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
|
| 248 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand2),
|
| 249 |
+
c10::MaybeOwned<Tensor>::borrowed(to_expand3));
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
auto expanded_size12 =
|
| 253 |
+
infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
|
| 254 |
+
auto expanded_size =
|
| 255 |
+
infer_size_dimvector(expanded_size12, to_expand3.sizes());
|
| 256 |
+
return std::make_tuple(
|
| 257 |
+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
|
| 258 |
+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
|
| 259 |
+
c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
inline std::tuple<
|
| 263 |
+
c10::MaybeOwned<Tensor>,
|
| 264 |
+
c10::MaybeOwned<Tensor>,
|
| 265 |
+
c10::MaybeOwned<Tensor>>
|
| 266 |
+
expand_outplace(
|
| 267 |
+
Tensor&& to_expand1,
|
| 268 |
+
const Tensor& to_expand2,
|
| 269 |
+
const Tensor& to_expand3) = delete;
|
| 270 |
+
inline std::tuple<
|
| 271 |
+
c10::MaybeOwned<Tensor>,
|
| 272 |
+
c10::MaybeOwned<Tensor>,
|
| 273 |
+
c10::MaybeOwned<Tensor>>
|
| 274 |
+
expand_outplace(
|
| 275 |
+
const Tensor& to_expand1,
|
| 276 |
+
Tensor&& to_expand2,
|
| 277 |
+
const Tensor& to_expand3) = delete;
|
| 278 |
+
inline std::tuple<
|
| 279 |
+
c10::MaybeOwned<Tensor>,
|
| 280 |
+
c10::MaybeOwned<Tensor>,
|
| 281 |
+
c10::MaybeOwned<Tensor>>
|
| 282 |
+
expand_outplace(
|
| 283 |
+
Tensor&& to_expand1,
|
| 284 |
+
Tensor&& to_expand2,
|
| 285 |
+
const Tensor& to_expand3) = delete;
|
| 286 |
+
inline std::tuple<
|
| 287 |
+
c10::MaybeOwned<Tensor>,
|
| 288 |
+
c10::MaybeOwned<Tensor>,
|
| 289 |
+
c10::MaybeOwned<Tensor>>
|
| 290 |
+
expand_outplace(
|
| 291 |
+
const Tensor& to_expand1,
|
| 292 |
+
const Tensor& to_expand2,
|
| 293 |
+
Tensor&& to_expand3) = delete;
|
| 294 |
+
inline std::tuple<
|
| 295 |
+
c10::MaybeOwned<Tensor>,
|
| 296 |
+
c10::MaybeOwned<Tensor>,
|
| 297 |
+
c10::MaybeOwned<Tensor>>
|
| 298 |
+
expand_outplace(
|
| 299 |
+
Tensor&& to_expand1,
|
| 300 |
+
const Tensor& to_expand2,
|
| 301 |
+
Tensor&& to_expand3) = delete;
|
| 302 |
+
inline std::tuple<
|
| 303 |
+
c10::MaybeOwned<Tensor>,
|
| 304 |
+
c10::MaybeOwned<Tensor>,
|
| 305 |
+
c10::MaybeOwned<Tensor>>
|
| 306 |
+
expand_outplace(
|
| 307 |
+
const Tensor& to_expand1,
|
| 308 |
+
Tensor&& to_expand2,
|
| 309 |
+
Tensor&& to_expand3) = delete;
|
| 310 |
+
inline std::tuple<
|
| 311 |
+
c10::MaybeOwned<Tensor>,
|
| 312 |
+
c10::MaybeOwned<Tensor>,
|
| 313 |
+
c10::MaybeOwned<Tensor>>
|
| 314 |
+
expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
|
| 315 |
+
delete;
|
| 316 |
+
|
| 317 |
+
inline std::tuple<
|
| 318 |
+
c10::MaybeOwned<Tensor>,
|
| 319 |
+
c10::MaybeOwned<Tensor>,
|
| 320 |
+
c10::MaybeOwned<Tensor>>
|
| 321 |
+
expand_outplace(
|
| 322 |
+
const Tensor& to_expand1,
|
| 323 |
+
const Tensor& to_expand2,
|
| 324 |
+
const Tensor& to_expand3,
|
| 325 |
+
const char* api_name) {
|
| 326 |
+
check_defined({to_expand1, to_expand2, to_expand3}, api_name);
|
| 327 |
+
return expand_outplace(to_expand1, to_expand2, to_expand3);
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
inline std::tuple<
|
| 331 |
+
c10::MaybeOwned<Tensor>,
|
| 332 |
+
c10::MaybeOwned<Tensor>,
|
| 333 |
+
c10::MaybeOwned<Tensor>>
|
| 334 |
+
expand_outplace(
|
| 335 |
+
Tensor&& to_expand1,
|
| 336 |
+
const Tensor& to_expand2,
|
| 337 |
+
const Tensor& to_expand3,
|
| 338 |
+
const char* api_name) = delete;
|
| 339 |
+
inline std::tuple<
|
| 340 |
+
c10::MaybeOwned<Tensor>,
|
| 341 |
+
c10::MaybeOwned<Tensor>,
|
| 342 |
+
c10::MaybeOwned<Tensor>>
|
| 343 |
+
expand_outplace(
|
| 344 |
+
const Tensor& to_expand1,
|
| 345 |
+
Tensor&& to_expand2,
|
| 346 |
+
const Tensor& to_expand3,
|
| 347 |
+
const char* api_name) = delete;
|
| 348 |
+
inline std::tuple<
|
| 349 |
+
c10::MaybeOwned<Tensor>,
|
| 350 |
+
c10::MaybeOwned<Tensor>,
|
| 351 |
+
c10::MaybeOwned<Tensor>>
|
| 352 |
+
expand_outplace(
|
| 353 |
+
Tensor&& to_expand1,
|
| 354 |
+
Tensor&& to_expand2,
|
| 355 |
+
const Tensor& to_expand3,
|
| 356 |
+
const char* api_name) = delete;
|
| 357 |
+
inline std::tuple<
|
| 358 |
+
c10::MaybeOwned<Tensor>,
|
| 359 |
+
c10::MaybeOwned<Tensor>,
|
| 360 |
+
c10::MaybeOwned<Tensor>>
|
| 361 |
+
expand_outplace(
|
| 362 |
+
const Tensor& to_expand1,
|
| 363 |
+
const Tensor& to_expand2,
|
| 364 |
+
Tensor&& to_expand3,
|
| 365 |
+
const char* api_name) = delete;
|
| 366 |
+
inline std::tuple<
|
| 367 |
+
c10::MaybeOwned<Tensor>,
|
| 368 |
+
c10::MaybeOwned<Tensor>,
|
| 369 |
+
c10::MaybeOwned<Tensor>>
|
| 370 |
+
expand_outplace(
|
| 371 |
+
Tensor&& to_expand1,
|
| 372 |
+
const Tensor& to_expand2,
|
| 373 |
+
Tensor&& to_expand3,
|
| 374 |
+
const char* api_name) = delete;
|
| 375 |
+
inline std::tuple<
|
| 376 |
+
c10::MaybeOwned<Tensor>,
|
| 377 |
+
c10::MaybeOwned<Tensor>,
|
| 378 |
+
c10::MaybeOwned<Tensor>>
|
| 379 |
+
expand_outplace(
|
| 380 |
+
const Tensor& to_expand1,
|
| 381 |
+
Tensor&& to_expand2,
|
| 382 |
+
Tensor&& to_expand3,
|
| 383 |
+
const char* api_name) = delete;
|
| 384 |
+
inline std::tuple<
|
| 385 |
+
c10::MaybeOwned<Tensor>,
|
| 386 |
+
c10::MaybeOwned<Tensor>,
|
| 387 |
+
c10::MaybeOwned<Tensor>>
|
| 388 |
+
expand_outplace(
|
| 389 |
+
Tensor&& to_expand1,
|
| 390 |
+
Tensor&& to_expand2,
|
| 391 |
+
Tensor&& to_expand3,
|
| 392 |
+
const char* api_name) = delete;
|
| 393 |
+
|
| 394 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 395 |
+
const Tensor& to_expand,
|
| 396 |
+
IntArrayRef sizes) {
|
| 397 |
+
if (to_expand.sizes().equals(sizes)) {
|
| 398 |
+
return c10::MaybeOwned<Tensor>::borrowed(to_expand);
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 405 |
+
Tensor&& to_expand,
|
| 406 |
+
IntArrayRef sizes) = delete;
|
| 407 |
+
|
| 408 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 409 |
+
const Tensor& to_expand,
|
| 410 |
+
IntArrayRef sizes,
|
| 411 |
+
const char* api_name) {
|
| 412 |
+
check_defined({to_expand}, api_name);
|
| 413 |
+
return expand_size(to_expand, sizes);
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
inline c10::MaybeOwned<Tensor> expand_size(
|
| 417 |
+
Tensor&& to_expand,
|
| 418 |
+
IntArrayRef sizes,
|
| 419 |
+
const char* api_name) = delete;
|
| 420 |
+
|
| 421 |
+
inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
|
| 422 |
+
// expands a list of Tensors; ignores undefined (null) tensors
|
| 423 |
+
bool first = true;
|
| 424 |
+
DimVector sizes;
|
| 425 |
+
for (const auto i : c10::irange(to_expand.size())) {
|
| 426 |
+
if (!to_expand[i].defined()) {
|
| 427 |
+
continue;
|
| 428 |
+
} else if (first) {
|
| 429 |
+
sizes = to_expand[i].sizes();
|
| 430 |
+
first = false;
|
| 431 |
+
} else {
|
| 432 |
+
sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
std::vector<Tensor> result(to_expand.size());
|
| 437 |
+
for (const auto i : c10::irange(to_expand.size())) {
|
| 438 |
+
if (!to_expand[i].defined()) {
|
| 439 |
+
continue;
|
| 440 |
+
} else if (to_expand[i].sizes().equals(sizes)) {
|
| 441 |
+
result[i] = to_expand[i];
|
| 442 |
+
} else {
|
| 443 |
+
result[i] = to_expand[i].expand(sizes);
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
return result;
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
template <typename T>
|
| 450 |
+
inline Tensor _sum_to(
|
| 451 |
+
Tensor tensor,
|
| 452 |
+
const c10::ArrayRef<T> shape,
|
| 453 |
+
bool always_return_non_view = false) {
|
| 454 |
+
if (shape.size() == 0) {
|
| 455 |
+
return tensor.sum();
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
auto sizes = at::symint::sizes<T>(tensor);
|
| 459 |
+
c10::SmallVector<int64_t, 8> reduce_dims;
|
| 460 |
+
const int64_t leading_dims = sizes.size() - shape.size();
|
| 461 |
+
for (const auto i : c10::irange(leading_dims)) {
|
| 462 |
+
reduce_dims.push_back(i);
|
| 463 |
+
}
|
| 464 |
+
for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
|
| 465 |
+
if (shape[i - leading_dims] == 1 && sizes[i] != 1) {
|
| 466 |
+
reduce_dims.push_back(i);
|
| 467 |
+
}
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
if (!reduce_dims.empty()) {
|
| 471 |
+
tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
if (always_return_non_view) {
|
| 475 |
+
// This is only actually used by the functionalization pass.
|
| 476 |
+
// We want to be able to guarantee that this function doesn't return a view
|
| 477 |
+
// of the input.
|
| 478 |
+
return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
|
| 479 |
+
: tensor.clone();
|
| 480 |
+
} else {
|
| 481 |
+
return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
inline Tensor sum_to(
|
| 486 |
+
Tensor tensor,
|
| 487 |
+
const c10::SymIntArrayRef shape,
|
| 488 |
+
bool always_return_non_view = false) {
|
| 489 |
+
return _sum_to(std::move(tensor), shape, always_return_non_view);
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
|
| 493 |
+
// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
|
| 494 |
+
inline Tensor sum_to(
|
| 495 |
+
Tensor tensor,
|
| 496 |
+
const IntArrayRef shape,
|
| 497 |
+
bool always_return_non_view = false) {
|
| 498 |
+
return _sum_to(std::move(tensor), shape, always_return_non_view);
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
static inline bool is_expandable_to(
|
| 502 |
+
SymIntArrayRef shape,
|
| 503 |
+
c10::SymIntArrayRef desired) {
|
| 504 |
+
size_t ndim = shape.size();
|
| 505 |
+
size_t target_dim = desired.size();
|
| 506 |
+
if (ndim > target_dim) {
|
| 507 |
+
return false;
|
| 508 |
+
}
|
| 509 |
+
for (const auto i : c10::irange(ndim)) {
|
| 510 |
+
const auto& size = shape[ndim - i - 1];
|
| 511 |
+
const auto& target = desired[target_dim - i - 1];
|
| 512 |
+
if (size != target && size != 1) {
|
| 513 |
+
return false;
|
| 514 |
+
}
|
| 515 |
+
}
|
| 516 |
+
return true;
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
|
| 520 |
+
auto sym_shape = c10::SymIntArrayRef(
|
| 521 |
+
reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
|
| 522 |
+
auto sym_desired = c10::SymIntArrayRef(
|
| 523 |
+
reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
|
| 524 |
+
return is_expandable_to(sym_shape, sym_desired);
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/FunctionalStorageImpl.h
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
|
| 5 |
+
namespace at::functionalization {
|
| 6 |
+
|
| 7 |
+
// See Note [Functionalization Pass In Core]
|
| 8 |
+
|
| 9 |
+
// ViewMeta is a class used by the functionalization pass to navigate between
|
| 10 |
+
// a base tensor and a view tensor.
|
| 11 |
+
// For example, if I call `b = a.view1(...)`
|
| 12 |
+
// the functionalization pass will generate and store a ViewMeta on b that looks
|
| 13 |
+
// like:
|
| 14 |
+
//
|
| 15 |
+
// ViewMeta(
|
| 16 |
+
// [<captures>](const Tensor& base, int64_t mutated_view_idx) {
|
| 17 |
+
// return base.view1(...);
|
| 18 |
+
// },
|
| 19 |
+
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
|
| 20 |
+
// int64_t mutated_view_idx) -> at::Tensor {
|
| 21 |
+
// return at::functionalization::impl::view1_inverse(base, mutated_view,
|
| 22 |
+
// ...);
|
| 23 |
+
// }
|
| 24 |
+
//
|
| 25 |
+
// The forward_fn lambda describes how to replay view1 on a tensor.
|
| 26 |
+
//
|
| 27 |
+
// The reverse_fn lambda describes how, given a tensor that is already a view,
|
| 28 |
+
// how to get the corresponding base tensor. See Note [Functionalization Pass:
|
| 29 |
+
// View Inverses] for details.
|
| 30 |
+
struct ViewMeta {
|
| 31 |
+
ViewMeta(
|
| 32 |
+
std::function<Tensor(const Tensor&, int64_t)> forward,
|
| 33 |
+
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
|
| 34 |
+
bool is_multi_output = false,
|
| 35 |
+
int64_t out_idx = 0)
|
| 36 |
+
: forward_fn(std::move(forward)),
|
| 37 |
+
reverse_fn(std::move(reverse)),
|
| 38 |
+
out_index(out_idx),
|
| 39 |
+
is_multi_output(is_multi_output) {}
|
| 40 |
+
|
| 41 |
+
std::function<Tensor(const Tensor&, int64_t)> forward_fn;
|
| 42 |
+
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
|
| 43 |
+
// See Note [out_idx in ViewMeta]
|
| 44 |
+
int64_t out_index;
|
| 45 |
+
|
| 46 |
+
// Tells us if this is a multi-output view
|
| 47 |
+
bool is_multi_output;
|
| 48 |
+
|
| 49 |
+
// Returns a copy of the current ViewMeta, if out_idx matches the current
|
| 50 |
+
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
|
| 51 |
+
// functions, but a new out index.
|
| 52 |
+
ViewMeta to_out_idx(int64_t out_idx);
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
// FunctionalStorageImpl is a subclass of StorageImpl used by the
|
| 56 |
+
// functionalization pass. It has no underlying data (similar to meta storage).
|
| 57 |
+
// It also knows how to reflect mutations to tensors in the absence of a valid
|
| 58 |
+
// data pointer.
|
| 59 |
+
//
|
| 60 |
+
// A storage represents the state shared by (potentially multiple) views of the
|
| 61 |
+
// same tensor. For example, in the following code:
|
| 62 |
+
//
|
| 63 |
+
// b = a.view1(...)
|
| 64 |
+
// c = b.view2(...)
|
| 65 |
+
// b.add_(1)
|
| 66 |
+
// --> storage.add_update(b, {view1_meta})
|
| 67 |
+
//
|
| 68 |
+
// The call to add_(1) will result in a call to alias.add_update(b,
|
| 69 |
+
// {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose
|
| 70 |
+
// c is used in an expression (e.g. you try to print c, or pass it to an
|
| 71 |
+
// operator). Doing so will involve "syncing" c. First we apply any pending
|
| 72 |
+
// updates to the alias, and then we regenerate c by replaying its views off of
|
| 73 |
+
// the updated alias. E.g:
|
| 74 |
+
//
|
| 75 |
+
// print(str(c))
|
| 76 |
+
// --> c.sync_()
|
| 77 |
+
// --> alias.apply_updates() // after this, the alias will be updated to
|
| 78 |
+
// reflect the mutation to b
|
| 79 |
+
struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
|
| 80 |
+
public:
|
| 81 |
+
struct Update {
|
| 82 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 83 |
+
const at::Tensor new_val;
|
| 84 |
+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
| 85 |
+
const std::vector<ViewMeta> view_metas;
|
| 86 |
+
};
|
| 87 |
+
|
| 88 |
+
explicit FunctionalStorageImpl(const Tensor& value);
|
| 89 |
+
|
| 90 |
+
void add_update(
|
| 91 |
+
const Tensor& updated_val,
|
| 92 |
+
const std::vector<ViewMeta>& view_metas);
|
| 93 |
+
bool apply_updates();
|
| 94 |
+
const Tensor& base() {
|
| 95 |
+
return base_;
|
| 96 |
+
}
|
| 97 |
+
size_t generation() const {
|
| 98 |
+
return generation_;
|
| 99 |
+
}
|
| 100 |
+
void freeze() {
|
| 101 |
+
frozen_ = true;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
~FunctionalStorageImpl() override = default;
|
| 105 |
+
|
| 106 |
+
private:
|
| 107 |
+
// NB: base_ should always point to a tensor BELOW the current
|
| 108 |
+
// functionalization layer. This is mainly to avoid reference cycles. e.g.
|
| 109 |
+
// given `b = a.view(...)` Both a.storage_ and b.storage_ are a
|
| 110 |
+
// FunctionStorageImpl containing an Walualias, with contains a Tensor
|
| 111 |
+
// `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_
|
| 112 |
+
// should point not to a, but to a's unwrapped value, a.value_` See Note
|
| 113 |
+
// [Functionalization: Walualias Removal] for a diagram that shows this
|
| 114 |
+
// visually.
|
| 115 |
+
at::Tensor base_;
|
| 116 |
+
std::vector<Update> updates_;
|
| 117 |
+
// generation_ gets incremented every time a mutation is queued onto the
|
| 118 |
+
// alias. It is used to determine if a given tensor is "up to date", or if it
|
| 119 |
+
// needs to be regenerated from the alias.
|
| 120 |
+
size_t generation_ = 0;
|
| 121 |
+
// If frozen, no more mutations are allowed on this storage. Once frozen, a
|
| 122 |
+
// storage cannot be unfrozen.
|
| 123 |
+
bool frozen_ = false;
|
| 124 |
+
};
|
| 125 |
+
|
| 126 |
+
} // namespace at::functionalization
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/InitialTensorOptions.h
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/TensorOptions.h>
|
| 4 |
+
|
| 5 |
+
namespace at {
|
| 6 |
+
|
| 7 |
+
// Represents the initial TensorOptions, before the "defaults" are ever changed.
|
| 8 |
+
// This is designed to be used in library code, where the explicit devices,
|
| 9 |
+
// dtypes, etc. are known. NOTE: this is not a stable API.
|
| 10 |
+
inline TensorOptions initialTensorOptions() {
|
| 11 |
+
return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad(
|
| 12 |
+
false);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/Layout.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/core/Layout.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/LinalgBackend.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/util/Exception.h>
|
| 4 |
+
|
| 5 |
+
#include <ostream>
|
| 6 |
+
#include <string>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
enum class LinalgBackend : int8_t { Default, Cusolver, Magma };
|
| 11 |
+
|
| 12 |
+
inline std::string LinalgBackendToString(at::LinalgBackend backend) {
|
| 13 |
+
switch (backend) {
|
| 14 |
+
case LinalgBackend::Default:
|
| 15 |
+
return "at::LinalgBackend::Default";
|
| 16 |
+
case LinalgBackend::Cusolver:
|
| 17 |
+
return "at::LinalgBackend::Cusolver";
|
| 18 |
+
case LinalgBackend::Magma:
|
| 19 |
+
return "at::LinalgBackend::Magma";
|
| 20 |
+
default:
|
| 21 |
+
TORCH_CHECK(false, "Unknown linalg backend");
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
inline std::ostream& operator<<(
|
| 26 |
+
std::ostream& stream,
|
| 27 |
+
at::LinalgBackend backend) {
|
| 28 |
+
return stream << LinalgBackendToString(backend);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MatrixRef.h
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/Utils.h>
|
| 3 |
+
#include <c10/util/ArrayRef.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
|
| 7 |
+
namespace at {
|
| 8 |
+
/// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
|
| 9 |
+
/// we can easily view it as a multidimensional array.
|
| 10 |
+
///
|
| 11 |
+
/// Like ArrayRef, this class does not own the underlying data, it is expected
|
| 12 |
+
/// to be used in situations where the data resides in some other buffer.
|
| 13 |
+
///
|
| 14 |
+
/// This is intended to be trivially copyable, so it should be passed by
|
| 15 |
+
/// value.
|
| 16 |
+
///
|
| 17 |
+
/// For now, 2D only (so the copies are actually cheap, without having
|
| 18 |
+
/// to write a SmallVector class) and contiguous only (so we can
|
| 19 |
+
/// return non-strided ArrayRef on index).
|
| 20 |
+
///
|
| 21 |
+
/// P.S. dimension 0 indexes rows, dimension 1 indexes columns
|
| 22 |
+
template <typename T>
|
| 23 |
+
class MatrixRef {
|
| 24 |
+
public:
|
| 25 |
+
typedef size_t size_type;
|
| 26 |
+
|
| 27 |
+
private:
|
| 28 |
+
/// Underlying ArrayRef
|
| 29 |
+
ArrayRef<T> arr;
|
| 30 |
+
|
| 31 |
+
/// Stride of dim 0 (outer dimension)
|
| 32 |
+
size_type stride0;
|
| 33 |
+
|
| 34 |
+
// Stride of dim 1 is assumed to be 1
|
| 35 |
+
|
| 36 |
+
public:
|
| 37 |
+
/// Construct an empty Matrixref.
|
| 38 |
+
/*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
|
| 39 |
+
|
| 40 |
+
/// Construct an MatrixRef from an ArrayRef and outer stride.
|
| 41 |
+
/*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
|
| 42 |
+
: arr(arr), stride0(stride0) {
|
| 43 |
+
TORCH_CHECK(
|
| 44 |
+
arr.size() % stride0 == 0,
|
| 45 |
+
"MatrixRef: ArrayRef size ",
|
| 46 |
+
arr.size(),
|
| 47 |
+
" not divisible by stride ",
|
| 48 |
+
stride0)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
/// @}
|
| 52 |
+
/// @name Simple Operations
|
| 53 |
+
/// @{
|
| 54 |
+
|
| 55 |
+
/// empty - Check if the matrix is empty.
|
| 56 |
+
bool empty() const {
|
| 57 |
+
return arr.empty();
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
const T* data() const {
|
| 61 |
+
return arr.data();
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/// size - Get size a dimension
|
| 65 |
+
size_t size(size_t dim) const {
|
| 66 |
+
if (dim == 0) {
|
| 67 |
+
return arr.size() / stride0;
|
| 68 |
+
} else if (dim == 1) {
|
| 69 |
+
return stride0;
|
| 70 |
+
} else {
|
| 71 |
+
TORCH_CHECK(
|
| 72 |
+
0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
size_t numel() const {
|
| 77 |
+
return arr.size();
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
/// equals - Check for element-wise equality.
|
| 81 |
+
bool equals(MatrixRef RHS) const {
|
| 82 |
+
return stride0 == RHS.stride0 && arr.equals(RHS.arr);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
/// @}
|
| 86 |
+
/// @name Operator Overloads
|
| 87 |
+
/// @{
|
| 88 |
+
ArrayRef<T> operator[](size_t Index) const {
|
| 89 |
+
return arr.slice(Index * stride0, stride0);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
/// Disallow accidental assignment from a temporary.
|
| 93 |
+
///
|
| 94 |
+
/// The declaration here is extra complicated so that "arrayRef = {}"
|
| 95 |
+
/// continues to select the move assignment operator.
|
| 96 |
+
template <typename U>
|
| 97 |
+
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
|
| 98 |
+
U&& Temporary) = delete;
|
| 99 |
+
|
| 100 |
+
/// Disallow accidental assignment from a temporary.
|
| 101 |
+
///
|
| 102 |
+
/// The declaration here is extra complicated so that "arrayRef = {}"
|
| 103 |
+
/// continues to select the move assignment operator.
|
| 104 |
+
template <typename U>
|
| 105 |
+
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
|
| 106 |
+
std::initializer_list<U>) = delete;
|
| 107 |
+
};
|
| 108 |
+
|
| 109 |
+
} // end namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MetaFunctions.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/core/TensorBody.h>
|
| 2 |
+
|
| 3 |
+
// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
|
| 4 |
+
// Code introduced to avoid cyclic dependency in static dispatch is no longer
|
| 5 |
+
// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
|
| 6 |
+
// to Operators.cpp for supporting multiple backends with multiple kernels.
|
| 7 |
+
//
|
| 8 |
+
// Note [Avoiding Include Cycles In Static Dispatch]
|
| 9 |
+
// In order to avoid #include cycles in the static dispatch build, we've carefully split out
|
| 10 |
+
// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
|
| 11 |
+
//
|
| 12 |
+
// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
|
| 13 |
+
// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
|
| 14 |
+
// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
|
| 15 |
+
// directly inlined into TensorBody.h.
|
| 16 |
+
// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
|
| 17 |
+
// which include functions that have defaultable optional<Tensor> arguments.
|
| 18 |
+
// That requires knowing the full Tensor class definition.
|
| 19 |
+
//
|
| 20 |
+
// We break the cycle by doing the following:
|
| 21 |
+
// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
|
| 22 |
+
// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
|
| 23 |
+
// - CPUFunctions_inl.h includes everything else
|
| 24 |
+
// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
|
| 25 |
+
// and then it includes CPUFunctions_inl.h.
|
| 26 |
+
// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
|
| 27 |
+
// - This also means that static dispatch build, CPUFunctions.h only needs to
|
| 28 |
+
// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
|
| 29 |
+
#include <ATen/MetaFunctions_inl.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/MethodOperators.h
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from MethodOperators.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
// Forward declarations of any types needed in the operator signatures.
|
| 14 |
+
// We can't directly include these classes because it will cause circular include dependencies.
|
| 15 |
+
// This file is included by TensorBody.h, which defines the Tensor class.
|
| 16 |
+
#include <ATen/core/ATen_fwd.h>
|
| 17 |
+
|
| 18 |
+
#include <ATen/ops/_addmm_activation_ops.h>
|
| 19 |
+
#include <ATen/ops/_autocast_to_full_precision_ops.h>
|
| 20 |
+
#include <ATen/ops/_autocast_to_reduced_precision_ops.h>
|
| 21 |
+
#include <ATen/ops/_backward_ops.h>
|
| 22 |
+
#include <ATen/ops/_coalesced_ops.h>
|
| 23 |
+
#include <ATen/ops/_conj_ops.h>
|
| 24 |
+
#include <ATen/ops/_conj_physical_ops.h>
|
| 25 |
+
#include <ATen/ops/_dimI_ops.h>
|
| 26 |
+
#include <ATen/ops/_dimV_ops.h>
|
| 27 |
+
#include <ATen/ops/_fw_primal_ops.h>
|
| 28 |
+
#include <ATen/ops/_indices_ops.h>
|
| 29 |
+
#include <ATen/ops/_is_all_true_ops.h>
|
| 30 |
+
#include <ATen/ops/_is_any_true_ops.h>
|
| 31 |
+
#include <ATen/ops/_is_zerotensor_ops.h>
|
| 32 |
+
#include <ATen/ops/_lazy_clone_ops.h>
|
| 33 |
+
#include <ATen/ops/_neg_view_ops.h>
|
| 34 |
+
#include <ATen/ops/_nested_tensor_size_ops.h>
|
| 35 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_ops.h>
|
| 36 |
+
#include <ATen/ops/_nested_tensor_strides_ops.h>
|
| 37 |
+
#include <ATen/ops/_nnz_ops.h>
|
| 38 |
+
#include <ATen/ops/_reshape_alias_ops.h>
|
| 39 |
+
#include <ATen/ops/_sparse_mask_projection_ops.h>
|
| 40 |
+
#include <ATen/ops/_to_dense_ops.h>
|
| 41 |
+
#include <ATen/ops/_to_sparse_bsc_ops.h>
|
| 42 |
+
#include <ATen/ops/_to_sparse_bsr_ops.h>
|
| 43 |
+
#include <ATen/ops/_to_sparse_csc_ops.h>
|
| 44 |
+
#include <ATen/ops/_to_sparse_csr_ops.h>
|
| 45 |
+
#include <ATen/ops/_to_sparse_ops.h>
|
| 46 |
+
#include <ATen/ops/_values_ops.h>
|
| 47 |
+
#include <ATen/ops/_version_ops.h>
|
| 48 |
+
#include <ATen/ops/abs_ops.h>
|
| 49 |
+
#include <ATen/ops/absolute_ops.h>
|
| 50 |
+
#include <ATen/ops/acos_ops.h>
|
| 51 |
+
#include <ATen/ops/acosh_ops.h>
|
| 52 |
+
#include <ATen/ops/add_ops.h>
|
| 53 |
+
#include <ATen/ops/addbmm_ops.h>
|
| 54 |
+
#include <ATen/ops/addcdiv_ops.h>
|
| 55 |
+
#include <ATen/ops/addcmul_ops.h>
|
| 56 |
+
#include <ATen/ops/addmm_ops.h>
|
| 57 |
+
#include <ATen/ops/addmv_ops.h>
|
| 58 |
+
#include <ATen/ops/addr_ops.h>
|
| 59 |
+
#include <ATen/ops/adjoint_ops.h>
|
| 60 |
+
#include <ATen/ops/alias_ops.h>
|
| 61 |
+
#include <ATen/ops/align_as_ops.h>
|
| 62 |
+
#include <ATen/ops/align_to_ops.h>
|
| 63 |
+
#include <ATen/ops/all_ops.h>
|
| 64 |
+
#include <ATen/ops/allclose_ops.h>
|
| 65 |
+
#include <ATen/ops/amax_ops.h>
|
| 66 |
+
#include <ATen/ops/amin_ops.h>
|
| 67 |
+
#include <ATen/ops/aminmax_ops.h>
|
| 68 |
+
#include <ATen/ops/and_ops.h>
|
| 69 |
+
#include <ATen/ops/angle_ops.h>
|
| 70 |
+
#include <ATen/ops/any_ops.h>
|
| 71 |
+
#include <ATen/ops/arccos_ops.h>
|
| 72 |
+
#include <ATen/ops/arccosh_ops.h>
|
| 73 |
+
#include <ATen/ops/arcsin_ops.h>
|
| 74 |
+
#include <ATen/ops/arcsinh_ops.h>
|
| 75 |
+
#include <ATen/ops/arctan2_ops.h>
|
| 76 |
+
#include <ATen/ops/arctan_ops.h>
|
| 77 |
+
#include <ATen/ops/arctanh_ops.h>
|
| 78 |
+
#include <ATen/ops/argmax_ops.h>
|
| 79 |
+
#include <ATen/ops/argmin_ops.h>
|
| 80 |
+
#include <ATen/ops/argsort_ops.h>
|
| 81 |
+
#include <ATen/ops/argwhere_ops.h>
|
| 82 |
+
#include <ATen/ops/as_strided_ops.h>
|
| 83 |
+
#include <ATen/ops/as_strided_scatter_ops.h>
|
| 84 |
+
#include <ATen/ops/asin_ops.h>
|
| 85 |
+
#include <ATen/ops/asinh_ops.h>
|
| 86 |
+
#include <ATen/ops/atan2_ops.h>
|
| 87 |
+
#include <ATen/ops/atan_ops.h>
|
| 88 |
+
#include <ATen/ops/atanh_ops.h>
|
| 89 |
+
#include <ATen/ops/baddbmm_ops.h>
|
| 90 |
+
#include <ATen/ops/bernoulli_ops.h>
|
| 91 |
+
#include <ATen/ops/bincount_ops.h>
|
| 92 |
+
#include <ATen/ops/bitwise_and_ops.h>
|
| 93 |
+
#include <ATen/ops/bitwise_left_shift_ops.h>
|
| 94 |
+
#include <ATen/ops/bitwise_not_ops.h>
|
| 95 |
+
#include <ATen/ops/bitwise_or_ops.h>
|
| 96 |
+
#include <ATen/ops/bitwise_right_shift_ops.h>
|
| 97 |
+
#include <ATen/ops/bitwise_xor_ops.h>
|
| 98 |
+
#include <ATen/ops/bmm_ops.h>
|
| 99 |
+
#include <ATen/ops/broadcast_to_ops.h>
|
| 100 |
+
#include <ATen/ops/cauchy_ops.h>
|
| 101 |
+
#include <ATen/ops/ccol_indices_ops.h>
|
| 102 |
+
#include <ATen/ops/ceil_ops.h>
|
| 103 |
+
#include <ATen/ops/chalf_ops.h>
|
| 104 |
+
#include <ATen/ops/cholesky_inverse_ops.h>
|
| 105 |
+
#include <ATen/ops/cholesky_ops.h>
|
| 106 |
+
#include <ATen/ops/cholesky_solve_ops.h>
|
| 107 |
+
#include <ATen/ops/chunk_ops.h>
|
| 108 |
+
#include <ATen/ops/clamp_max_ops.h>
|
| 109 |
+
#include <ATen/ops/clamp_min_ops.h>
|
| 110 |
+
#include <ATen/ops/clamp_ops.h>
|
| 111 |
+
#include <ATen/ops/clip_ops.h>
|
| 112 |
+
#include <ATen/ops/clone_ops.h>
|
| 113 |
+
#include <ATen/ops/coalesce_ops.h>
|
| 114 |
+
#include <ATen/ops/col_indices_ops.h>
|
| 115 |
+
#include <ATen/ops/conj_ops.h>
|
| 116 |
+
#include <ATen/ops/conj_physical_ops.h>
|
| 117 |
+
#include <ATen/ops/contiguous_ops.h>
|
| 118 |
+
#include <ATen/ops/copy_ops.h>
|
| 119 |
+
#include <ATen/ops/copysign_ops.h>
|
| 120 |
+
#include <ATen/ops/corrcoef_ops.h>
|
| 121 |
+
#include <ATen/ops/cos_ops.h>
|
| 122 |
+
#include <ATen/ops/cosh_ops.h>
|
| 123 |
+
#include <ATen/ops/count_nonzero_ops.h>
|
| 124 |
+
#include <ATen/ops/cov_ops.h>
|
| 125 |
+
#include <ATen/ops/cross_ops.h>
|
| 126 |
+
#include <ATen/ops/crow_indices_ops.h>
|
| 127 |
+
#include <ATen/ops/cummax_ops.h>
|
| 128 |
+
#include <ATen/ops/cummin_ops.h>
|
| 129 |
+
#include <ATen/ops/cumprod_ops.h>
|
| 130 |
+
#include <ATen/ops/cumsum_ops.h>
|
| 131 |
+
#include <ATen/ops/data_ops.h>
|
| 132 |
+
#include <ATen/ops/deg2rad_ops.h>
|
| 133 |
+
#include <ATen/ops/dense_dim_ops.h>
|
| 134 |
+
#include <ATen/ops/dequantize_ops.h>
|
| 135 |
+
#include <ATen/ops/det_ops.h>
|
| 136 |
+
#include <ATen/ops/detach_ops.h>
|
| 137 |
+
#include <ATen/ops/diag_embed_ops.h>
|
| 138 |
+
#include <ATen/ops/diag_ops.h>
|
| 139 |
+
#include <ATen/ops/diagflat_ops.h>
|
| 140 |
+
#include <ATen/ops/diagonal_ops.h>
|
| 141 |
+
#include <ATen/ops/diagonal_scatter_ops.h>
|
| 142 |
+
#include <ATen/ops/diff_ops.h>
|
| 143 |
+
#include <ATen/ops/digamma_ops.h>
|
| 144 |
+
#include <ATen/ops/dist_ops.h>
|
| 145 |
+
#include <ATen/ops/div_ops.h>
|
| 146 |
+
#include <ATen/ops/divide_ops.h>
|
| 147 |
+
#include <ATen/ops/dot_ops.h>
|
| 148 |
+
#include <ATen/ops/dsplit_ops.h>
|
| 149 |
+
#include <ATen/ops/eq_ops.h>
|
| 150 |
+
#include <ATen/ops/equal_ops.h>
|
| 151 |
+
#include <ATen/ops/erf_ops.h>
|
| 152 |
+
#include <ATen/ops/erfc_ops.h>
|
| 153 |
+
#include <ATen/ops/erfinv_ops.h>
|
| 154 |
+
#include <ATen/ops/exp2_ops.h>
|
| 155 |
+
#include <ATen/ops/exp_ops.h>
|
| 156 |
+
#include <ATen/ops/expand_as_ops.h>
|
| 157 |
+
#include <ATen/ops/expand_ops.h>
|
| 158 |
+
#include <ATen/ops/expm1_ops.h>
|
| 159 |
+
#include <ATen/ops/exponential_ops.h>
|
| 160 |
+
#include <ATen/ops/fill_diagonal_ops.h>
|
| 161 |
+
#include <ATen/ops/fill_ops.h>
|
| 162 |
+
#include <ATen/ops/fix_ops.h>
|
| 163 |
+
#include <ATen/ops/flatten_ops.h>
|
| 164 |
+
#include <ATen/ops/flip_ops.h>
|
| 165 |
+
#include <ATen/ops/fliplr_ops.h>
|
| 166 |
+
#include <ATen/ops/flipud_ops.h>
|
| 167 |
+
#include <ATen/ops/float_power_ops.h>
|
| 168 |
+
#include <ATen/ops/floor_divide_ops.h>
|
| 169 |
+
#include <ATen/ops/floor_ops.h>
|
| 170 |
+
#include <ATen/ops/fmax_ops.h>
|
| 171 |
+
#include <ATen/ops/fmin_ops.h>
|
| 172 |
+
#include <ATen/ops/fmod_ops.h>
|
| 173 |
+
#include <ATen/ops/frac_ops.h>
|
| 174 |
+
#include <ATen/ops/frexp_ops.h>
|
| 175 |
+
#include <ATen/ops/gather_ops.h>
|
| 176 |
+
#include <ATen/ops/gcd_ops.h>
|
| 177 |
+
#include <ATen/ops/ge_ops.h>
|
| 178 |
+
#include <ATen/ops/geometric_ops.h>
|
| 179 |
+
#include <ATen/ops/geqrf_ops.h>
|
| 180 |
+
#include <ATen/ops/ger_ops.h>
|
| 181 |
+
#include <ATen/ops/greater_equal_ops.h>
|
| 182 |
+
#include <ATen/ops/greater_ops.h>
|
| 183 |
+
#include <ATen/ops/gt_ops.h>
|
| 184 |
+
#include <ATen/ops/hardshrink_backward_ops.h>
|
| 185 |
+
#include <ATen/ops/hardshrink_ops.h>
|
| 186 |
+
#include <ATen/ops/heaviside_ops.h>
|
| 187 |
+
#include <ATen/ops/histc_ops.h>
|
| 188 |
+
#include <ATen/ops/histogram_ops.h>
|
| 189 |
+
#include <ATen/ops/hsplit_ops.h>
|
| 190 |
+
#include <ATen/ops/hypot_ops.h>
|
| 191 |
+
#include <ATen/ops/i0_ops.h>
|
| 192 |
+
#include <ATen/ops/igamma_ops.h>
|
| 193 |
+
#include <ATen/ops/igammac_ops.h>
|
| 194 |
+
#include <ATen/ops/index_add_ops.h>
|
| 195 |
+
#include <ATen/ops/index_copy_ops.h>
|
| 196 |
+
#include <ATen/ops/index_fill_ops.h>
|
| 197 |
+
#include <ATen/ops/index_ops.h>
|
| 198 |
+
#include <ATen/ops/index_put_ops.h>
|
| 199 |
+
#include <ATen/ops/index_reduce_ops.h>
|
| 200 |
+
#include <ATen/ops/index_select_ops.h>
|
| 201 |
+
#include <ATen/ops/indices_ops.h>
|
| 202 |
+
#include <ATen/ops/inner_ops.h>
|
| 203 |
+
#include <ATen/ops/int_repr_ops.h>
|
| 204 |
+
#include <ATen/ops/inverse_ops.h>
|
| 205 |
+
#include <ATen/ops/is_coalesced_ops.h>
|
| 206 |
+
#include <ATen/ops/is_complex_ops.h>
|
| 207 |
+
#include <ATen/ops/is_conj_ops.h>
|
| 208 |
+
#include <ATen/ops/is_distributed_ops.h>
|
| 209 |
+
#include <ATen/ops/is_floating_point_ops.h>
|
| 210 |
+
#include <ATen/ops/is_inference_ops.h>
|
| 211 |
+
#include <ATen/ops/is_leaf_ops.h>
|
| 212 |
+
#include <ATen/ops/is_neg_ops.h>
|
| 213 |
+
#include <ATen/ops/is_nonzero_ops.h>
|
| 214 |
+
#include <ATen/ops/is_pinned_ops.h>
|
| 215 |
+
#include <ATen/ops/is_same_size_ops.h>
|
| 216 |
+
#include <ATen/ops/is_set_to_ops.h>
|
| 217 |
+
#include <ATen/ops/is_signed_ops.h>
|
| 218 |
+
#include <ATen/ops/isclose_ops.h>
|
| 219 |
+
#include <ATen/ops/isfinite_ops.h>
|
| 220 |
+
#include <ATen/ops/isinf_ops.h>
|
| 221 |
+
#include <ATen/ops/isnan_ops.h>
|
| 222 |
+
#include <ATen/ops/isneginf_ops.h>
|
| 223 |
+
#include <ATen/ops/isposinf_ops.h>
|
| 224 |
+
#include <ATen/ops/isreal_ops.h>
|
| 225 |
+
#include <ATen/ops/istft_ops.h>
|
| 226 |
+
#include <ATen/ops/item_ops.h>
|
| 227 |
+
#include <ATen/ops/kron_ops.h>
|
| 228 |
+
#include <ATen/ops/kthvalue_ops.h>
|
| 229 |
+
#include <ATen/ops/lcm_ops.h>
|
| 230 |
+
#include <ATen/ops/ldexp_ops.h>
|
| 231 |
+
#include <ATen/ops/le_ops.h>
|
| 232 |
+
#include <ATen/ops/lerp_ops.h>
|
| 233 |
+
#include <ATen/ops/less_equal_ops.h>
|
| 234 |
+
#include <ATen/ops/less_ops.h>
|
| 235 |
+
#include <ATen/ops/lgamma_ops.h>
|
| 236 |
+
#include <ATen/ops/log10_ops.h>
|
| 237 |
+
#include <ATen/ops/log1p_ops.h>
|
| 238 |
+
#include <ATen/ops/log2_ops.h>
|
| 239 |
+
#include <ATen/ops/log_normal_ops.h>
|
| 240 |
+
#include <ATen/ops/log_ops.h>
|
| 241 |
+
#include <ATen/ops/log_softmax_ops.h>
|
| 242 |
+
#include <ATen/ops/logaddexp2_ops.h>
|
| 243 |
+
#include <ATen/ops/logaddexp_ops.h>
|
| 244 |
+
#include <ATen/ops/logcumsumexp_ops.h>
|
| 245 |
+
#include <ATen/ops/logdet_ops.h>
|
| 246 |
+
#include <ATen/ops/logical_and_ops.h>
|
| 247 |
+
#include <ATen/ops/logical_not_ops.h>
|
| 248 |
+
#include <ATen/ops/logical_or_ops.h>
|
| 249 |
+
#include <ATen/ops/logical_xor_ops.h>
|
| 250 |
+
#include <ATen/ops/logit_ops.h>
|
| 251 |
+
#include <ATen/ops/logsumexp_ops.h>
|
| 252 |
+
#include <ATen/ops/lshift_ops.h>
|
| 253 |
+
#include <ATen/ops/lt_ops.h>
|
| 254 |
+
#include <ATen/ops/lu_solve_ops.h>
|
| 255 |
+
#include <ATen/ops/mH_ops.h>
|
| 256 |
+
#include <ATen/ops/mT_ops.h>
|
| 257 |
+
#include <ATen/ops/masked_fill_ops.h>
|
| 258 |
+
#include <ATen/ops/masked_scatter_ops.h>
|
| 259 |
+
#include <ATen/ops/masked_select_ops.h>
|
| 260 |
+
#include <ATen/ops/matmul_ops.h>
|
| 261 |
+
#include <ATen/ops/matrix_H_ops.h>
|
| 262 |
+
#include <ATen/ops/matrix_exp_ops.h>
|
| 263 |
+
#include <ATen/ops/matrix_power_ops.h>
|
| 264 |
+
#include <ATen/ops/max_ops.h>
|
| 265 |
+
#include <ATen/ops/maximum_ops.h>
|
| 266 |
+
#include <ATen/ops/mean_ops.h>
|
| 267 |
+
#include <ATen/ops/median_ops.h>
|
| 268 |
+
#include <ATen/ops/min_ops.h>
|
| 269 |
+
#include <ATen/ops/minimum_ops.h>
|
| 270 |
+
#include <ATen/ops/mm_ops.h>
|
| 271 |
+
#include <ATen/ops/mode_ops.h>
|
| 272 |
+
#include <ATen/ops/moveaxis_ops.h>
|
| 273 |
+
#include <ATen/ops/movedim_ops.h>
|
| 274 |
+
#include <ATen/ops/msort_ops.h>
|
| 275 |
+
#include <ATen/ops/mul_ops.h>
|
| 276 |
+
#include <ATen/ops/multinomial_ops.h>
|
| 277 |
+
#include <ATen/ops/multiply_ops.h>
|
| 278 |
+
#include <ATen/ops/mv_ops.h>
|
| 279 |
+
#include <ATen/ops/mvlgamma_ops.h>
|
| 280 |
+
#include <ATen/ops/nan_to_num_ops.h>
|
| 281 |
+
#include <ATen/ops/nanmean_ops.h>
|
| 282 |
+
#include <ATen/ops/nanmedian_ops.h>
|
| 283 |
+
#include <ATen/ops/nanquantile_ops.h>
|
| 284 |
+
#include <ATen/ops/nansum_ops.h>
|
| 285 |
+
#include <ATen/ops/narrow_copy_ops.h>
|
| 286 |
+
#include <ATen/ops/narrow_ops.h>
|
| 287 |
+
#include <ATen/ops/ne_ops.h>
|
| 288 |
+
#include <ATen/ops/neg_ops.h>
|
| 289 |
+
#include <ATen/ops/negative_ops.h>
|
| 290 |
+
#include <ATen/ops/new_empty_ops.h>
|
| 291 |
+
#include <ATen/ops/new_empty_strided_ops.h>
|
| 292 |
+
#include <ATen/ops/new_full_ops.h>
|
| 293 |
+
#include <ATen/ops/new_ones_ops.h>
|
| 294 |
+
#include <ATen/ops/new_zeros_ops.h>
|
| 295 |
+
#include <ATen/ops/nextafter_ops.h>
|
| 296 |
+
#include <ATen/ops/nonzero_numpy_ops.h>
|
| 297 |
+
#include <ATen/ops/nonzero_ops.h>
|
| 298 |
+
#include <ATen/ops/nonzero_static_ops.h>
|
| 299 |
+
#include <ATen/ops/norm_ops.h>
|
| 300 |
+
#include <ATen/ops/normal_ops.h>
|
| 301 |
+
#include <ATen/ops/not_equal_ops.h>
|
| 302 |
+
#include <ATen/ops/numpy_T_ops.h>
|
| 303 |
+
#include <ATen/ops/or_ops.h>
|
| 304 |
+
#include <ATen/ops/orgqr_ops.h>
|
| 305 |
+
#include <ATen/ops/ormqr_ops.h>
|
| 306 |
+
#include <ATen/ops/outer_ops.h>
|
| 307 |
+
#include <ATen/ops/output_nr_ops.h>
|
| 308 |
+
#include <ATen/ops/permute_ops.h>
|
| 309 |
+
#include <ATen/ops/pin_memory_ops.h>
|
| 310 |
+
#include <ATen/ops/pinverse_ops.h>
|
| 311 |
+
#include <ATen/ops/polygamma_ops.h>
|
| 312 |
+
#include <ATen/ops/positive_ops.h>
|
| 313 |
+
#include <ATen/ops/pow_ops.h>
|
| 314 |
+
#include <ATen/ops/prelu_ops.h>
|
| 315 |
+
#include <ATen/ops/prod_ops.h>
|
| 316 |
+
#include <ATen/ops/put_ops.h>
|
| 317 |
+
#include <ATen/ops/q_per_channel_axis_ops.h>
|
| 318 |
+
#include <ATen/ops/q_per_channel_scales_ops.h>
|
| 319 |
+
#include <ATen/ops/q_per_channel_zero_points_ops.h>
|
| 320 |
+
#include <ATen/ops/q_scale_ops.h>
|
| 321 |
+
#include <ATen/ops/q_zero_point_ops.h>
|
| 322 |
+
#include <ATen/ops/qr_ops.h>
|
| 323 |
+
#include <ATen/ops/qscheme_ops.h>
|
| 324 |
+
#include <ATen/ops/quantile_ops.h>
|
| 325 |
+
#include <ATen/ops/rad2deg_ops.h>
|
| 326 |
+
#include <ATen/ops/random_ops.h>
|
| 327 |
+
#include <ATen/ops/ravel_ops.h>
|
| 328 |
+
#include <ATen/ops/reciprocal_ops.h>
|
| 329 |
+
#include <ATen/ops/record_stream_ops.h>
|
| 330 |
+
#include <ATen/ops/refine_names_ops.h>
|
| 331 |
+
#include <ATen/ops/relu_ops.h>
|
| 332 |
+
#include <ATen/ops/remainder_ops.h>
|
| 333 |
+
#include <ATen/ops/rename_ops.h>
|
| 334 |
+
#include <ATen/ops/renorm_ops.h>
|
| 335 |
+
#include <ATen/ops/repeat_interleave_ops.h>
|
| 336 |
+
#include <ATen/ops/repeat_ops.h>
|
| 337 |
+
#include <ATen/ops/requires_grad_ops.h>
|
| 338 |
+
#include <ATen/ops/reshape_as_ops.h>
|
| 339 |
+
#include <ATen/ops/reshape_ops.h>
|
| 340 |
+
#include <ATen/ops/resize_as_ops.h>
|
| 341 |
+
#include <ATen/ops/resize_as_sparse_ops.h>
|
| 342 |
+
#include <ATen/ops/resize_ops.h>
|
| 343 |
+
#include <ATen/ops/resolve_conj_ops.h>
|
| 344 |
+
#include <ATen/ops/resolve_neg_ops.h>
|
| 345 |
+
#include <ATen/ops/retain_grad_ops.h>
|
| 346 |
+
#include <ATen/ops/retains_grad_ops.h>
|
| 347 |
+
#include <ATen/ops/roll_ops.h>
|
| 348 |
+
#include <ATen/ops/rot90_ops.h>
|
| 349 |
+
#include <ATen/ops/round_ops.h>
|
| 350 |
+
#include <ATen/ops/row_indices_ops.h>
|
| 351 |
+
#include <ATen/ops/rshift_ops.h>
|
| 352 |
+
#include <ATen/ops/rsqrt_ops.h>
|
| 353 |
+
#include <ATen/ops/scatter_add_ops.h>
|
| 354 |
+
#include <ATen/ops/scatter_ops.h>
|
| 355 |
+
#include <ATen/ops/scatter_reduce_ops.h>
|
| 356 |
+
#include <ATen/ops/select_ops.h>
|
| 357 |
+
#include <ATen/ops/select_scatter_ops.h>
|
| 358 |
+
#include <ATen/ops/set_data_ops.h>
|
| 359 |
+
#include <ATen/ops/set_ops.h>
|
| 360 |
+
#include <ATen/ops/sgn_ops.h>
|
| 361 |
+
#include <ATen/ops/sigmoid_ops.h>
|
| 362 |
+
#include <ATen/ops/sign_ops.h>
|
| 363 |
+
#include <ATen/ops/signbit_ops.h>
|
| 364 |
+
#include <ATen/ops/sin_ops.h>
|
| 365 |
+
#include <ATen/ops/sinc_ops.h>
|
| 366 |
+
#include <ATen/ops/sinh_ops.h>
|
| 367 |
+
#include <ATen/ops/size_ops.h>
|
| 368 |
+
#include <ATen/ops/slice_inverse_ops.h>
|
| 369 |
+
#include <ATen/ops/slice_ops.h>
|
| 370 |
+
#include <ATen/ops/slice_scatter_ops.h>
|
| 371 |
+
#include <ATen/ops/slogdet_ops.h>
|
| 372 |
+
#include <ATen/ops/smm_ops.h>
|
| 373 |
+
#include <ATen/ops/softmax_ops.h>
|
| 374 |
+
#include <ATen/ops/sort_ops.h>
|
| 375 |
+
#include <ATen/ops/sparse_dim_ops.h>
|
| 376 |
+
#include <ATen/ops/sparse_mask_ops.h>
|
| 377 |
+
#include <ATen/ops/sparse_resize_and_clear_ops.h>
|
| 378 |
+
#include <ATen/ops/sparse_resize_ops.h>
|
| 379 |
+
#include <ATen/ops/split_ops.h>
|
| 380 |
+
#include <ATen/ops/split_with_sizes_ops.h>
|
| 381 |
+
#include <ATen/ops/sqrt_ops.h>
|
| 382 |
+
#include <ATen/ops/square_ops.h>
|
| 383 |
+
#include <ATen/ops/squeeze_ops.h>
|
| 384 |
+
#include <ATen/ops/sspaddmm_ops.h>
|
| 385 |
+
#include <ATen/ops/std_ops.h>
|
| 386 |
+
#include <ATen/ops/stft_ops.h>
|
| 387 |
+
#include <ATen/ops/stride_ops.h>
|
| 388 |
+
#include <ATen/ops/sub_ops.h>
|
| 389 |
+
#include <ATen/ops/subtract_ops.h>
|
| 390 |
+
#include <ATen/ops/sum_ops.h>
|
| 391 |
+
#include <ATen/ops/sum_to_size_ops.h>
|
| 392 |
+
#include <ATen/ops/svd_ops.h>
|
| 393 |
+
#include <ATen/ops/swapaxes_ops.h>
|
| 394 |
+
#include <ATen/ops/swapdims_ops.h>
|
| 395 |
+
#include <ATen/ops/t_ops.h>
|
| 396 |
+
#include <ATen/ops/take_along_dim_ops.h>
|
| 397 |
+
#include <ATen/ops/take_ops.h>
|
| 398 |
+
#include <ATen/ops/tan_ops.h>
|
| 399 |
+
#include <ATen/ops/tanh_ops.h>
|
| 400 |
+
#include <ATen/ops/tensor_split_ops.h>
|
| 401 |
+
#include <ATen/ops/tile_ops.h>
|
| 402 |
+
#include <ATen/ops/to_dense_ops.h>
|
| 403 |
+
#include <ATen/ops/to_mkldnn_ops.h>
|
| 404 |
+
#include <ATen/ops/to_ops.h>
|
| 405 |
+
#include <ATen/ops/to_padded_tensor_ops.h>
|
| 406 |
+
#include <ATen/ops/to_sparse_bsc_ops.h>
|
| 407 |
+
#include <ATen/ops/to_sparse_bsr_ops.h>
|
| 408 |
+
#include <ATen/ops/to_sparse_csc_ops.h>
|
| 409 |
+
#include <ATen/ops/to_sparse_csr_ops.h>
|
| 410 |
+
#include <ATen/ops/to_sparse_ops.h>
|
| 411 |
+
#include <ATen/ops/topk_ops.h>
|
| 412 |
+
#include <ATen/ops/trace_ops.h>
|
| 413 |
+
#include <ATen/ops/transpose_ops.h>
|
| 414 |
+
#include <ATen/ops/triangular_solve_ops.h>
|
| 415 |
+
#include <ATen/ops/tril_ops.h>
|
| 416 |
+
#include <ATen/ops/triu_ops.h>
|
| 417 |
+
#include <ATen/ops/true_divide_ops.h>
|
| 418 |
+
#include <ATen/ops/trunc_ops.h>
|
| 419 |
+
#include <ATen/ops/type_as_ops.h>
|
| 420 |
+
#include <ATen/ops/unbind_ops.h>
|
| 421 |
+
#include <ATen/ops/unflatten_ops.h>
|
| 422 |
+
#include <ATen/ops/unfold_ops.h>
|
| 423 |
+
#include <ATen/ops/uniform_ops.h>
|
| 424 |
+
#include <ATen/ops/unsafe_chunk_ops.h>
|
| 425 |
+
#include <ATen/ops/unsafe_split_ops.h>
|
| 426 |
+
#include <ATen/ops/unsafe_split_with_sizes_ops.h>
|
| 427 |
+
#include <ATen/ops/unsqueeze_ops.h>
|
| 428 |
+
#include <ATen/ops/values_ops.h>
|
| 429 |
+
#include <ATen/ops/var_ops.h>
|
| 430 |
+
#include <ATen/ops/vdot_ops.h>
|
| 431 |
+
#include <ATen/ops/view_as_ops.h>
|
| 432 |
+
#include <ATen/ops/view_ops.h>
|
| 433 |
+
#include <ATen/ops/vsplit_ops.h>
|
| 434 |
+
#include <ATen/ops/where_ops.h>
|
| 435 |
+
#include <ATen/ops/xlogy_ops.h>
|
| 436 |
+
#include <ATen/ops/xor_ops.h>
|
| 437 |
+
#include <ATen/ops/zero_ops.h>
|
| 438 |
+
|
| 439 |
+
namespace at {
|
| 440 |
+
namespace _ops {
|
| 441 |
+
|
| 442 |
+
} // namespace _ops
|
| 443 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NamedTensorUtils.h
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/NamedTensor.h>
|
| 3 |
+
#include <ATen/TensorNames.h>
|
| 4 |
+
#include <ATen/WrapDimUtilsMulti.h>
|
| 5 |
+
|
| 6 |
+
#include <ATen/core/DimVector.h>
|
| 7 |
+
#include <ATen/core/Tensor.h>
|
| 8 |
+
#include <functional>
|
| 9 |
+
|
| 10 |
+
namespace at {
|
| 11 |
+
|
| 12 |
+
using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
|
| 13 |
+
|
| 14 |
+
inline bool has_names(const ITensorListRef& tensors) {
|
| 15 |
+
return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) {
|
| 16 |
+
return t.has_names();
|
| 17 |
+
});
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// Converts dim to an positional index. Errors if `dim` cannot be used to
|
| 21 |
+
// refer to any dimension of tensor.
|
| 22 |
+
TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
|
| 23 |
+
TORCH_API std::vector<int64_t> dimnames_to_positions(
|
| 24 |
+
const Tensor& tensor,
|
| 25 |
+
DimnameList dims);
|
| 26 |
+
|
| 27 |
+
// Unifies two DimnameList to produce a third. This is useful for implementing
|
| 28 |
+
// the named inference rule for binary broadcasting operations like add.
|
| 29 |
+
//
|
| 30 |
+
// There are three main constraints:
|
| 31 |
+
// 1) Check matching: Names must match positionally from the right.
|
| 32 |
+
// 2) Check misaligned: If a name `n` is in `names`, then it must appear at
|
| 33 |
+
// the same index from the right in other.
|
| 34 |
+
// 3) The output names are obtained by unifying the names individually from the
|
| 35 |
+
// right.
|
| 36 |
+
TORCH_API std::vector<Dimname> unify_from_right(
|
| 37 |
+
DimnameList names,
|
| 38 |
+
DimnameList other,
|
| 39 |
+
const char* action = "broadcast");
|
| 40 |
+
|
| 41 |
+
[[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
|
| 42 |
+
TORCH_CHECK(
|
| 43 |
+
false,
|
| 44 |
+
op_name,
|
| 45 |
+
": You passed a dimname (string) to this op in place of a dimension "
|
| 46 |
+
"index but it does not yet support this behavior. Please pass a dimension "
|
| 47 |
+
"index to work around this.");
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// [NOTE] Writing name inference rules
|
| 51 |
+
//
|
| 52 |
+
// Operators that support named tensors are either composed of operations that
|
| 53 |
+
// support named tensors or implement some name inference rule. An op that
|
| 54 |
+
// implements its own name inference rule generally looks like the following:
|
| 55 |
+
//
|
| 56 |
+
// Tensor op(...) {
|
| 57 |
+
// perform_shape_checks(...);
|
| 58 |
+
// # (1)
|
| 59 |
+
// auto maybe_outnames = compute_outnames(...);
|
| 60 |
+
// auto result = [&]() {
|
| 61 |
+
// NoNamesGuard guard;
|
| 62 |
+
// return op_impl(...);
|
| 63 |
+
// }();
|
| 64 |
+
// # (2)
|
| 65 |
+
// propagate_names_if_nonempty(result, maybe_outnames);
|
| 66 |
+
//
|
| 67 |
+
// Each op has (1) a compute outnames step and (2) a propagate names step.
|
| 68 |
+
//
|
| 69 |
+
// compute_outnames is responsible for checking that input names match and
|
| 70 |
+
// determining what the output names should be. It returns either:
|
| 71 |
+
// - {} (if the inputs tensors are all unnamed)
|
| 72 |
+
// - non-empty outnames.
|
| 73 |
+
//
|
| 74 |
+
// propagate_names_if_nonempty propagates the outnames if they exist to the
|
| 75 |
+
// result tensors.
|
| 76 |
+
//
|
| 77 |
+
// The {} case is an optimization; if the user does not use named tensors they
|
| 78 |
+
// pay no perf cost for it.
|
| 79 |
+
|
| 80 |
+
namespace namedinference {
|
| 81 |
+
|
| 82 |
+
const Tensor& propagate_names_if_present_and_nonempty(
|
| 83 |
+
const Tensor& result,
|
| 84 |
+
c10::optional<DimnameList> maybe_names,
|
| 85 |
+
bool validate_names = false);
|
| 86 |
+
// Propagates `names` to `result` if `names` is not empty.
|
| 87 |
+
// `names` can be empty; see [NOTE] Writing name inference rules
|
| 88 |
+
// If `names` is not empty, `names.size()` should equal `result.dim()`.
|
| 89 |
+
// When in doubt, use this overload instead of the others.
|
| 90 |
+
TORCH_API const Tensor& propagate_names_if_nonempty(
|
| 91 |
+
const Tensor& result,
|
| 92 |
+
DimnameList maybe_names,
|
| 93 |
+
bool validate_names = false);
|
| 94 |
+
|
| 95 |
+
// Propagates `names` to `result`. Only use this if we are certain that there
|
| 96 |
+
// are names to propagate (that names is not empty).
|
| 97 |
+
TORCH_API const Tensor& propagate_names(
|
| 98 |
+
const Tensor& result,
|
| 99 |
+
DimnameList names,
|
| 100 |
+
bool validate_names = false);
|
| 101 |
+
|
| 102 |
+
// Propagates all names from src to result.
|
| 103 |
+
TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
|
| 104 |
+
|
| 105 |
+
// Propagates all names except for those at the excluded_idxs.
|
| 106 |
+
TORCH_API void propagate_names_except(
|
| 107 |
+
const Tensor& result,
|
| 108 |
+
const Tensor& src,
|
| 109 |
+
IntArrayRef excluded_idxs);
|
| 110 |
+
|
| 111 |
+
// Used for reduction ops that have a `keepdim` arg.
|
| 112 |
+
TORCH_API void propagate_names_for_reduction(
|
| 113 |
+
const Tensor& result,
|
| 114 |
+
const Tensor& src,
|
| 115 |
+
IntArrayRef excluded_idxs,
|
| 116 |
+
bool keepdim);
|
| 117 |
+
|
| 118 |
+
TORCH_API void propagate_names_for_expand(
|
| 119 |
+
const Tensor& result,
|
| 120 |
+
const Tensor& self);
|
| 121 |
+
|
| 122 |
+
TORCH_API std::vector<Dimname> compute_cat_outnames(
|
| 123 |
+
const MaterializedITensorListRef& tensors);
|
| 124 |
+
|
| 125 |
+
TORCH_API std::vector<Dimname> compute_broadcast_outnames(
|
| 126 |
+
const Tensor& self,
|
| 127 |
+
const Tensor& other);
|
| 128 |
+
|
| 129 |
+
TORCH_API std::vector<Dimname> broadcast_to_outnames(
|
| 130 |
+
const Tensor& tensor,
|
| 131 |
+
const Tensor& reference_tensor,
|
| 132 |
+
const char* op_name);
|
| 133 |
+
|
| 134 |
+
TORCH_API std::vector<Dimname> compute_matmul_outnames(
|
| 135 |
+
const Tensor& self,
|
| 136 |
+
const Tensor& other);
|
| 137 |
+
|
| 138 |
+
TORCH_API std::vector<Dimname> compute_cdist_outnames(
|
| 139 |
+
const Tensor& self,
|
| 140 |
+
const Tensor& other);
|
| 141 |
+
|
| 142 |
+
TORCH_API std::vector<Dimname> compute_bmm_outnames(
|
| 143 |
+
const Tensor& result,
|
| 144 |
+
const Tensor& self,
|
| 145 |
+
const Tensor& other);
|
| 146 |
+
|
| 147 |
+
TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
|
| 148 |
+
TORCH_API std::vector<Dimname> compute_squeeze_outnames(
|
| 149 |
+
const Tensor& tensor,
|
| 150 |
+
std::bitset<dim_bitset_size> dims);
|
| 151 |
+
|
| 152 |
+
std::vector<Dimname> compute_diagonal_outnames(
|
| 153 |
+
const Tensor& tensor,
|
| 154 |
+
int64_t dim1,
|
| 155 |
+
int64_t dim2);
|
| 156 |
+
|
| 157 |
+
// TensorImpl* overloads for Legacy TH/THC code. Use these sparingly.
|
| 158 |
+
|
| 159 |
+
TORCH_API TensorImpl* propagate_names_if_nonempty(
|
| 160 |
+
TensorImpl* result,
|
| 161 |
+
DimnameList maybe_names,
|
| 162 |
+
bool validate_names = false);
|
| 163 |
+
|
| 164 |
+
TORCH_API TensorImpl* propagate_names(
|
| 165 |
+
TensorImpl* result,
|
| 166 |
+
DimnameList names,
|
| 167 |
+
bool validate_names = false);
|
| 168 |
+
|
| 169 |
+
TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src);
|
| 170 |
+
|
| 171 |
+
TORCH_API inline void propagate_names(
|
| 172 |
+
const TensorBase& result,
|
| 173 |
+
DimnameList names,
|
| 174 |
+
bool validate_names = false) {
|
| 175 |
+
propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
TORCH_API inline void propagate_names_if_nonempty(
|
| 179 |
+
const TensorBase& result,
|
| 180 |
+
DimnameList names,
|
| 181 |
+
bool validate_names = false) {
|
| 182 |
+
propagate_names_if_nonempty(
|
| 183 |
+
result.unsafeGetTensorImpl(), names, validate_names);
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
TORCH_API inline void propagate_names(
|
| 187 |
+
const TensorBase& result,
|
| 188 |
+
const TensorBase& src) {
|
| 189 |
+
propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
// result = m1 @ m2 + bias
|
| 193 |
+
TORCH_API std::vector<Dimname> propagate_names_for_addmm(
|
| 194 |
+
const Tensor& m1,
|
| 195 |
+
const Tensor& m2,
|
| 196 |
+
const Tensor& bias);
|
| 197 |
+
|
| 198 |
+
TORCH_API std::vector<Dimname> propagate_names_for_addmv(
|
| 199 |
+
const Tensor& mat,
|
| 200 |
+
const Tensor& vec,
|
| 201 |
+
const Tensor& bias);
|
| 202 |
+
|
| 203 |
+
TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
|
| 204 |
+
|
| 205 |
+
TORCH_API std::vector<Dimname> compute_baddbmm_outnames(
|
| 206 |
+
const Tensor& result,
|
| 207 |
+
const Tensor& self,
|
| 208 |
+
const Tensor& other,
|
| 209 |
+
const Tensor& bias);
|
| 210 |
+
|
| 211 |
+
TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other);
|
| 212 |
+
|
| 213 |
+
} // namespace namedinference
|
| 214 |
+
|
| 215 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/NativeFunctions.h
ADDED
|
@@ -0,0 +1,1317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// @generated by torchgen/gen.py from NativeFunctions.h
|
| 4 |
+
|
| 5 |
+
#ifdef TORCH_ASSERT_NO_OPERATORS
|
| 6 |
+
#error This change adds a dependency on native_functions.yaml, \
|
| 7 |
+
meaning the file will need to be re-compiled every time an operator \
|
| 8 |
+
is changed or added. Consider if your change would be better placed in \
|
| 9 |
+
another file, or if a more specific header might achieve the same goal. \
|
| 10 |
+
See NOTE: [Tensor vs. TensorBase]
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
|
| 14 |
+
#error This change adds a dependency on all pytorch operators, meaning the \
|
| 15 |
+
file will need to be re-compiled every time an operator is changed or added. \
|
| 16 |
+
Consider including a specific operator from <ATen/ops/{my_operator}_native.h> \
|
| 17 |
+
and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#include <c10/core/Scalar.h>
|
| 21 |
+
#include <c10/core/Storage.h>
|
| 22 |
+
#include <c10/core/TensorOptions.h>
|
| 23 |
+
#include <c10/util/Deprecated.h>
|
| 24 |
+
#include <c10/util/Optional.h>
|
| 25 |
+
#include <c10/core/QScheme.h>
|
| 26 |
+
#include <ATen/core/Reduction.h>
|
| 27 |
+
#include <ATen/core/Tensor.h>
|
| 28 |
+
#include <tuple>
|
| 29 |
+
#include <vector>
|
| 30 |
+
|
| 31 |
+
#include <ATen/ops/_adaptive_avg_pool2d_native.h>
|
| 32 |
+
#include <ATen/ops/_adaptive_avg_pool2d_backward_native.h>
|
| 33 |
+
#include <ATen/ops/_adaptive_avg_pool3d_native.h>
|
| 34 |
+
#include <ATen/ops/_adaptive_avg_pool3d_backward_native.h>
|
| 35 |
+
#include <ATen/ops/_add_batch_dim_native.h>
|
| 36 |
+
#include <ATen/ops/_add_relu_native.h>
|
| 37 |
+
#include <ATen/ops/_addmm_activation_native.h>
|
| 38 |
+
#include <ATen/ops/_aminmax_native.h>
|
| 39 |
+
#include <ATen/ops/_amp_foreach_non_finite_check_and_unscale_native.h>
|
| 40 |
+
#include <ATen/ops/_amp_update_scale_native.h>
|
| 41 |
+
#include <ATen/ops/_assert_async_native.h>
|
| 42 |
+
#include <ATen/ops/_assert_scalar_native.h>
|
| 43 |
+
#include <ATen/ops/_assert_tensor_metadata_native.h>
|
| 44 |
+
#include <ATen/ops/_autocast_to_full_precision_native.h>
|
| 45 |
+
#include <ATen/ops/_autocast_to_reduced_precision_native.h>
|
| 46 |
+
#include <ATen/ops/_backward_native.h>
|
| 47 |
+
#include <ATen/ops/_batch_norm_impl_index_native.h>
|
| 48 |
+
#include <ATen/ops/_batch_norm_impl_index_backward_native.h>
|
| 49 |
+
#include <ATen/ops/_cast_Byte_native.h>
|
| 50 |
+
#include <ATen/ops/_cast_Char_native.h>
|
| 51 |
+
#include <ATen/ops/_cast_Double_native.h>
|
| 52 |
+
#include <ATen/ops/_cast_Float_native.h>
|
| 53 |
+
#include <ATen/ops/_cast_Half_native.h>
|
| 54 |
+
#include <ATen/ops/_cast_Int_native.h>
|
| 55 |
+
#include <ATen/ops/_cast_Long_native.h>
|
| 56 |
+
#include <ATen/ops/_cast_Short_native.h>
|
| 57 |
+
#include <ATen/ops/_cdist_backward_native.h>
|
| 58 |
+
#include <ATen/ops/_cdist_forward_native.h>
|
| 59 |
+
#include <ATen/ops/_cholesky_solve_helper_native.h>
|
| 60 |
+
#include <ATen/ops/_choose_qparams_per_tensor_native.h>
|
| 61 |
+
#include <ATen/ops/_chunk_cat_native.h>
|
| 62 |
+
#include <ATen/ops/_coalesce_native.h>
|
| 63 |
+
#include <ATen/ops/_coalesced_native.h>
|
| 64 |
+
#include <ATen/ops/_compute_linear_combination_native.h>
|
| 65 |
+
#include <ATen/ops/_conj_native.h>
|
| 66 |
+
#include <ATen/ops/_conj_copy_native.h>
|
| 67 |
+
#include <ATen/ops/_conj_physical_native.h>
|
| 68 |
+
#include <ATen/ops/_conv_depthwise2d_native.h>
|
| 69 |
+
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
|
| 70 |
+
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
|
| 71 |
+
#include <ATen/ops/_convert_weight_to_int4pack_native.h>
|
| 72 |
+
#include <ATen/ops/_convolution_native.h>
|
| 73 |
+
#include <ATen/ops/_convolution_double_backward_native.h>
|
| 74 |
+
#include <ATen/ops/_convolution_mode_native.h>
|
| 75 |
+
#include <ATen/ops/_copy_from_native.h>
|
| 76 |
+
#include <ATen/ops/_copy_from_and_resize_native.h>
|
| 77 |
+
#include <ATen/ops/_cslt_compress_native.h>
|
| 78 |
+
#include <ATen/ops/_cslt_sparse_mm_native.h>
|
| 79 |
+
#include <ATen/ops/_cslt_sparse_mm_search_native.h>
|
| 80 |
+
#include <ATen/ops/_ctc_loss_native.h>
|
| 81 |
+
#include <ATen/ops/_ctc_loss_backward_native.h>
|
| 82 |
+
#include <ATen/ops/_cudnn_ctc_loss_native.h>
|
| 83 |
+
#include <ATen/ops/_cudnn_init_dropout_state_native.h>
|
| 84 |
+
#include <ATen/ops/_cudnn_rnn_native.h>
|
| 85 |
+
#include <ATen/ops/_cudnn_rnn_backward_native.h>
|
| 86 |
+
#include <ATen/ops/_cudnn_rnn_flatten_weight_native.h>
|
| 87 |
+
#include <ATen/ops/_cufft_clear_plan_cache_native.h>
|
| 88 |
+
#include <ATen/ops/_cufft_get_plan_cache_max_size_native.h>
|
| 89 |
+
#include <ATen/ops/_cufft_get_plan_cache_size_native.h>
|
| 90 |
+
#include <ATen/ops/_cufft_set_plan_cache_max_size_native.h>
|
| 91 |
+
#include <ATen/ops/_cummax_helper_native.h>
|
| 92 |
+
#include <ATen/ops/_cummin_helper_native.h>
|
| 93 |
+
#include <ATen/ops/_debug_has_internal_overlap_native.h>
|
| 94 |
+
#include <ATen/ops/_dimI_native.h>
|
| 95 |
+
#include <ATen/ops/_dimV_native.h>
|
| 96 |
+
#include <ATen/ops/_dim_arange_native.h>
|
| 97 |
+
#include <ATen/ops/_dirichlet_grad_native.h>
|
| 98 |
+
#include <ATen/ops/_efficient_attention_backward_native.h>
|
| 99 |
+
#include <ATen/ops/_efficient_attention_forward_native.h>
|
| 100 |
+
#include <ATen/ops/_efficientzerotensor_native.h>
|
| 101 |
+
#include <ATen/ops/_embedding_bag_native.h>
|
| 102 |
+
#include <ATen/ops/_embedding_bag_backward_native.h>
|
| 103 |
+
#include <ATen/ops/_embedding_bag_dense_backward_native.h>
|
| 104 |
+
#include <ATen/ops/_embedding_bag_forward_only_native.h>
|
| 105 |
+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
|
| 106 |
+
#include <ATen/ops/_embedding_bag_sparse_backward_native.h>
|
| 107 |
+
#include <ATen/ops/_empty_affine_quantized_native.h>
|
| 108 |
+
#include <ATen/ops/_empty_per_channel_affine_quantized_native.h>
|
| 109 |
+
#include <ATen/ops/_euclidean_dist_native.h>
|
| 110 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_native.h>
|
| 111 |
+
#include <ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_native.h>
|
| 112 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_native.h>
|
| 113 |
+
#include <ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_native.h>
|
| 114 |
+
#include <ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_native.h>
|
| 115 |
+
#include <ATen/ops/_fft_c2c_native.h>
|
| 116 |
+
#include <ATen/ops/_fft_c2r_native.h>
|
| 117 |
+
#include <ATen/ops/_fft_r2c_native.h>
|
| 118 |
+
#include <ATen/ops/_fill_mem_eff_dropout_mask_native.h>
|
| 119 |
+
#include <ATen/ops/_flash_attention_backward_native.h>
|
| 120 |
+
#include <ATen/ops/_flash_attention_forward_native.h>
|
| 121 |
+
#include <ATen/ops/_foobar_native.h>
|
| 122 |
+
#include <ATen/ops/_foreach_abs_native.h>
|
| 123 |
+
#include <ATen/ops/_foreach_acos_native.h>
|
| 124 |
+
#include <ATen/ops/_foreach_add_native.h>
|
| 125 |
+
#include <ATen/ops/_foreach_addcdiv_native.h>
|
| 126 |
+
#include <ATen/ops/_foreach_addcmul_native.h>
|
| 127 |
+
#include <ATen/ops/_foreach_asin_native.h>
|
| 128 |
+
#include <ATen/ops/_foreach_atan_native.h>
|
| 129 |
+
#include <ATen/ops/_foreach_ceil_native.h>
|
| 130 |
+
#include <ATen/ops/_foreach_clamp_max_native.h>
|
| 131 |
+
#include <ATen/ops/_foreach_clamp_min_native.h>
|
| 132 |
+
#include <ATen/ops/_foreach_copy_native.h>
|
| 133 |
+
#include <ATen/ops/_foreach_cos_native.h>
|
| 134 |
+
#include <ATen/ops/_foreach_cosh_native.h>
|
| 135 |
+
#include <ATen/ops/_foreach_div_native.h>
|
| 136 |
+
#include <ATen/ops/_foreach_erf_native.h>
|
| 137 |
+
#include <ATen/ops/_foreach_erfc_native.h>
|
| 138 |
+
#include <ATen/ops/_foreach_exp_native.h>
|
| 139 |
+
#include <ATen/ops/_foreach_expm1_native.h>
|
| 140 |
+
#include <ATen/ops/_foreach_floor_native.h>
|
| 141 |
+
#include <ATen/ops/_foreach_frac_native.h>
|
| 142 |
+
#include <ATen/ops/_foreach_lerp_native.h>
|
| 143 |
+
#include <ATen/ops/_foreach_lgamma_native.h>
|
| 144 |
+
#include <ATen/ops/_foreach_log_native.h>
|
| 145 |
+
#include <ATen/ops/_foreach_log10_native.h>
|
| 146 |
+
#include <ATen/ops/_foreach_log1p_native.h>
|
| 147 |
+
#include <ATen/ops/_foreach_log2_native.h>
|
| 148 |
+
#include <ATen/ops/_foreach_maximum_native.h>
|
| 149 |
+
#include <ATen/ops/_foreach_minimum_native.h>
|
| 150 |
+
#include <ATen/ops/_foreach_mul_native.h>
|
| 151 |
+
#include <ATen/ops/_foreach_neg_native.h>
|
| 152 |
+
#include <ATen/ops/_foreach_norm_native.h>
|
| 153 |
+
#include <ATen/ops/_foreach_pow_native.h>
|
| 154 |
+
#include <ATen/ops/_foreach_reciprocal_native.h>
|
| 155 |
+
#include <ATen/ops/_foreach_round_native.h>
|
| 156 |
+
#include <ATen/ops/_foreach_sigmoid_native.h>
|
| 157 |
+
#include <ATen/ops/_foreach_sign_native.h>
|
| 158 |
+
#include <ATen/ops/_foreach_sin_native.h>
|
| 159 |
+
#include <ATen/ops/_foreach_sinh_native.h>
|
| 160 |
+
#include <ATen/ops/_foreach_sqrt_native.h>
|
| 161 |
+
#include <ATen/ops/_foreach_sub_native.h>
|
| 162 |
+
#include <ATen/ops/_foreach_tan_native.h>
|
| 163 |
+
#include <ATen/ops/_foreach_tanh_native.h>
|
| 164 |
+
#include <ATen/ops/_foreach_trunc_native.h>
|
| 165 |
+
#include <ATen/ops/_foreach_zero_native.h>
|
| 166 |
+
#include <ATen/ops/_functional_assert_async_native.h>
|
| 167 |
+
#include <ATen/ops/_functional_assert_scalar_native.h>
|
| 168 |
+
#include <ATen/ops/_functional_sym_constrain_range_native.h>
|
| 169 |
+
#include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
|
| 170 |
+
#include <ATen/ops/_fused_adam_native.h>
|
| 171 |
+
#include <ATen/ops/_fused_adamw_native.h>
|
| 172 |
+
#include <ATen/ops/_fused_dropout_native.h>
|
| 173 |
+
#include <ATen/ops/_fused_moving_avg_obs_fq_helper_native.h>
|
| 174 |
+
#include <ATen/ops/_fused_sdp_choice_native.h>
|
| 175 |
+
#include <ATen/ops/_fused_sgd_native.h>
|
| 176 |
+
#include <ATen/ops/_fw_primal_native.h>
|
| 177 |
+
#include <ATen/ops/_fw_primal_copy_native.h>
|
| 178 |
+
#include <ATen/ops/_gather_sparse_backward_native.h>
|
| 179 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_native.h>
|
| 180 |
+
#include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_native.h>
|
| 181 |
+
#include <ATen/ops/_has_compatible_shallow_copy_type_native.h>
|
| 182 |
+
#include <ATen/ops/_has_same_storage_numel_native.h>
|
| 183 |
+
#include <ATen/ops/_histogramdd_bin_edges_native.h>
|
| 184 |
+
#include <ATen/ops/_histogramdd_from_bin_cts_native.h>
|
| 185 |
+
#include <ATen/ops/_histogramdd_from_bin_tensors_native.h>
|
| 186 |
+
#include <ATen/ops/_index_put_impl_native.h>
|
| 187 |
+
#include <ATen/ops/_indices_native.h>
|
| 188 |
+
#include <ATen/ops/_indices_copy_native.h>
|
| 189 |
+
#include <ATen/ops/_int_mm_native.h>
|
| 190 |
+
#include <ATen/ops/_is_all_true_native.h>
|
| 191 |
+
#include <ATen/ops/_is_any_true_native.h>
|
| 192 |
+
#include <ATen/ops/_is_zerotensor_native.h>
|
| 193 |
+
#include <ATen/ops/_lazy_clone_native.h>
|
| 194 |
+
#include <ATen/ops/_linalg_check_errors_native.h>
|
| 195 |
+
#include <ATen/ops/_linalg_det_native.h>
|
| 196 |
+
#include <ATen/ops/_linalg_eigh_native.h>
|
| 197 |
+
#include <ATen/ops/_linalg_eigvals_native.h>
|
| 198 |
+
#include <ATen/ops/_linalg_slogdet_native.h>
|
| 199 |
+
#include <ATen/ops/_linalg_solve_ex_native.h>
|
| 200 |
+
#include <ATen/ops/_linalg_svd_native.h>
|
| 201 |
+
#include <ATen/ops/_local_scalar_dense_native.h>
|
| 202 |
+
#include <ATen/ops/_log_softmax_native.h>
|
| 203 |
+
#include <ATen/ops/_log_softmax_backward_data_native.h>
|
| 204 |
+
#include <ATen/ops/_logcumsumexp_native.h>
|
| 205 |
+
#include <ATen/ops/_lstm_mps_native.h>
|
| 206 |
+
#include <ATen/ops/_lu_with_info_native.h>
|
| 207 |
+
#include <ATen/ops/_make_dep_token_native.h>
|
| 208 |
+
#include <ATen/ops/_make_dual_native.h>
|
| 209 |
+
#include <ATen/ops/_make_dual_copy_native.h>
|
| 210 |
+
#include <ATen/ops/_make_per_channel_quantized_tensor_native.h>
|
| 211 |
+
#include <ATen/ops/_make_per_tensor_quantized_tensor_native.h>
|
| 212 |
+
#include <ATen/ops/_masked_scale_native.h>
|
| 213 |
+
#include <ATen/ops/_masked_softmax_native.h>
|
| 214 |
+
#include <ATen/ops/_masked_softmax_backward_native.h>
|
| 215 |
+
#include <ATen/ops/_mixed_dtypes_linear_native.h>
|
| 216 |
+
#include <ATen/ops/_mkldnn_reshape_native.h>
|
| 217 |
+
#include <ATen/ops/_mkldnn_transpose_native.h>
|
| 218 |
+
#include <ATen/ops/_mps_convolution_native.h>
|
| 219 |
+
#include <ATen/ops/_mps_convolution_transpose_native.h>
|
| 220 |
+
#include <ATen/ops/_native_batch_norm_legit_native.h>
|
| 221 |
+
#include <ATen/ops/_native_batch_norm_legit_no_training_native.h>
|
| 222 |
+
#include <ATen/ops/_native_multi_head_attention_native.h>
|
| 223 |
+
#include <ATen/ops/_neg_view_native.h>
|
| 224 |
+
#include <ATen/ops/_neg_view_copy_native.h>
|
| 225 |
+
#include <ATen/ops/_nested_from_padded_native.h>
|
| 226 |
+
#include <ATen/ops/_nested_from_padded_and_nested_example_native.h>
|
| 227 |
+
#include <ATen/ops/_nested_get_jagged_dummy_native.h>
|
| 228 |
+
#include <ATen/ops/_nested_get_lengths_native.h>
|
| 229 |
+
#include <ATen/ops/_nested_get_offsets_native.h>
|
| 230 |
+
#include <ATen/ops/_nested_get_ragged_idx_native.h>
|
| 231 |
+
#include <ATen/ops/_nested_get_values_native.h>
|
| 232 |
+
#include <ATen/ops/_nested_get_values_copy_native.h>
|
| 233 |
+
#include <ATen/ops/_nested_select_backward_native.h>
|
| 234 |
+
#include <ATen/ops/_nested_sum_backward_native.h>
|
| 235 |
+
#include <ATen/ops/_nested_tensor_from_mask_native.h>
|
| 236 |
+
#include <ATen/ops/_nested_tensor_from_mask_left_aligned_native.h>
|
| 237 |
+
#include <ATen/ops/_nested_tensor_from_tensor_list_native.h>
|
| 238 |
+
#include <ATen/ops/_nested_tensor_size_native.h>
|
| 239 |
+
#include <ATen/ops/_nested_tensor_softmax_with_shape_native.h>
|
| 240 |
+
#include <ATen/ops/_nested_tensor_storage_offsets_native.h>
|
| 241 |
+
#include <ATen/ops/_nested_tensor_strides_native.h>
|
| 242 |
+
#include <ATen/ops/_nested_view_from_buffer_native.h>
|
| 243 |
+
#include <ATen/ops/_nested_view_from_buffer_copy_native.h>
|
| 244 |
+
#include <ATen/ops/_nested_view_from_jagged_native.h>
|
| 245 |
+
#include <ATen/ops/_nested_view_from_jagged_copy_native.h>
|
| 246 |
+
#include <ATen/ops/_new_zeros_with_same_feature_meta_native.h>
|
| 247 |
+
#include <ATen/ops/_nnpack_available_native.h>
|
| 248 |
+
#include <ATen/ops/_nnpack_spatial_convolution_native.h>
|
| 249 |
+
#include <ATen/ops/_nnz_native.h>
|
| 250 |
+
#include <ATen/ops/_pack_padded_sequence_native.h>
|
| 251 |
+
#include <ATen/ops/_pack_padded_sequence_backward_native.h>
|
| 252 |
+
#include <ATen/ops/_pad_circular_native.h>
|
| 253 |
+
#include <ATen/ops/_pad_enum_native.h>
|
| 254 |
+
#include <ATen/ops/_pad_packed_sequence_native.h>
|
| 255 |
+
#include <ATen/ops/_pdist_backward_native.h>
|
| 256 |
+
#include <ATen/ops/_pdist_forward_native.h>
|
| 257 |
+
#include <ATen/ops/_pin_memory_native.h>
|
| 258 |
+
#include <ATen/ops/_prelu_kernel_native.h>
|
| 259 |
+
#include <ATen/ops/_prelu_kernel_backward_native.h>
|
| 260 |
+
#include <ATen/ops/_print_native.h>
|
| 261 |
+
#include <ATen/ops/_propagate_xla_data_native.h>
|
| 262 |
+
#include <ATen/ops/_remove_batch_dim_native.h>
|
| 263 |
+
#include <ATen/ops/_reshape_alias_native.h>
|
| 264 |
+
#include <ATen/ops/_reshape_alias_copy_native.h>
|
| 265 |
+
#include <ATen/ops/_reshape_copy_native.h>
|
| 266 |
+
#include <ATen/ops/_reshape_from_tensor_native.h>
|
| 267 |
+
#include <ATen/ops/_resize_output_native.h>
|
| 268 |
+
#include <ATen/ops/_rowwise_prune_native.h>
|
| 269 |
+
#include <ATen/ops/_sample_dirichlet_native.h>
|
| 270 |
+
#include <ATen/ops/_saturate_weight_to_fp16_native.h>
|
| 271 |
+
#include <ATen/ops/_scaled_dot_product_attention_math_native.h>
|
| 272 |
+
#include <ATen/ops/_scaled_dot_product_cudnn_attention_native.h>
|
| 273 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_native.h>
|
| 274 |
+
#include <ATen/ops/_scaled_dot_product_efficient_attention_backward_native.h>
|
| 275 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_native.h>
|
| 276 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_native.h>
|
| 277 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_native.h>
|
| 278 |
+
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_native.h>
|
| 279 |
+
#include <ATen/ops/_scaled_mm_native.h>
|
| 280 |
+
#include <ATen/ops/_segment_reduce_backward_native.h>
|
| 281 |
+
#include <ATen/ops/_shape_as_tensor_native.h>
|
| 282 |
+
#include <ATen/ops/_slow_conv2d_backward_native.h>
|
| 283 |
+
#include <ATen/ops/_slow_conv2d_forward_native.h>
|
| 284 |
+
#include <ATen/ops/_sobol_engine_draw_native.h>
|
| 285 |
+
#include <ATen/ops/_sobol_engine_ff_native.h>
|
| 286 |
+
#include <ATen/ops/_sobol_engine_initialize_state_native.h>
|
| 287 |
+
#include <ATen/ops/_sobol_engine_scramble_native.h>
|
| 288 |
+
#include <ATen/ops/_softmax_native.h>
|
| 289 |
+
#include <ATen/ops/_softmax_backward_data_native.h>
|
| 290 |
+
#include <ATen/ops/_sparse_addmm_native.h>
|
| 291 |
+
#include <ATen/ops/_sparse_broadcast_to_native.h>
|
| 292 |
+
#include <ATen/ops/_sparse_broadcast_to_copy_native.h>
|
| 293 |
+
#include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
|
| 294 |
+
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
|
| 295 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
|
| 296 |
+
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
| 297 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_native.h>
|
| 298 |
+
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_native.h>
|
| 299 |
+
#include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
|
| 300 |
+
#include <ATen/ops/_sparse_csr_prod_native.h>
|
| 301 |
+
#include <ATen/ops/_sparse_csr_sum_native.h>
|
| 302 |
+
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
|
| 303 |
+
#include <ATen/ops/_sparse_log_softmax_native.h>
|
| 304 |
+
#include <ATen/ops/_sparse_log_softmax_backward_data_native.h>
|
| 305 |
+
#include <ATen/ops/_sparse_mask_projection_native.h>
|
| 306 |
+
#include <ATen/ops/_sparse_mm_native.h>
|
| 307 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_native.h>
|
| 308 |
+
#include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
|
| 309 |
+
#include <ATen/ops/_sparse_semi_structured_linear_native.h>
|
| 310 |
+
#include <ATen/ops/_sparse_softmax_native.h>
|
| 311 |
+
#include <ATen/ops/_sparse_softmax_backward_data_native.h>
|
| 312 |
+
#include <ATen/ops/_sparse_sparse_matmul_native.h>
|
| 313 |
+
#include <ATen/ops/_sparse_sum_native.h>
|
| 314 |
+
#include <ATen/ops/_sparse_sum_backward_native.h>
|
| 315 |
+
#include <ATen/ops/_spdiags_native.h>
|
| 316 |
+
#include <ATen/ops/_stack_native.h>
|
| 317 |
+
#include <ATen/ops/_standard_gamma_native.h>
|
| 318 |
+
#include <ATen/ops/_standard_gamma_grad_native.h>
|
| 319 |
+
#include <ATen/ops/_test_ambiguous_defaults_native.h>
|
| 320 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_native.h>
|
| 321 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_native.h>
|
| 322 |
+
#include <ATen/ops/_test_autograd_multiple_dispatch_view_copy_native.h>
|
| 323 |
+
#include <ATen/ops/_test_check_tensor_native.h>
|
| 324 |
+
#include <ATen/ops/_test_functorch_fallback_native.h>
|
| 325 |
+
#include <ATen/ops/_test_optional_filled_intlist_native.h>
|
| 326 |
+
#include <ATen/ops/_test_optional_floatlist_native.h>
|
| 327 |
+
#include <ATen/ops/_test_optional_intlist_native.h>
|
| 328 |
+
#include <ATen/ops/_test_parallel_materialize_native.h>
|
| 329 |
+
#include <ATen/ops/_test_serialization_subcmul_native.h>
|
| 330 |
+
#include <ATen/ops/_test_string_default_native.h>
|
| 331 |
+
#include <ATen/ops/_test_warn_in_autograd_native.h>
|
| 332 |
+
#include <ATen/ops/_thnn_differentiable_gru_cell_backward_native.h>
|
| 333 |
+
#include <ATen/ops/_thnn_differentiable_lstm_cell_backward_native.h>
|
| 334 |
+
#include <ATen/ops/_thnn_fused_gru_cell_native.h>
|
| 335 |
+
#include <ATen/ops/_thnn_fused_gru_cell_backward_native.h>
|
| 336 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_native.h>
|
| 337 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_native.h>
|
| 338 |
+
#include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_native.h>
|
| 339 |
+
#include <ATen/ops/_to_copy_native.h>
|
| 340 |
+
#include <ATen/ops/_to_cpu_native.h>
|
| 341 |
+
#include <ATen/ops/_to_dense_native.h>
|
| 342 |
+
#include <ATen/ops/_to_sparse_native.h>
|
| 343 |
+
#include <ATen/ops/_to_sparse_bsc_native.h>
|
| 344 |
+
#include <ATen/ops/_to_sparse_bsr_native.h>
|
| 345 |
+
#include <ATen/ops/_to_sparse_csc_native.h>
|
| 346 |
+
#include <ATen/ops/_to_sparse_csr_native.h>
|
| 347 |
+
#include <ATen/ops/_to_sparse_semi_structured_native.h>
|
| 348 |
+
#include <ATen/ops/_transform_bias_rescale_qkv_native.h>
|
| 349 |
+
#include <ATen/ops/_transformer_encoder_layer_fwd_native.h>
|
| 350 |
+
#include <ATen/ops/_trilinear_native.h>
|
| 351 |
+
#include <ATen/ops/_triton_multi_head_attention_native.h>
|
| 352 |
+
#include <ATen/ops/_triton_scaled_dot_attention_native.h>
|
| 353 |
+
#include <ATen/ops/_unique_native.h>
|
| 354 |
+
#include <ATen/ops/_unique2_native.h>
|
| 355 |
+
#include <ATen/ops/_unpack_dual_native.h>
|
| 356 |
+
#include <ATen/ops/_unsafe_index_native.h>
|
| 357 |
+
#include <ATen/ops/_unsafe_index_put_native.h>
|
| 358 |
+
#include <ATen/ops/_unsafe_view_native.h>
|
| 359 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_native.h>
|
| 360 |
+
#include <ATen/ops/_upsample_bicubic2d_aa_backward_native.h>
|
| 361 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_native.h>
|
| 362 |
+
#include <ATen/ops/_upsample_bilinear2d_aa_backward_native.h>
|
| 363 |
+
#include <ATen/ops/_upsample_nearest_exact1d_native.h>
|
| 364 |
+
#include <ATen/ops/_upsample_nearest_exact1d_backward_native.h>
|
| 365 |
+
#include <ATen/ops/_upsample_nearest_exact2d_native.h>
|
| 366 |
+
#include <ATen/ops/_upsample_nearest_exact2d_backward_native.h>
|
| 367 |
+
#include <ATen/ops/_upsample_nearest_exact3d_native.h>
|
| 368 |
+
#include <ATen/ops/_upsample_nearest_exact3d_backward_native.h>
|
| 369 |
+
#include <ATen/ops/_use_cudnn_ctc_loss_native.h>
|
| 370 |
+
#include <ATen/ops/_use_cudnn_rnn_flatten_weight_native.h>
|
| 371 |
+
#include <ATen/ops/_validate_compressed_sparse_indices_native.h>
|
| 372 |
+
#include <ATen/ops/_validate_sparse_bsc_tensor_args_native.h>
|
| 373 |
+
#include <ATen/ops/_validate_sparse_bsr_tensor_args_native.h>
|
| 374 |
+
#include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
|
| 375 |
+
#include <ATen/ops/_validate_sparse_coo_tensor_args_native.h>
|
| 376 |
+
#include <ATen/ops/_validate_sparse_csc_tensor_args_native.h>
|
| 377 |
+
#include <ATen/ops/_validate_sparse_csr_tensor_args_native.h>
|
| 378 |
+
#include <ATen/ops/_values_native.h>
|
| 379 |
+
#include <ATen/ops/_values_copy_native.h>
|
| 380 |
+
#include <ATen/ops/_version_native.h>
|
| 381 |
+
#include <ATen/ops/_weight_int4pack_mm_native.h>
|
| 382 |
+
#include <ATen/ops/_weight_int8pack_mm_native.h>
|
| 383 |
+
#include <ATen/ops/_weight_norm_native.h>
|
| 384 |
+
#include <ATen/ops/_weight_norm_differentiable_backward_native.h>
|
| 385 |
+
#include <ATen/ops/_weight_norm_interface_native.h>
|
| 386 |
+
#include <ATen/ops/_weight_norm_interface_backward_native.h>
|
| 387 |
+
#include <ATen/ops/abs_native.h>
|
| 388 |
+
#include <ATen/ops/absolute_native.h>
|
| 389 |
+
#include <ATen/ops/acos_native.h>
|
| 390 |
+
#include <ATen/ops/acosh_native.h>
|
| 391 |
+
#include <ATen/ops/adaptive_avg_pool1d_native.h>
|
| 392 |
+
#include <ATen/ops/adaptive_avg_pool2d_native.h>
|
| 393 |
+
#include <ATen/ops/adaptive_avg_pool3d_native.h>
|
| 394 |
+
#include <ATen/ops/adaptive_avg_pool3d_backward_native.h>
|
| 395 |
+
#include <ATen/ops/adaptive_max_pool1d_native.h>
|
| 396 |
+
#include <ATen/ops/adaptive_max_pool2d_native.h>
|
| 397 |
+
#include <ATen/ops/adaptive_max_pool2d_backward_native.h>
|
| 398 |
+
#include <ATen/ops/adaptive_max_pool3d_native.h>
|
| 399 |
+
#include <ATen/ops/adaptive_max_pool3d_backward_native.h>
|
| 400 |
+
#include <ATen/ops/add_native.h>
|
| 401 |
+
#include <ATen/ops/addbmm_native.h>
|
| 402 |
+
#include <ATen/ops/addcdiv_native.h>
|
| 403 |
+
#include <ATen/ops/addcmul_native.h>
|
| 404 |
+
#include <ATen/ops/addmm_native.h>
|
| 405 |
+
#include <ATen/ops/addmv_native.h>
|
| 406 |
+
#include <ATen/ops/addr_native.h>
|
| 407 |
+
#include <ATen/ops/adjoint_native.h>
|
| 408 |
+
#include <ATen/ops/affine_grid_generator_native.h>
|
| 409 |
+
#include <ATen/ops/affine_grid_generator_backward_native.h>
|
| 410 |
+
#include <ATen/ops/alias_native.h>
|
| 411 |
+
#include <ATen/ops/alias_copy_native.h>
|
| 412 |
+
#include <ATen/ops/align_as_native.h>
|
| 413 |
+
#include <ATen/ops/align_tensors_native.h>
|
| 414 |
+
#include <ATen/ops/align_to_native.h>
|
| 415 |
+
#include <ATen/ops/all_native.h>
|
| 416 |
+
#include <ATen/ops/allclose_native.h>
|
| 417 |
+
#include <ATen/ops/alpha_dropout_native.h>
|
| 418 |
+
#include <ATen/ops/amax_native.h>
|
| 419 |
+
#include <ATen/ops/amin_native.h>
|
| 420 |
+
#include <ATen/ops/aminmax_native.h>
|
| 421 |
+
#include <ATen/ops/and_native.h>
|
| 422 |
+
#include <ATen/ops/angle_native.h>
|
| 423 |
+
#include <ATen/ops/any_native.h>
|
| 424 |
+
#include <ATen/ops/arange_native.h>
|
| 425 |
+
#include <ATen/ops/arccos_native.h>
|
| 426 |
+
#include <ATen/ops/arccosh_native.h>
|
| 427 |
+
#include <ATen/ops/arcsin_native.h>
|
| 428 |
+
#include <ATen/ops/arcsinh_native.h>
|
| 429 |
+
#include <ATen/ops/arctan_native.h>
|
| 430 |
+
#include <ATen/ops/arctan2_native.h>
|
| 431 |
+
#include <ATen/ops/arctanh_native.h>
|
| 432 |
+
#include <ATen/ops/argmax_native.h>
|
| 433 |
+
#include <ATen/ops/argmin_native.h>
|
| 434 |
+
#include <ATen/ops/argsort_native.h>
|
| 435 |
+
#include <ATen/ops/argwhere_native.h>
|
| 436 |
+
#include <ATen/ops/as_strided_native.h>
|
| 437 |
+
#include <ATen/ops/as_strided_copy_native.h>
|
| 438 |
+
#include <ATen/ops/as_strided_scatter_native.h>
|
| 439 |
+
#include <ATen/ops/asin_native.h>
|
| 440 |
+
#include <ATen/ops/asinh_native.h>
|
| 441 |
+
#include <ATen/ops/atan_native.h>
|
| 442 |
+
#include <ATen/ops/atan2_native.h>
|
| 443 |
+
#include <ATen/ops/atanh_native.h>
|
| 444 |
+
#include <ATen/ops/atleast_1d_native.h>
|
| 445 |
+
#include <ATen/ops/atleast_2d_native.h>
|
| 446 |
+
#include <ATen/ops/atleast_3d_native.h>
|
| 447 |
+
#include <ATen/ops/avg_pool1d_native.h>
|
| 448 |
+
#include <ATen/ops/avg_pool2d_native.h>
|
| 449 |
+
#include <ATen/ops/avg_pool2d_backward_native.h>
|
| 450 |
+
#include <ATen/ops/avg_pool3d_native.h>
|
| 451 |
+
#include <ATen/ops/avg_pool3d_backward_native.h>
|
| 452 |
+
#include <ATen/ops/baddbmm_native.h>
|
| 453 |
+
#include <ATen/ops/bartlett_window_native.h>
|
| 454 |
+
#include <ATen/ops/batch_norm_native.h>
|
| 455 |
+
#include <ATen/ops/batch_norm_backward_elemt_native.h>
|
| 456 |
+
#include <ATen/ops/batch_norm_backward_reduce_native.h>
|
| 457 |
+
#include <ATen/ops/batch_norm_elemt_native.h>
|
| 458 |
+
#include <ATen/ops/batch_norm_gather_stats_native.h>
|
| 459 |
+
#include <ATen/ops/batch_norm_gather_stats_with_counts_native.h>
|
| 460 |
+
#include <ATen/ops/batch_norm_stats_native.h>
|
| 461 |
+
#include <ATen/ops/batch_norm_update_stats_native.h>
|
| 462 |
+
#include <ATen/ops/bernoulli_native.h>
|
| 463 |
+
#include <ATen/ops/bilinear_native.h>
|
| 464 |
+
#include <ATen/ops/binary_cross_entropy_native.h>
|
| 465 |
+
#include <ATen/ops/binary_cross_entropy_backward_native.h>
|
| 466 |
+
#include <ATen/ops/binary_cross_entropy_with_logits_native.h>
|
| 467 |
+
#include <ATen/ops/bincount_native.h>
|
| 468 |
+
#include <ATen/ops/binomial_native.h>
|
| 469 |
+
#include <ATen/ops/bitwise_and_native.h>
|
| 470 |
+
#include <ATen/ops/bitwise_left_shift_native.h>
|
| 471 |
+
#include <ATen/ops/bitwise_not_native.h>
|
| 472 |
+
#include <ATen/ops/bitwise_or_native.h>
|
| 473 |
+
#include <ATen/ops/bitwise_right_shift_native.h>
|
| 474 |
+
#include <ATen/ops/bitwise_xor_native.h>
|
| 475 |
+
#include <ATen/ops/blackman_window_native.h>
|
| 476 |
+
#include <ATen/ops/block_diag_native.h>
|
| 477 |
+
#include <ATen/ops/bmm_native.h>
|
| 478 |
+
#include <ATen/ops/broadcast_tensors_native.h>
|
| 479 |
+
#include <ATen/ops/broadcast_to_native.h>
|
| 480 |
+
#include <ATen/ops/bucketize_native.h>
|
| 481 |
+
#include <ATen/ops/can_cast_native.h>
|
| 482 |
+
#include <ATen/ops/cartesian_prod_native.h>
|
| 483 |
+
#include <ATen/ops/cat_native.h>
|
| 484 |
+
#include <ATen/ops/cauchy_native.h>
|
| 485 |
+
#include <ATen/ops/ccol_indices_native.h>
|
| 486 |
+
#include <ATen/ops/ccol_indices_copy_native.h>
|
| 487 |
+
#include <ATen/ops/cdist_native.h>
|
| 488 |
+
#include <ATen/ops/ceil_native.h>
|
| 489 |
+
#include <ATen/ops/celu_native.h>
|
| 490 |
+
#include <ATen/ops/chain_matmul_native.h>
|
| 491 |
+
#include <ATen/ops/chalf_native.h>
|
| 492 |
+
#include <ATen/ops/channel_shuffle_native.h>
|
| 493 |
+
#include <ATen/ops/cholesky_native.h>
|
| 494 |
+
#include <ATen/ops/cholesky_inverse_native.h>
|
| 495 |
+
#include <ATen/ops/cholesky_solve_native.h>
|
| 496 |
+
#include <ATen/ops/choose_qparams_optimized_native.h>
|
| 497 |
+
#include <ATen/ops/chunk_native.h>
|
| 498 |
+
#include <ATen/ops/clamp_native.h>
|
| 499 |
+
#include <ATen/ops/clamp_max_native.h>
|
| 500 |
+
#include <ATen/ops/clamp_min_native.h>
|
| 501 |
+
#include <ATen/ops/clip_native.h>
|
| 502 |
+
#include <ATen/ops/clone_native.h>
|
| 503 |
+
#include <ATen/ops/coalesce_native.h>
|
| 504 |
+
#include <ATen/ops/col2im_native.h>
|
| 505 |
+
#include <ATen/ops/col_indices_native.h>
|
| 506 |
+
#include <ATen/ops/col_indices_copy_native.h>
|
| 507 |
+
#include <ATen/ops/column_stack_native.h>
|
| 508 |
+
#include <ATen/ops/combinations_native.h>
|
| 509 |
+
#include <ATen/ops/complex_native.h>
|
| 510 |
+
#include <ATen/ops/concat_native.h>
|
| 511 |
+
#include <ATen/ops/concatenate_native.h>
|
| 512 |
+
#include <ATen/ops/conj_native.h>
|
| 513 |
+
#include <ATen/ops/conj_physical_native.h>
|
| 514 |
+
#include <ATen/ops/constant_pad_nd_native.h>
|
| 515 |
+
#include <ATen/ops/contiguous_native.h>
|
| 516 |
+
#include <ATen/ops/conv1d_native.h>
|
| 517 |
+
#include <ATen/ops/conv2d_native.h>
|
| 518 |
+
#include <ATen/ops/conv3d_native.h>
|
| 519 |
+
#include <ATen/ops/conv_depthwise3d_native.h>
|
| 520 |
+
#include <ATen/ops/conv_tbc_native.h>
|
| 521 |
+
#include <ATen/ops/conv_tbc_backward_native.h>
|
| 522 |
+
#include <ATen/ops/conv_transpose1d_native.h>
|
| 523 |
+
#include <ATen/ops/conv_transpose2d_native.h>
|
| 524 |
+
#include <ATen/ops/conv_transpose3d_native.h>
|
| 525 |
+
#include <ATen/ops/convolution_native.h>
|
| 526 |
+
#include <ATen/ops/convolution_backward_native.h>
|
| 527 |
+
#include <ATen/ops/convolution_backward_overrideable_native.h>
|
| 528 |
+
#include <ATen/ops/convolution_overrideable_native.h>
|
| 529 |
+
#include <ATen/ops/copy_native.h>
|
| 530 |
+
#include <ATen/ops/copy_sparse_to_sparse_native.h>
|
| 531 |
+
#include <ATen/ops/copysign_native.h>
|
| 532 |
+
#include <ATen/ops/corrcoef_native.h>
|
| 533 |
+
#include <ATen/ops/cos_native.h>
|
| 534 |
+
#include <ATen/ops/cosh_native.h>
|
| 535 |
+
#include <ATen/ops/cosine_embedding_loss_native.h>
|
| 536 |
+
#include <ATen/ops/cosine_similarity_native.h>
|
| 537 |
+
#include <ATen/ops/count_nonzero_native.h>
|
| 538 |
+
#include <ATen/ops/cov_native.h>
|
| 539 |
+
#include <ATen/ops/cross_native.h>
|
| 540 |
+
#include <ATen/ops/cross_entropy_loss_native.h>
|
| 541 |
+
#include <ATen/ops/crow_indices_native.h>
|
| 542 |
+
#include <ATen/ops/crow_indices_copy_native.h>
|
| 543 |
+
#include <ATen/ops/ctc_loss_native.h>
|
| 544 |
+
#include <ATen/ops/cudnn_affine_grid_generator_native.h>
|
| 545 |
+
#include <ATen/ops/cudnn_affine_grid_generator_backward_native.h>
|
| 546 |
+
#include <ATen/ops/cudnn_batch_norm_native.h>
|
| 547 |
+
#include <ATen/ops/cudnn_batch_norm_backward_native.h>
|
| 548 |
+
#include <ATen/ops/cudnn_convolution_native.h>
|
| 549 |
+
#include <ATen/ops/cudnn_convolution_add_relu_native.h>
|
| 550 |
+
#include <ATen/ops/cudnn_convolution_relu_native.h>
|
| 551 |
+
#include <ATen/ops/cudnn_convolution_transpose_native.h>
|
| 552 |
+
#include <ATen/ops/cudnn_grid_sampler_native.h>
|
| 553 |
+
#include <ATen/ops/cudnn_grid_sampler_backward_native.h>
|
| 554 |
+
#include <ATen/ops/cudnn_is_acceptable_native.h>
|
| 555 |
+
#include <ATen/ops/cummax_native.h>
|
| 556 |
+
#include <ATen/ops/cummaxmin_backward_native.h>
|
| 557 |
+
#include <ATen/ops/cummin_native.h>
|
| 558 |
+
#include <ATen/ops/cumprod_native.h>
|
| 559 |
+
#include <ATen/ops/cumprod_backward_native.h>
|
| 560 |
+
#include <ATen/ops/cumsum_native.h>
|
| 561 |
+
#include <ATen/ops/cumulative_trapezoid_native.h>
|
| 562 |
+
#include <ATen/ops/data_native.h>
|
| 563 |
+
#include <ATen/ops/deg2rad_native.h>
|
| 564 |
+
#include <ATen/ops/dense_dim_native.h>
|
| 565 |
+
#include <ATen/ops/dequantize_native.h>
|
| 566 |
+
#include <ATen/ops/det_native.h>
|
| 567 |
+
#include <ATen/ops/detach_native.h>
|
| 568 |
+
#include <ATen/ops/detach_copy_native.h>
|
| 569 |
+
#include <ATen/ops/diag_native.h>
|
| 570 |
+
#include <ATen/ops/diag_embed_native.h>
|
| 571 |
+
#include <ATen/ops/diagflat_native.h>
|
| 572 |
+
#include <ATen/ops/diagonal_native.h>
|
| 573 |
+
#include <ATen/ops/diagonal_backward_native.h>
|
| 574 |
+
#include <ATen/ops/diagonal_copy_native.h>
|
| 575 |
+
#include <ATen/ops/diagonal_scatter_native.h>
|
| 576 |
+
#include <ATen/ops/diff_native.h>
|
| 577 |
+
#include <ATen/ops/digamma_native.h>
|
| 578 |
+
#include <ATen/ops/dist_native.h>
|
| 579 |
+
#include <ATen/ops/div_native.h>
|
| 580 |
+
#include <ATen/ops/divide_native.h>
|
| 581 |
+
#include <ATen/ops/dot_native.h>
|
| 582 |
+
#include <ATen/ops/dropout_native.h>
|
| 583 |
+
#include <ATen/ops/dsplit_native.h>
|
| 584 |
+
#include <ATen/ops/dstack_native.h>
|
| 585 |
+
#include <ATen/ops/einsum_native.h>
|
| 586 |
+
#include <ATen/ops/elu_native.h>
|
| 587 |
+
#include <ATen/ops/elu_backward_native.h>
|
| 588 |
+
#include <ATen/ops/embedding_native.h>
|
| 589 |
+
#include <ATen/ops/embedding_backward_native.h>
|
| 590 |
+
#include <ATen/ops/embedding_bag_native.h>
|
| 591 |
+
#include <ATen/ops/embedding_dense_backward_native.h>
|
| 592 |
+
#include <ATen/ops/embedding_renorm_native.h>
|
| 593 |
+
#include <ATen/ops/embedding_sparse_backward_native.h>
|
| 594 |
+
#include <ATen/ops/empty_native.h>
|
| 595 |
+
#include <ATen/ops/empty_like_native.h>
|
| 596 |
+
#include <ATen/ops/empty_permuted_native.h>
|
| 597 |
+
#include <ATen/ops/empty_quantized_native.h>
|
| 598 |
+
#include <ATen/ops/empty_strided_native.h>
|
| 599 |
+
#include <ATen/ops/eq_native.h>
|
| 600 |
+
#include <ATen/ops/equal_native.h>
|
| 601 |
+
#include <ATen/ops/erf_native.h>
|
| 602 |
+
#include <ATen/ops/erfc_native.h>
|
| 603 |
+
#include <ATen/ops/erfinv_native.h>
|
| 604 |
+
#include <ATen/ops/exp_native.h>
|
| 605 |
+
#include <ATen/ops/exp2_native.h>
|
| 606 |
+
#include <ATen/ops/expand_native.h>
|
| 607 |
+
#include <ATen/ops/expand_as_native.h>
|
| 608 |
+
#include <ATen/ops/expand_copy_native.h>
|
| 609 |
+
#include <ATen/ops/expm1_native.h>
|
| 610 |
+
#include <ATen/ops/exponential_native.h>
|
| 611 |
+
#include <ATen/ops/eye_native.h>
|
| 612 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_native.h>
|
| 613 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_native.h>
|
| 614 |
+
#include <ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_native.h>
|
| 615 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_native.h>
|
| 616 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_native.h>
|
| 617 |
+
#include <ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_native.h>
|
| 618 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
|
| 619 |
+
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
|
| 620 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_native.h>
|
| 621 |
+
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_native.h>
|
| 622 |
+
#include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
|
| 623 |
+
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
|
| 624 |
+
#include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
|
| 625 |
+
#include <ATen/ops/feature_alpha_dropout_native.h>
|
| 626 |
+
#include <ATen/ops/feature_dropout_native.h>
|
| 627 |
+
#include <ATen/ops/fft_fft_native.h>
|
| 628 |
+
#include <ATen/ops/fft_fft2_native.h>
|
| 629 |
+
#include <ATen/ops/fft_fftfreq_native.h>
|
| 630 |
+
#include <ATen/ops/fft_fftn_native.h>
|
| 631 |
+
#include <ATen/ops/fft_fftshift_native.h>
|
| 632 |
+
#include <ATen/ops/fft_hfft_native.h>
|
| 633 |
+
#include <ATen/ops/fft_hfft2_native.h>
|
| 634 |
+
#include <ATen/ops/fft_hfftn_native.h>
|
| 635 |
+
#include <ATen/ops/fft_ifft_native.h>
|
| 636 |
+
#include <ATen/ops/fft_ifft2_native.h>
|
| 637 |
+
#include <ATen/ops/fft_ifftn_native.h>
|
| 638 |
+
#include <ATen/ops/fft_ifftshift_native.h>
|
| 639 |
+
#include <ATen/ops/fft_ihfft_native.h>
|
| 640 |
+
#include <ATen/ops/fft_ihfft2_native.h>
|
| 641 |
+
#include <ATen/ops/fft_ihfftn_native.h>
|
| 642 |
+
#include <ATen/ops/fft_irfft_native.h>
|
| 643 |
+
#include <ATen/ops/fft_irfft2_native.h>
|
| 644 |
+
#include <ATen/ops/fft_irfftn_native.h>
|
| 645 |
+
#include <ATen/ops/fft_rfft_native.h>
|
| 646 |
+
#include <ATen/ops/fft_rfft2_native.h>
|
| 647 |
+
#include <ATen/ops/fft_rfftfreq_native.h>
|
| 648 |
+
#include <ATen/ops/fft_rfftn_native.h>
|
| 649 |
+
#include <ATen/ops/fill_native.h>
|
| 650 |
+
#include <ATen/ops/fill_diagonal_native.h>
|
| 651 |
+
#include <ATen/ops/fix_native.h>
|
| 652 |
+
#include <ATen/ops/flatten_native.h>
|
| 653 |
+
#include <ATen/ops/flatten_dense_tensors_native.h>
|
| 654 |
+
#include <ATen/ops/flip_native.h>
|
| 655 |
+
#include <ATen/ops/fliplr_native.h>
|
| 656 |
+
#include <ATen/ops/flipud_native.h>
|
| 657 |
+
#include <ATen/ops/float_power_native.h>
|
| 658 |
+
#include <ATen/ops/floor_native.h>
|
| 659 |
+
#include <ATen/ops/floor_divide_native.h>
|
| 660 |
+
#include <ATen/ops/fmax_native.h>
|
| 661 |
+
#include <ATen/ops/fmin_native.h>
|
| 662 |
+
#include <ATen/ops/fmod_native.h>
|
| 663 |
+
#include <ATen/ops/frac_native.h>
|
| 664 |
+
#include <ATen/ops/fractional_max_pool2d_native.h>
|
| 665 |
+
#include <ATen/ops/fractional_max_pool2d_backward_native.h>
|
| 666 |
+
#include <ATen/ops/fractional_max_pool3d_native.h>
|
| 667 |
+
#include <ATen/ops/fractional_max_pool3d_backward_native.h>
|
| 668 |
+
#include <ATen/ops/frexp_native.h>
|
| 669 |
+
#include <ATen/ops/frobenius_norm_native.h>
|
| 670 |
+
#include <ATen/ops/from_file_native.h>
|
| 671 |
+
#include <ATen/ops/full_native.h>
|
| 672 |
+
#include <ATen/ops/full_like_native.h>
|
| 673 |
+
#include <ATen/ops/fused_moving_avg_obs_fake_quant_native.h>
|
| 674 |
+
#include <ATen/ops/gather_native.h>
|
| 675 |
+
#include <ATen/ops/gather_backward_native.h>
|
| 676 |
+
#include <ATen/ops/gcd_native.h>
|
| 677 |
+
#include <ATen/ops/ge_native.h>
|
| 678 |
+
#include <ATen/ops/gelu_native.h>
|
| 679 |
+
#include <ATen/ops/gelu_backward_native.h>
|
| 680 |
+
#include <ATen/ops/geometric_native.h>
|
| 681 |
+
#include <ATen/ops/geqrf_native.h>
|
| 682 |
+
#include <ATen/ops/ger_native.h>
|
| 683 |
+
#include <ATen/ops/glu_native.h>
|
| 684 |
+
#include <ATen/ops/glu_backward_native.h>
|
| 685 |
+
#include <ATen/ops/glu_backward_jvp_native.h>
|
| 686 |
+
#include <ATen/ops/glu_jvp_native.h>
|
| 687 |
+
#include <ATen/ops/gradient_native.h>
|
| 688 |
+
#include <ATen/ops/greater_native.h>
|
| 689 |
+
#include <ATen/ops/greater_equal_native.h>
|
| 690 |
+
#include <ATen/ops/grid_sampler_native.h>
|
| 691 |
+
#include <ATen/ops/grid_sampler_2d_native.h>
|
| 692 |
+
#include <ATen/ops/grid_sampler_2d_backward_native.h>
|
| 693 |
+
#include <ATen/ops/grid_sampler_3d_native.h>
|
| 694 |
+
#include <ATen/ops/grid_sampler_3d_backward_native.h>
|
| 695 |
+
#include <ATen/ops/group_norm_native.h>
|
| 696 |
+
#include <ATen/ops/gru_native.h>
|
| 697 |
+
#include <ATen/ops/gru_cell_native.h>
|
| 698 |
+
#include <ATen/ops/gt_native.h>
|
| 699 |
+
#include <ATen/ops/hamming_window_native.h>
|
| 700 |
+
#include <ATen/ops/hann_window_native.h>
|
| 701 |
+
#include <ATen/ops/hardshrink_native.h>
|
| 702 |
+
#include <ATen/ops/hardshrink_backward_native.h>
|
| 703 |
+
#include <ATen/ops/hardsigmoid_native.h>
|
| 704 |
+
#include <ATen/ops/hardsigmoid_backward_native.h>
|
| 705 |
+
#include <ATen/ops/hardswish_native.h>
|
| 706 |
+
#include <ATen/ops/hardswish_backward_native.h>
|
| 707 |
+
#include <ATen/ops/hardtanh_native.h>
|
| 708 |
+
#include <ATen/ops/hardtanh_backward_native.h>
|
| 709 |
+
#include <ATen/ops/heaviside_native.h>
|
| 710 |
+
#include <ATen/ops/hinge_embedding_loss_native.h>
|
| 711 |
+
#include <ATen/ops/histc_native.h>
|
| 712 |
+
#include <ATen/ops/histogram_native.h>
|
| 713 |
+
#include <ATen/ops/histogramdd_native.h>
|
| 714 |
+
#include <ATen/ops/hsplit_native.h>
|
| 715 |
+
#include <ATen/ops/hspmm_native.h>
|
| 716 |
+
#include <ATen/ops/hstack_native.h>
|
| 717 |
+
#include <ATen/ops/huber_loss_native.h>
|
| 718 |
+
#include <ATen/ops/huber_loss_backward_native.h>
|
| 719 |
+
#include <ATen/ops/hypot_native.h>
|
| 720 |
+
#include <ATen/ops/i0_native.h>
|
| 721 |
+
#include <ATen/ops/igamma_native.h>
|
| 722 |
+
#include <ATen/ops/igammac_native.h>
|
| 723 |
+
#include <ATen/ops/im2col_native.h>
|
| 724 |
+
#include <ATen/ops/imag_native.h>
|
| 725 |
+
#include <ATen/ops/index_native.h>
|
| 726 |
+
#include <ATen/ops/index_add_native.h>
|
| 727 |
+
#include <ATen/ops/index_copy_native.h>
|
| 728 |
+
#include <ATen/ops/index_fill_native.h>
|
| 729 |
+
#include <ATen/ops/index_put_native.h>
|
| 730 |
+
#include <ATen/ops/index_reduce_native.h>
|
| 731 |
+
#include <ATen/ops/index_select_native.h>
|
| 732 |
+
#include <ATen/ops/index_select_backward_native.h>
|
| 733 |
+
#include <ATen/ops/indices_native.h>
|
| 734 |
+
#include <ATen/ops/indices_copy_native.h>
|
| 735 |
+
#include <ATen/ops/infinitely_differentiable_gelu_backward_native.h>
|
| 736 |
+
#include <ATen/ops/inner_native.h>
|
| 737 |
+
#include <ATen/ops/instance_norm_native.h>
|
| 738 |
+
#include <ATen/ops/int_repr_native.h>
|
| 739 |
+
#include <ATen/ops/inverse_native.h>
|
| 740 |
+
#include <ATen/ops/is_coalesced_native.h>
|
| 741 |
+
#include <ATen/ops/is_complex_native.h>
|
| 742 |
+
#include <ATen/ops/is_conj_native.h>
|
| 743 |
+
#include <ATen/ops/is_distributed_native.h>
|
| 744 |
+
#include <ATen/ops/is_floating_point_native.h>
|
| 745 |
+
#include <ATen/ops/is_inference_native.h>
|
| 746 |
+
#include <ATen/ops/is_leaf_native.h>
|
| 747 |
+
#include <ATen/ops/is_neg_native.h>
|
| 748 |
+
#include <ATen/ops/is_nonzero_native.h>
|
| 749 |
+
#include <ATen/ops/is_pinned_native.h>
|
| 750 |
+
#include <ATen/ops/is_same_size_native.h>
|
| 751 |
+
#include <ATen/ops/is_set_to_native.h>
|
| 752 |
+
#include <ATen/ops/is_signed_native.h>
|
| 753 |
+
#include <ATen/ops/is_vulkan_available_native.h>
|
| 754 |
+
#include <ATen/ops/isclose_native.h>
|
| 755 |
+
#include <ATen/ops/isfinite_native.h>
|
| 756 |
+
#include <ATen/ops/isin_native.h>
|
| 757 |
+
#include <ATen/ops/isinf_native.h>
|
| 758 |
+
#include <ATen/ops/isnan_native.h>
|
| 759 |
+
#include <ATen/ops/isneginf_native.h>
|
| 760 |
+
#include <ATen/ops/isposinf_native.h>
|
| 761 |
+
#include <ATen/ops/isreal_native.h>
|
| 762 |
+
#include <ATen/ops/istft_native.h>
|
| 763 |
+
#include <ATen/ops/item_native.h>
|
| 764 |
+
#include <ATen/ops/kaiser_window_native.h>
|
| 765 |
+
#include <ATen/ops/kl_div_native.h>
|
| 766 |
+
#include <ATen/ops/kron_native.h>
|
| 767 |
+
#include <ATen/ops/kthvalue_native.h>
|
| 768 |
+
#include <ATen/ops/l1_loss_native.h>
|
| 769 |
+
#include <ATen/ops/layer_norm_native.h>
|
| 770 |
+
#include <ATen/ops/lcm_native.h>
|
| 771 |
+
#include <ATen/ops/ldexp_native.h>
|
| 772 |
+
#include <ATen/ops/le_native.h>
|
| 773 |
+
#include <ATen/ops/leaky_relu_native.h>
|
| 774 |
+
#include <ATen/ops/leaky_relu_backward_native.h>
|
| 775 |
+
#include <ATen/ops/lerp_native.h>
|
| 776 |
+
#include <ATen/ops/less_native.h>
|
| 777 |
+
#include <ATen/ops/less_equal_native.h>
|
| 778 |
+
#include <ATen/ops/lgamma_native.h>
|
| 779 |
+
#include <ATen/ops/lift_native.h>
|
| 780 |
+
#include <ATen/ops/lift_fresh_native.h>
|
| 781 |
+
#include <ATen/ops/lift_fresh_copy_native.h>
|
| 782 |
+
#include <ATen/ops/linalg_cholesky_native.h>
|
| 783 |
+
#include <ATen/ops/linalg_cholesky_ex_native.h>
|
| 784 |
+
#include <ATen/ops/linalg_cond_native.h>
|
| 785 |
+
#include <ATen/ops/linalg_cross_native.h>
|
| 786 |
+
#include <ATen/ops/linalg_det_native.h>
|
| 787 |
+
#include <ATen/ops/linalg_diagonal_native.h>
|
| 788 |
+
#include <ATen/ops/linalg_eig_native.h>
|
| 789 |
+
#include <ATen/ops/linalg_eigh_native.h>
|
| 790 |
+
#include <ATen/ops/linalg_eigvals_native.h>
|
| 791 |
+
#include <ATen/ops/linalg_eigvalsh_native.h>
|
| 792 |
+
#include <ATen/ops/linalg_householder_product_native.h>
|
| 793 |
+
#include <ATen/ops/linalg_inv_native.h>
|
| 794 |
+
#include <ATen/ops/linalg_inv_ex_native.h>
|
| 795 |
+
#include <ATen/ops/linalg_ldl_factor_native.h>
|
| 796 |
+
#include <ATen/ops/linalg_ldl_factor_ex_native.h>
|
| 797 |
+
#include <ATen/ops/linalg_ldl_solve_native.h>
|
| 798 |
+
#include <ATen/ops/linalg_lstsq_native.h>
|
| 799 |
+
#include <ATen/ops/linalg_lu_native.h>
|
| 800 |
+
#include <ATen/ops/linalg_lu_factor_native.h>
|
| 801 |
+
#include <ATen/ops/linalg_lu_factor_ex_native.h>
|
| 802 |
+
#include <ATen/ops/linalg_lu_solve_native.h>
|
| 803 |
+
#include <ATen/ops/linalg_matmul_native.h>
|
| 804 |
+
#include <ATen/ops/linalg_matrix_exp_native.h>
|
| 805 |
+
#include <ATen/ops/linalg_matrix_norm_native.h>
|
| 806 |
+
#include <ATen/ops/linalg_matrix_power_native.h>
|
| 807 |
+
#include <ATen/ops/linalg_matrix_rank_native.h>
|
| 808 |
+
#include <ATen/ops/linalg_multi_dot_native.h>
|
| 809 |
+
#include <ATen/ops/linalg_norm_native.h>
|
| 810 |
+
#include <ATen/ops/linalg_pinv_native.h>
|
| 811 |
+
#include <ATen/ops/linalg_qr_native.h>
|
| 812 |
+
#include <ATen/ops/linalg_slogdet_native.h>
|
| 813 |
+
#include <ATen/ops/linalg_solve_native.h>
|
| 814 |
+
#include <ATen/ops/linalg_solve_ex_native.h>
|
| 815 |
+
#include <ATen/ops/linalg_solve_triangular_native.h>
|
| 816 |
+
#include <ATen/ops/linalg_svd_native.h>
|
| 817 |
+
#include <ATen/ops/linalg_svdvals_native.h>
|
| 818 |
+
#include <ATen/ops/linalg_tensorinv_native.h>
|
| 819 |
+
#include <ATen/ops/linalg_tensorsolve_native.h>
|
| 820 |
+
#include <ATen/ops/linalg_vander_native.h>
|
| 821 |
+
#include <ATen/ops/linalg_vecdot_native.h>
|
| 822 |
+
#include <ATen/ops/linalg_vector_norm_native.h>
|
| 823 |
+
#include <ATen/ops/linear_native.h>
|
| 824 |
+
#include <ATen/ops/linear_backward_native.h>
|
| 825 |
+
#include <ATen/ops/linspace_native.h>
|
| 826 |
+
#include <ATen/ops/log_native.h>
|
| 827 |
+
#include <ATen/ops/log10_native.h>
|
| 828 |
+
#include <ATen/ops/log1p_native.h>
|
| 829 |
+
#include <ATen/ops/log2_native.h>
|
| 830 |
+
#include <ATen/ops/log_normal_native.h>
|
| 831 |
+
#include <ATen/ops/log_sigmoid_native.h>
|
| 832 |
+
#include <ATen/ops/log_sigmoid_backward_native.h>
|
| 833 |
+
#include <ATen/ops/log_sigmoid_forward_native.h>
|
| 834 |
+
#include <ATen/ops/log_softmax_native.h>
|
| 835 |
+
#include <ATen/ops/logaddexp_native.h>
|
| 836 |
+
#include <ATen/ops/logaddexp2_native.h>
|
| 837 |
+
#include <ATen/ops/logcumsumexp_native.h>
|
| 838 |
+
#include <ATen/ops/logdet_native.h>
|
| 839 |
+
#include <ATen/ops/logical_and_native.h>
|
| 840 |
+
#include <ATen/ops/logical_not_native.h>
|
| 841 |
+
#include <ATen/ops/logical_or_native.h>
|
| 842 |
+
#include <ATen/ops/logical_xor_native.h>
|
| 843 |
+
#include <ATen/ops/logit_native.h>
|
| 844 |
+
#include <ATen/ops/logit_backward_native.h>
|
| 845 |
+
#include <ATen/ops/logspace_native.h>
|
| 846 |
+
#include <ATen/ops/logsumexp_native.h>
|
| 847 |
+
#include <ATen/ops/lshift_native.h>
|
| 848 |
+
#include <ATen/ops/lstm_native.h>
|
| 849 |
+
#include <ATen/ops/lstm_cell_native.h>
|
| 850 |
+
#include <ATen/ops/lstm_mps_backward_native.h>
|
| 851 |
+
#include <ATen/ops/lt_native.h>
|
| 852 |
+
#include <ATen/ops/lu_solve_native.h>
|
| 853 |
+
#include <ATen/ops/lu_unpack_native.h>
|
| 854 |
+
#include <ATen/ops/mH_native.h>
|
| 855 |
+
#include <ATen/ops/mT_native.h>
|
| 856 |
+
#include <ATen/ops/margin_ranking_loss_native.h>
|
| 857 |
+
#include <ATen/ops/masked_fill_native.h>
|
| 858 |
+
#include <ATen/ops/masked_scatter_native.h>
|
| 859 |
+
#include <ATen/ops/masked_scatter_backward_native.h>
|
| 860 |
+
#include <ATen/ops/masked_select_native.h>
|
| 861 |
+
#include <ATen/ops/masked_select_backward_native.h>
|
| 862 |
+
#include <ATen/ops/matmul_native.h>
|
| 863 |
+
#include <ATen/ops/matmul_backward_native.h>
|
| 864 |
+
#include <ATen/ops/matrix_H_native.h>
|
| 865 |
+
#include <ATen/ops/matrix_exp_native.h>
|
| 866 |
+
#include <ATen/ops/matrix_exp_backward_native.h>
|
| 867 |
+
#include <ATen/ops/matrix_power_native.h>
|
| 868 |
+
#include <ATen/ops/max_native.h>
|
| 869 |
+
#include <ATen/ops/max_pool1d_native.h>
|
| 870 |
+
#include <ATen/ops/max_pool1d_with_indices_native.h>
|
| 871 |
+
#include <ATen/ops/max_pool2d_native.h>
|
| 872 |
+
#include <ATen/ops/max_pool2d_backward_native.h>
|
| 873 |
+
#include <ATen/ops/max_pool2d_with_indices_native.h>
|
| 874 |
+
#include <ATen/ops/max_pool2d_with_indices_backward_native.h>
|
| 875 |
+
#include <ATen/ops/max_pool3d_native.h>
|
| 876 |
+
#include <ATen/ops/max_pool3d_with_indices_native.h>
|
| 877 |
+
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
|
| 878 |
+
#include <ATen/ops/max_unpool2d_native.h>
|
| 879 |
+
#include <ATen/ops/max_unpool3d_native.h>
|
| 880 |
+
#include <ATen/ops/maximum_native.h>
|
| 881 |
+
#include <ATen/ops/mean_native.h>
|
| 882 |
+
#include <ATen/ops/median_native.h>
|
| 883 |
+
#include <ATen/ops/meshgrid_native.h>
|
| 884 |
+
#include <ATen/ops/min_native.h>
|
| 885 |
+
#include <ATen/ops/minimum_native.h>
|
| 886 |
+
#include <ATen/ops/miopen_batch_norm_native.h>
|
| 887 |
+
#include <ATen/ops/miopen_batch_norm_backward_native.h>
|
| 888 |
+
#include <ATen/ops/miopen_convolution_native.h>
|
| 889 |
+
#include <ATen/ops/miopen_convolution_add_relu_native.h>
|
| 890 |
+
#include <ATen/ops/miopen_convolution_relu_native.h>
|
| 891 |
+
#include <ATen/ops/miopen_convolution_transpose_native.h>
|
| 892 |
+
#include <ATen/ops/miopen_depthwise_convolution_native.h>
|
| 893 |
+
#include <ATen/ops/miopen_rnn_native.h>
|
| 894 |
+
#include <ATen/ops/miopen_rnn_backward_native.h>
|
| 895 |
+
#include <ATen/ops/mish_native.h>
|
| 896 |
+
#include <ATen/ops/mish_backward_native.h>
|
| 897 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_native.h>
|
| 898 |
+
#include <ATen/ops/mkldnn_adaptive_avg_pool2d_backward_native.h>
|
| 899 |
+
#include <ATen/ops/mkldnn_convolution_native.h>
|
| 900 |
+
#include <ATen/ops/mkldnn_linear_native.h>
|
| 901 |
+
#include <ATen/ops/mkldnn_linear_backward_native.h>
|
| 902 |
+
#include <ATen/ops/mkldnn_linear_backward_input_native.h>
|
| 903 |
+
#include <ATen/ops/mkldnn_linear_backward_weights_native.h>
|
| 904 |
+
#include <ATen/ops/mkldnn_max_pool2d_native.h>
|
| 905 |
+
#include <ATen/ops/mkldnn_max_pool2d_backward_native.h>
|
| 906 |
+
#include <ATen/ops/mkldnn_max_pool3d_native.h>
|
| 907 |
+
#include <ATen/ops/mkldnn_max_pool3d_backward_native.h>
|
| 908 |
+
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
|
| 909 |
+
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
|
| 910 |
+
#include <ATen/ops/mkldnn_rnn_layer_native.h>
|
| 911 |
+
#include <ATen/ops/mkldnn_rnn_layer_backward_native.h>
|
| 912 |
+
#include <ATen/ops/mm_native.h>
|
| 913 |
+
#include <ATen/ops/mode_native.h>
|
| 914 |
+
#include <ATen/ops/moveaxis_native.h>
|
| 915 |
+
#include <ATen/ops/movedim_native.h>
|
| 916 |
+
#include <ATen/ops/mps_convolution_backward_native.h>
|
| 917 |
+
#include <ATen/ops/mps_convolution_transpose_backward_native.h>
|
| 918 |
+
#include <ATen/ops/mse_loss_native.h>
|
| 919 |
+
#include <ATen/ops/mse_loss_backward_native.h>
|
| 920 |
+
#include <ATen/ops/msort_native.h>
|
| 921 |
+
#include <ATen/ops/mul_native.h>
|
| 922 |
+
#include <ATen/ops/multi_margin_loss_native.h>
|
| 923 |
+
#include <ATen/ops/multi_margin_loss_backward_native.h>
|
| 924 |
+
#include <ATen/ops/multilabel_margin_loss_native.h>
|
| 925 |
+
#include <ATen/ops/multilabel_margin_loss_backward_native.h>
|
| 926 |
+
#include <ATen/ops/multilabel_margin_loss_forward_native.h>
|
| 927 |
+
#include <ATen/ops/multinomial_native.h>
|
| 928 |
+
#include <ATen/ops/multiply_native.h>
|
| 929 |
+
#include <ATen/ops/mv_native.h>
|
| 930 |
+
#include <ATen/ops/mvlgamma_native.h>
|
| 931 |
+
#include <ATen/ops/nan_to_num_native.h>
|
| 932 |
+
#include <ATen/ops/nanmean_native.h>
|
| 933 |
+
#include <ATen/ops/nanmedian_native.h>
|
| 934 |
+
#include <ATen/ops/nanquantile_native.h>
|
| 935 |
+
#include <ATen/ops/nansum_native.h>
|
| 936 |
+
#include <ATen/ops/narrow_native.h>
|
| 937 |
+
#include <ATen/ops/narrow_copy_native.h>
|
| 938 |
+
#include <ATen/ops/native_batch_norm_native.h>
|
| 939 |
+
#include <ATen/ops/native_batch_norm_backward_native.h>
|
| 940 |
+
#include <ATen/ops/native_channel_shuffle_native.h>
|
| 941 |
+
#include <ATen/ops/native_dropout_native.h>
|
| 942 |
+
#include <ATen/ops/native_dropout_backward_native.h>
|
| 943 |
+
#include <ATen/ops/native_group_norm_native.h>
|
| 944 |
+
#include <ATen/ops/native_group_norm_backward_native.h>
|
| 945 |
+
#include <ATen/ops/native_layer_norm_native.h>
|
| 946 |
+
#include <ATen/ops/native_layer_norm_backward_native.h>
|
| 947 |
+
#include <ATen/ops/native_norm_native.h>
|
| 948 |
+
#include <ATen/ops/ne_native.h>
|
| 949 |
+
#include <ATen/ops/neg_native.h>
|
| 950 |
+
#include <ATen/ops/negative_native.h>
|
| 951 |
+
#include <ATen/ops/nested_to_padded_tensor_native.h>
|
| 952 |
+
#include <ATen/ops/new_empty_native.h>
|
| 953 |
+
#include <ATen/ops/new_empty_strided_native.h>
|
| 954 |
+
#include <ATen/ops/new_full_native.h>
|
| 955 |
+
#include <ATen/ops/new_ones_native.h>
|
| 956 |
+
#include <ATen/ops/new_zeros_native.h>
|
| 957 |
+
#include <ATen/ops/nextafter_native.h>
|
| 958 |
+
#include <ATen/ops/nll_loss_native.h>
|
| 959 |
+
#include <ATen/ops/nll_loss2d_native.h>
|
| 960 |
+
#include <ATen/ops/nll_loss2d_backward_native.h>
|
| 961 |
+
#include <ATen/ops/nll_loss2d_forward_native.h>
|
| 962 |
+
#include <ATen/ops/nll_loss_backward_native.h>
|
| 963 |
+
#include <ATen/ops/nll_loss_forward_native.h>
|
| 964 |
+
#include <ATen/ops/nll_loss_nd_native.h>
|
| 965 |
+
#include <ATen/ops/nonzero_native.h>
|
| 966 |
+
#include <ATen/ops/nonzero_numpy_native.h>
|
| 967 |
+
#include <ATen/ops/nonzero_static_native.h>
|
| 968 |
+
#include <ATen/ops/norm_native.h>
|
| 969 |
+
#include <ATen/ops/norm_except_dim_native.h>
|
| 970 |
+
#include <ATen/ops/normal_native.h>
|
| 971 |
+
#include <ATen/ops/not_equal_native.h>
|
| 972 |
+
#include <ATen/ops/nuclear_norm_native.h>
|
| 973 |
+
#include <ATen/ops/numpy_T_native.h>
|
| 974 |
+
#include <ATen/ops/one_hot_native.h>
|
| 975 |
+
#include <ATen/ops/ones_native.h>
|
| 976 |
+
#include <ATen/ops/ones_like_native.h>
|
| 977 |
+
#include <ATen/ops/or_native.h>
|
| 978 |
+
#include <ATen/ops/orgqr_native.h>
|
| 979 |
+
#include <ATen/ops/ormqr_native.h>
|
| 980 |
+
#include <ATen/ops/outer_native.h>
|
| 981 |
+
#include <ATen/ops/output_nr_native.h>
|
| 982 |
+
#include <ATen/ops/pad_native.h>
|
| 983 |
+
#include <ATen/ops/pad_sequence_native.h>
|
| 984 |
+
#include <ATen/ops/pairwise_distance_native.h>
|
| 985 |
+
#include <ATen/ops/pdist_native.h>
|
| 986 |
+
#include <ATen/ops/permute_native.h>
|
| 987 |
+
#include <ATen/ops/permute_copy_native.h>
|
| 988 |
+
#include <ATen/ops/pin_memory_native.h>
|
| 989 |
+
#include <ATen/ops/pinverse_native.h>
|
| 990 |
+
#include <ATen/ops/pixel_shuffle_native.h>
|
| 991 |
+
#include <ATen/ops/pixel_unshuffle_native.h>
|
| 992 |
+
#include <ATen/ops/poisson_native.h>
|
| 993 |
+
#include <ATen/ops/poisson_nll_loss_native.h>
|
| 994 |
+
#include <ATen/ops/polar_native.h>
|
| 995 |
+
#include <ATen/ops/polygamma_native.h>
|
| 996 |
+
#include <ATen/ops/positive_native.h>
|
| 997 |
+
#include <ATen/ops/pow_native.h>
|
| 998 |
+
#include <ATen/ops/prelu_native.h>
|
| 999 |
+
#include <ATen/ops/prod_native.h>
|
| 1000 |
+
#include <ATen/ops/promote_types_native.h>
|
| 1001 |
+
#include <ATen/ops/put_native.h>
|
| 1002 |
+
#include <ATen/ops/q_per_channel_axis_native.h>
|
| 1003 |
+
#include <ATen/ops/q_per_channel_scales_native.h>
|
| 1004 |
+
#include <ATen/ops/q_per_channel_zero_points_native.h>
|
| 1005 |
+
#include <ATen/ops/q_scale_native.h>
|
| 1006 |
+
#include <ATen/ops/q_zero_point_native.h>
|
| 1007 |
+
#include <ATen/ops/qr_native.h>
|
| 1008 |
+
#include <ATen/ops/qscheme_native.h>
|
| 1009 |
+
#include <ATen/ops/quantile_native.h>
|
| 1010 |
+
#include <ATen/ops/quantize_per_channel_native.h>
|
| 1011 |
+
#include <ATen/ops/quantize_per_tensor_native.h>
|
| 1012 |
+
#include <ATen/ops/quantize_per_tensor_dynamic_native.h>
|
| 1013 |
+
#include <ATen/ops/quantized_batch_norm_native.h>
|
| 1014 |
+
#include <ATen/ops/quantized_gru_cell_native.h>
|
| 1015 |
+
#include <ATen/ops/quantized_lstm_cell_native.h>
|
| 1016 |
+
#include <ATen/ops/quantized_max_pool1d_native.h>
|
| 1017 |
+
#include <ATen/ops/quantized_max_pool2d_native.h>
|
| 1018 |
+
#include <ATen/ops/quantized_max_pool3d_native.h>
|
| 1019 |
+
#include <ATen/ops/quantized_rnn_relu_cell_native.h>
|
| 1020 |
+
#include <ATen/ops/quantized_rnn_tanh_cell_native.h>
|
| 1021 |
+
#include <ATen/ops/rad2deg_native.h>
|
| 1022 |
+
#include <ATen/ops/rand_native.h>
|
| 1023 |
+
#include <ATen/ops/rand_like_native.h>
|
| 1024 |
+
#include <ATen/ops/randint_native.h>
|
| 1025 |
+
#include <ATen/ops/randint_like_native.h>
|
| 1026 |
+
#include <ATen/ops/randn_native.h>
|
| 1027 |
+
#include <ATen/ops/randn_like_native.h>
|
| 1028 |
+
#include <ATen/ops/random_native.h>
|
| 1029 |
+
#include <ATen/ops/randperm_native.h>
|
| 1030 |
+
#include <ATen/ops/range_native.h>
|
| 1031 |
+
#include <ATen/ops/ravel_native.h>
|
| 1032 |
+
#include <ATen/ops/real_native.h>
|
| 1033 |
+
#include <ATen/ops/reciprocal_native.h>
|
| 1034 |
+
#include <ATen/ops/record_stream_native.h>
|
| 1035 |
+
#include <ATen/ops/refine_names_native.h>
|
| 1036 |
+
#include <ATen/ops/reflection_pad1d_native.h>
|
| 1037 |
+
#include <ATen/ops/reflection_pad1d_backward_native.h>
|
| 1038 |
+
#include <ATen/ops/reflection_pad2d_native.h>
|
| 1039 |
+
#include <ATen/ops/reflection_pad2d_backward_native.h>
|
| 1040 |
+
#include <ATen/ops/reflection_pad3d_native.h>
|
| 1041 |
+
#include <ATen/ops/reflection_pad3d_backward_native.h>
|
| 1042 |
+
#include <ATen/ops/relu_native.h>
|
| 1043 |
+
#include <ATen/ops/relu6_native.h>
|
| 1044 |
+
#include <ATen/ops/remainder_native.h>
|
| 1045 |
+
#include <ATen/ops/rename_native.h>
|
| 1046 |
+
#include <ATen/ops/renorm_native.h>
|
| 1047 |
+
#include <ATen/ops/repeat_native.h>
|
| 1048 |
+
#include <ATen/ops/repeat_interleave_native.h>
|
| 1049 |
+
#include <ATen/ops/replication_pad1d_native.h>
|
| 1050 |
+
#include <ATen/ops/replication_pad1d_backward_native.h>
|
| 1051 |
+
#include <ATen/ops/replication_pad2d_native.h>
|
| 1052 |
+
#include <ATen/ops/replication_pad2d_backward_native.h>
|
| 1053 |
+
#include <ATen/ops/replication_pad3d_native.h>
|
| 1054 |
+
#include <ATen/ops/replication_pad3d_backward_native.h>
|
| 1055 |
+
#include <ATen/ops/requires_grad_native.h>
|
| 1056 |
+
#include <ATen/ops/reshape_native.h>
|
| 1057 |
+
#include <ATen/ops/reshape_as_native.h>
|
| 1058 |
+
#include <ATen/ops/resize_native.h>
|
| 1059 |
+
#include <ATen/ops/resize_as_native.h>
|
| 1060 |
+
#include <ATen/ops/resize_as_sparse_native.h>
|
| 1061 |
+
#include <ATen/ops/resolve_conj_native.h>
|
| 1062 |
+
#include <ATen/ops/resolve_neg_native.h>
|
| 1063 |
+
#include <ATen/ops/result_type_native.h>
|
| 1064 |
+
#include <ATen/ops/retain_grad_native.h>
|
| 1065 |
+
#include <ATen/ops/retains_grad_native.h>
|
| 1066 |
+
#include <ATen/ops/rnn_relu_native.h>
|
| 1067 |
+
#include <ATen/ops/rnn_relu_cell_native.h>
|
| 1068 |
+
#include <ATen/ops/rnn_tanh_native.h>
|
| 1069 |
+
#include <ATen/ops/rnn_tanh_cell_native.h>
|
| 1070 |
+
#include <ATen/ops/roll_native.h>
|
| 1071 |
+
#include <ATen/ops/rot90_native.h>
|
| 1072 |
+
#include <ATen/ops/round_native.h>
|
| 1073 |
+
#include <ATen/ops/row_indices_native.h>
|
| 1074 |
+
#include <ATen/ops/row_indices_copy_native.h>
|
| 1075 |
+
#include <ATen/ops/row_stack_native.h>
|
| 1076 |
+
#include <ATen/ops/rrelu_native.h>
|
| 1077 |
+
#include <ATen/ops/rrelu_with_noise_native.h>
|
| 1078 |
+
#include <ATen/ops/rrelu_with_noise_backward_native.h>
|
| 1079 |
+
#include <ATen/ops/rshift_native.h>
|
| 1080 |
+
#include <ATen/ops/rsqrt_native.h>
|
| 1081 |
+
#include <ATen/ops/rsub_native.h>
|
| 1082 |
+
#include <ATen/ops/scalar_tensor_native.h>
|
| 1083 |
+
#include <ATen/ops/scaled_dot_product_attention_native.h>
|
| 1084 |
+
#include <ATen/ops/scatter_native.h>
|
| 1085 |
+
#include <ATen/ops/scatter_add_native.h>
|
| 1086 |
+
#include <ATen/ops/scatter_reduce_native.h>
|
| 1087 |
+
#include <ATen/ops/searchsorted_native.h>
|
| 1088 |
+
#include <ATen/ops/segment_reduce_native.h>
|
| 1089 |
+
#include <ATen/ops/select_native.h>
|
| 1090 |
+
#include <ATen/ops/select_backward_native.h>
|
| 1091 |
+
#include <ATen/ops/select_copy_native.h>
|
| 1092 |
+
#include <ATen/ops/select_scatter_native.h>
|
| 1093 |
+
#include <ATen/ops/selu_native.h>
|
| 1094 |
+
#include <ATen/ops/set_native.h>
|
| 1095 |
+
#include <ATen/ops/set_data_native.h>
|
| 1096 |
+
#include <ATen/ops/sgn_native.h>
|
| 1097 |
+
#include <ATen/ops/sigmoid_native.h>
|
| 1098 |
+
#include <ATen/ops/sigmoid_backward_native.h>
|
| 1099 |
+
#include <ATen/ops/sign_native.h>
|
| 1100 |
+
#include <ATen/ops/signbit_native.h>
|
| 1101 |
+
#include <ATen/ops/silu_native.h>
|
| 1102 |
+
#include <ATen/ops/silu_backward_native.h>
|
| 1103 |
+
#include <ATen/ops/sin_native.h>
|
| 1104 |
+
#include <ATen/ops/sinc_native.h>
|
| 1105 |
+
#include <ATen/ops/sinh_native.h>
|
| 1106 |
+
#include <ATen/ops/size_native.h>
|
| 1107 |
+
#include <ATen/ops/slice_native.h>
|
| 1108 |
+
#include <ATen/ops/slice_backward_native.h>
|
| 1109 |
+
#include <ATen/ops/slice_copy_native.h>
|
| 1110 |
+
#include <ATen/ops/slice_inverse_native.h>
|
| 1111 |
+
#include <ATen/ops/slice_scatter_native.h>
|
| 1112 |
+
#include <ATen/ops/slogdet_native.h>
|
| 1113 |
+
#include <ATen/ops/slow_conv3d_native.h>
|
| 1114 |
+
#include <ATen/ops/slow_conv3d_forward_native.h>
|
| 1115 |
+
#include <ATen/ops/slow_conv_dilated2d_native.h>
|
| 1116 |
+
#include <ATen/ops/slow_conv_dilated3d_native.h>
|
| 1117 |
+
#include <ATen/ops/slow_conv_transpose2d_native.h>
|
| 1118 |
+
#include <ATen/ops/slow_conv_transpose3d_native.h>
|
| 1119 |
+
#include <ATen/ops/smm_native.h>
|
| 1120 |
+
#include <ATen/ops/smooth_l1_loss_native.h>
|
| 1121 |
+
#include <ATen/ops/smooth_l1_loss_backward_native.h>
|
| 1122 |
+
#include <ATen/ops/soft_margin_loss_native.h>
|
| 1123 |
+
#include <ATen/ops/soft_margin_loss_backward_native.h>
|
| 1124 |
+
#include <ATen/ops/softmax_native.h>
|
| 1125 |
+
#include <ATen/ops/softplus_native.h>
|
| 1126 |
+
#include <ATen/ops/softplus_backward_native.h>
|
| 1127 |
+
#include <ATen/ops/softshrink_native.h>
|
| 1128 |
+
#include <ATen/ops/softshrink_backward_native.h>
|
| 1129 |
+
#include <ATen/ops/sort_native.h>
|
| 1130 |
+
#include <ATen/ops/sparse_bsc_tensor_native.h>
|
| 1131 |
+
#include <ATen/ops/sparse_bsr_tensor_native.h>
|
| 1132 |
+
#include <ATen/ops/sparse_compressed_tensor_native.h>
|
| 1133 |
+
#include <ATen/ops/sparse_coo_tensor_native.h>
|
| 1134 |
+
#include <ATen/ops/sparse_csc_tensor_native.h>
|
| 1135 |
+
#include <ATen/ops/sparse_csr_tensor_native.h>
|
| 1136 |
+
#include <ATen/ops/sparse_dim_native.h>
|
| 1137 |
+
#include <ATen/ops/sparse_mask_native.h>
|
| 1138 |
+
#include <ATen/ops/sparse_resize_native.h>
|
| 1139 |
+
#include <ATen/ops/sparse_resize_and_clear_native.h>
|
| 1140 |
+
#include <ATen/ops/sparse_sampled_addmm_native.h>
|
| 1141 |
+
#include <ATen/ops/special_airy_ai_native.h>
|
| 1142 |
+
#include <ATen/ops/special_bessel_j0_native.h>
|
| 1143 |
+
#include <ATen/ops/special_bessel_j1_native.h>
|
| 1144 |
+
#include <ATen/ops/special_bessel_y0_native.h>
|
| 1145 |
+
#include <ATen/ops/special_bessel_y1_native.h>
|
| 1146 |
+
#include <ATen/ops/special_chebyshev_polynomial_t_native.h>
|
| 1147 |
+
#include <ATen/ops/special_chebyshev_polynomial_u_native.h>
|
| 1148 |
+
#include <ATen/ops/special_chebyshev_polynomial_v_native.h>
|
| 1149 |
+
#include <ATen/ops/special_chebyshev_polynomial_w_native.h>
|
| 1150 |
+
#include <ATen/ops/special_digamma_native.h>
|
| 1151 |
+
#include <ATen/ops/special_entr_native.h>
|
| 1152 |
+
#include <ATen/ops/special_erf_native.h>
|
| 1153 |
+
#include <ATen/ops/special_erfc_native.h>
|
| 1154 |
+
#include <ATen/ops/special_erfcx_native.h>
|
| 1155 |
+
#include <ATen/ops/special_erfinv_native.h>
|
| 1156 |
+
#include <ATen/ops/special_exp2_native.h>
|
| 1157 |
+
#include <ATen/ops/special_expit_native.h>
|
| 1158 |
+
#include <ATen/ops/special_expm1_native.h>
|
| 1159 |
+
#include <ATen/ops/special_gammainc_native.h>
|
| 1160 |
+
#include <ATen/ops/special_gammaincc_native.h>
|
| 1161 |
+
#include <ATen/ops/special_gammaln_native.h>
|
| 1162 |
+
#include <ATen/ops/special_hermite_polynomial_h_native.h>
|
| 1163 |
+
#include <ATen/ops/special_hermite_polynomial_he_native.h>
|
| 1164 |
+
#include <ATen/ops/special_i0_native.h>
|
| 1165 |
+
#include <ATen/ops/special_i0e_native.h>
|
| 1166 |
+
#include <ATen/ops/special_i1_native.h>
|
| 1167 |
+
#include <ATen/ops/special_i1e_native.h>
|
| 1168 |
+
#include <ATen/ops/special_laguerre_polynomial_l_native.h>
|
| 1169 |
+
#include <ATen/ops/special_legendre_polynomial_p_native.h>
|
| 1170 |
+
#include <ATen/ops/special_log1p_native.h>
|
| 1171 |
+
#include <ATen/ops/special_log_ndtr_native.h>
|
| 1172 |
+
#include <ATen/ops/special_log_softmax_native.h>
|
| 1173 |
+
#include <ATen/ops/special_logit_native.h>
|
| 1174 |
+
#include <ATen/ops/special_logsumexp_native.h>
|
| 1175 |
+
#include <ATen/ops/special_modified_bessel_i0_native.h>
|
| 1176 |
+
#include <ATen/ops/special_modified_bessel_i1_native.h>
|
| 1177 |
+
#include <ATen/ops/special_modified_bessel_k0_native.h>
|
| 1178 |
+
#include <ATen/ops/special_modified_bessel_k1_native.h>
|
| 1179 |
+
#include <ATen/ops/special_multigammaln_native.h>
|
| 1180 |
+
#include <ATen/ops/special_ndtr_native.h>
|
| 1181 |
+
#include <ATen/ops/special_ndtri_native.h>
|
| 1182 |
+
#include <ATen/ops/special_polygamma_native.h>
|
| 1183 |
+
#include <ATen/ops/special_psi_native.h>
|
| 1184 |
+
#include <ATen/ops/special_round_native.h>
|
| 1185 |
+
#include <ATen/ops/special_scaled_modified_bessel_k0_native.h>
|
| 1186 |
+
#include <ATen/ops/special_scaled_modified_bessel_k1_native.h>
|
| 1187 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_t_native.h>
|
| 1188 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_u_native.h>
|
| 1189 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_v_native.h>
|
| 1190 |
+
#include <ATen/ops/special_shifted_chebyshev_polynomial_w_native.h>
|
| 1191 |
+
#include <ATen/ops/special_sinc_native.h>
|
| 1192 |
+
#include <ATen/ops/special_softmax_native.h>
|
| 1193 |
+
#include <ATen/ops/special_spherical_bessel_j0_native.h>
|
| 1194 |
+
#include <ATen/ops/special_xlog1py_native.h>
|
| 1195 |
+
#include <ATen/ops/special_xlogy_native.h>
|
| 1196 |
+
#include <ATen/ops/special_zeta_native.h>
|
| 1197 |
+
#include <ATen/ops/split_native.h>
|
| 1198 |
+
#include <ATen/ops/split_copy_native.h>
|
| 1199 |
+
#include <ATen/ops/split_with_sizes_native.h>
|
| 1200 |
+
#include <ATen/ops/split_with_sizes_copy_native.h>
|
| 1201 |
+
#include <ATen/ops/sqrt_native.h>
|
| 1202 |
+
#include <ATen/ops/square_native.h>
|
| 1203 |
+
#include <ATen/ops/squeeze_native.h>
|
| 1204 |
+
#include <ATen/ops/squeeze_copy_native.h>
|
| 1205 |
+
#include <ATen/ops/sspaddmm_native.h>
|
| 1206 |
+
#include <ATen/ops/stack_native.h>
|
| 1207 |
+
#include <ATen/ops/std_native.h>
|
| 1208 |
+
#include <ATen/ops/std_mean_native.h>
|
| 1209 |
+
#include <ATen/ops/stft_native.h>
|
| 1210 |
+
#include <ATen/ops/stride_native.h>
|
| 1211 |
+
#include <ATen/ops/sub_native.h>
|
| 1212 |
+
#include <ATen/ops/subtract_native.h>
|
| 1213 |
+
#include <ATen/ops/sum_native.h>
|
| 1214 |
+
#include <ATen/ops/sum_to_size_native.h>
|
| 1215 |
+
#include <ATen/ops/svd_native.h>
|
| 1216 |
+
#include <ATen/ops/swapaxes_native.h>
|
| 1217 |
+
#include <ATen/ops/swapdims_native.h>
|
| 1218 |
+
#include <ATen/ops/sym_constrain_range_native.h>
|
| 1219 |
+
#include <ATen/ops/sym_constrain_range_for_size_native.h>
|
| 1220 |
+
#include <ATen/ops/sym_numel_native.h>
|
| 1221 |
+
#include <ATen/ops/sym_size_native.h>
|
| 1222 |
+
#include <ATen/ops/sym_storage_offset_native.h>
|
| 1223 |
+
#include <ATen/ops/sym_stride_native.h>
|
| 1224 |
+
#include <ATen/ops/t_native.h>
|
| 1225 |
+
#include <ATen/ops/t_copy_native.h>
|
| 1226 |
+
#include <ATen/ops/take_native.h>
|
| 1227 |
+
#include <ATen/ops/take_along_dim_native.h>
|
| 1228 |
+
#include <ATen/ops/tan_native.h>
|
| 1229 |
+
#include <ATen/ops/tanh_native.h>
|
| 1230 |
+
#include <ATen/ops/tanh_backward_native.h>
|
| 1231 |
+
#include <ATen/ops/tensor_split_native.h>
|
| 1232 |
+
#include <ATen/ops/tensordot_native.h>
|
| 1233 |
+
#include <ATen/ops/thnn_conv2d_native.h>
|
| 1234 |
+
#include <ATen/ops/threshold_native.h>
|
| 1235 |
+
#include <ATen/ops/threshold_backward_native.h>
|
| 1236 |
+
#include <ATen/ops/tile_native.h>
|
| 1237 |
+
#include <ATen/ops/to_native.h>
|
| 1238 |
+
#include <ATen/ops/to_dense_native.h>
|
| 1239 |
+
#include <ATen/ops/to_dense_backward_native.h>
|
| 1240 |
+
#include <ATen/ops/to_mkldnn_native.h>
|
| 1241 |
+
#include <ATen/ops/to_mkldnn_backward_native.h>
|
| 1242 |
+
#include <ATen/ops/to_padded_tensor_native.h>
|
| 1243 |
+
#include <ATen/ops/to_sparse_native.h>
|
| 1244 |
+
#include <ATen/ops/to_sparse_bsc_native.h>
|
| 1245 |
+
#include <ATen/ops/to_sparse_bsr_native.h>
|
| 1246 |
+
#include <ATen/ops/to_sparse_csc_native.h>
|
| 1247 |
+
#include <ATen/ops/to_sparse_csr_native.h>
|
| 1248 |
+
#include <ATen/ops/topk_native.h>
|
| 1249 |
+
#include <ATen/ops/trace_native.h>
|
| 1250 |
+
#include <ATen/ops/trace_backward_native.h>
|
| 1251 |
+
#include <ATen/ops/transpose_native.h>
|
| 1252 |
+
#include <ATen/ops/transpose_copy_native.h>
|
| 1253 |
+
#include <ATen/ops/trapezoid_native.h>
|
| 1254 |
+
#include <ATen/ops/trapz_native.h>
|
| 1255 |
+
#include <ATen/ops/triangular_solve_native.h>
|
| 1256 |
+
#include <ATen/ops/tril_native.h>
|
| 1257 |
+
#include <ATen/ops/tril_indices_native.h>
|
| 1258 |
+
#include <ATen/ops/triplet_margin_loss_native.h>
|
| 1259 |
+
#include <ATen/ops/triu_native.h>
|
| 1260 |
+
#include <ATen/ops/triu_indices_native.h>
|
| 1261 |
+
#include <ATen/ops/true_divide_native.h>
|
| 1262 |
+
#include <ATen/ops/trunc_native.h>
|
| 1263 |
+
#include <ATen/ops/type_as_native.h>
|
| 1264 |
+
#include <ATen/ops/unbind_native.h>
|
| 1265 |
+
#include <ATen/ops/unbind_copy_native.h>
|
| 1266 |
+
#include <ATen/ops/unflatten_native.h>
|
| 1267 |
+
#include <ATen/ops/unflatten_dense_tensors_native.h>
|
| 1268 |
+
#include <ATen/ops/unfold_native.h>
|
| 1269 |
+
#include <ATen/ops/unfold_backward_native.h>
|
| 1270 |
+
#include <ATen/ops/unfold_copy_native.h>
|
| 1271 |
+
#include <ATen/ops/uniform_native.h>
|
| 1272 |
+
#include <ATen/ops/unique_consecutive_native.h>
|
| 1273 |
+
#include <ATen/ops/unique_dim_native.h>
|
| 1274 |
+
#include <ATen/ops/unique_dim_consecutive_native.h>
|
| 1275 |
+
#include <ATen/ops/unsafe_chunk_native.h>
|
| 1276 |
+
#include <ATen/ops/unsafe_split_native.h>
|
| 1277 |
+
#include <ATen/ops/unsafe_split_with_sizes_native.h>
|
| 1278 |
+
#include <ATen/ops/unsqueeze_native.h>
|
| 1279 |
+
#include <ATen/ops/unsqueeze_copy_native.h>
|
| 1280 |
+
#include <ATen/ops/upsample_bicubic2d_native.h>
|
| 1281 |
+
#include <ATen/ops/upsample_bicubic2d_backward_native.h>
|
| 1282 |
+
#include <ATen/ops/upsample_bilinear2d_native.h>
|
| 1283 |
+
#include <ATen/ops/upsample_bilinear2d_backward_native.h>
|
| 1284 |
+
#include <ATen/ops/upsample_linear1d_native.h>
|
| 1285 |
+
#include <ATen/ops/upsample_linear1d_backward_native.h>
|
| 1286 |
+
#include <ATen/ops/upsample_nearest1d_native.h>
|
| 1287 |
+
#include <ATen/ops/upsample_nearest1d_backward_native.h>
|
| 1288 |
+
#include <ATen/ops/upsample_nearest2d_native.h>
|
| 1289 |
+
#include <ATen/ops/upsample_nearest2d_backward_native.h>
|
| 1290 |
+
#include <ATen/ops/upsample_nearest3d_native.h>
|
| 1291 |
+
#include <ATen/ops/upsample_nearest3d_backward_native.h>
|
| 1292 |
+
#include <ATen/ops/upsample_trilinear3d_native.h>
|
| 1293 |
+
#include <ATen/ops/upsample_trilinear3d_backward_native.h>
|
| 1294 |
+
#include <ATen/ops/value_selecting_reduction_backward_native.h>
|
| 1295 |
+
#include <ATen/ops/values_native.h>
|
| 1296 |
+
#include <ATen/ops/values_copy_native.h>
|
| 1297 |
+
#include <ATen/ops/vander_native.h>
|
| 1298 |
+
#include <ATen/ops/var_native.h>
|
| 1299 |
+
#include <ATen/ops/var_mean_native.h>
|
| 1300 |
+
#include <ATen/ops/vdot_native.h>
|
| 1301 |
+
#include <ATen/ops/view_native.h>
|
| 1302 |
+
#include <ATen/ops/view_as_native.h>
|
| 1303 |
+
#include <ATen/ops/view_as_complex_native.h>
|
| 1304 |
+
#include <ATen/ops/view_as_complex_copy_native.h>
|
| 1305 |
+
#include <ATen/ops/view_as_real_native.h>
|
| 1306 |
+
#include <ATen/ops/view_as_real_copy_native.h>
|
| 1307 |
+
#include <ATen/ops/view_copy_native.h>
|
| 1308 |
+
#include <ATen/ops/vsplit_native.h>
|
| 1309 |
+
#include <ATen/ops/vstack_native.h>
|
| 1310 |
+
#include <ATen/ops/where_native.h>
|
| 1311 |
+
#include <ATen/ops/xlogy_native.h>
|
| 1312 |
+
#include <ATen/ops/xor_native.h>
|
| 1313 |
+
#include <ATen/ops/zero_native.h>
|
| 1314 |
+
#include <ATen/ops/zeros_native.h>
|
| 1315 |
+
#include <ATen/ops/zeros_like_native.h>
|
| 1316 |
+
|
| 1317 |
+
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/ParallelOpenMP.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <atomic>
|
| 5 |
+
#include <cstddef>
|
| 6 |
+
#include <exception>
|
| 7 |
+
|
| 8 |
+
#ifdef _OPENMP
|
| 9 |
+
#define INTRA_OP_PARALLEL
|
| 10 |
+
|
| 11 |
+
#include <omp.h>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
#ifdef _OPENMP
|
| 15 |
+
namespace at::internal {
|
| 16 |
+
template <typename F>
|
| 17 |
+
inline void invoke_parallel(
|
| 18 |
+
int64_t begin,
|
| 19 |
+
int64_t end,
|
| 20 |
+
int64_t grain_size,
|
| 21 |
+
const F& f) {
|
| 22 |
+
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
|
| 23 |
+
std::exception_ptr eptr;
|
| 24 |
+
|
| 25 |
+
#pragma omp parallel
|
| 26 |
+
{
|
| 27 |
+
// choose number of tasks based on grain size and number of threads
|
| 28 |
+
// can't use num_threads clause due to bugs in GOMP's thread pool (See
|
| 29 |
+
// #32008)
|
| 30 |
+
int64_t num_threads = omp_get_num_threads();
|
| 31 |
+
if (grain_size > 0) {
|
| 32 |
+
num_threads = std::min(num_threads, divup((end - begin), grain_size));
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
int64_t tid = omp_get_thread_num();
|
| 36 |
+
int64_t chunk_size = divup((end - begin), num_threads);
|
| 37 |
+
int64_t begin_tid = begin + tid * chunk_size;
|
| 38 |
+
if (begin_tid < end) {
|
| 39 |
+
try {
|
| 40 |
+
internal::ThreadIdGuard tid_guard(tid);
|
| 41 |
+
f(begin_tid, std::min(end, chunk_size + begin_tid));
|
| 42 |
+
} catch (...) {
|
| 43 |
+
if (!err_flag.test_and_set()) {
|
| 44 |
+
eptr = std::current_exception();
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
if (eptr) {
|
| 50 |
+
std::rethrow_exception(eptr);
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
} // namespace at::internal
|
| 54 |
+
#endif // _OPENMP
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/RedispatchFunctions.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SmallVector.h
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <c10/util/SmallVector.h>
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/SparseCsrTensorUtils.h
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/SparseCsrTensorImpl.h>
|
| 4 |
+
#include <ATen/SparseTensorImpl.h>
|
| 5 |
+
#include <ATen/core/Tensor.h>
|
| 6 |
+
|
| 7 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 8 |
+
#include <ATen/Functions.h>
|
| 9 |
+
#include <ATen/NativeFunctions.h>
|
| 10 |
+
#include <ATen/Operators.h>
|
| 11 |
+
#else
|
| 12 |
+
#include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
|
| 13 |
+
#include <ATen/ops/resize_as_sparse_native.h>
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
+
#define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
|
| 17 |
+
[&] { \
|
| 18 |
+
const auto& the_layout = LAYOUT; \
|
| 19 |
+
switch (the_layout) { \
|
| 20 |
+
case kSparseCsr: \
|
| 21 |
+
case kSparseCsc: \
|
| 22 |
+
case kSparseBsr: \
|
| 23 |
+
case kSparseBsc: \
|
| 24 |
+
return __VA_ARGS__(); \
|
| 25 |
+
default: \
|
| 26 |
+
AT_ERROR( \
|
| 27 |
+
NAME, \
|
| 28 |
+
" expected sparse compressed tensor layout but got ", \
|
| 29 |
+
the_layout); \
|
| 30 |
+
} \
|
| 31 |
+
}()
|
| 32 |
+
|
| 33 |
+
#define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
|
| 34 |
+
LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
|
| 35 |
+
[&]() { \
|
| 36 |
+
const auto& the_layout = LAYOUT; \
|
| 37 |
+
switch (the_layout) { \
|
| 38 |
+
case kSparseCsr: \
|
| 39 |
+
case kSparseBsr: \
|
| 40 |
+
return (ROW_DIM_ACTION)(); \
|
| 41 |
+
case kSparseCsc: \
|
| 42 |
+
case kSparseBsc: \
|
| 43 |
+
return (COLUMN_DIM_ACTION)(); \
|
| 44 |
+
default: \
|
| 45 |
+
AT_ERROR( \
|
| 46 |
+
NAME, \
|
| 47 |
+
" expected sparse compressed tensor layout but got ", \
|
| 48 |
+
the_layout); \
|
| 49 |
+
} \
|
| 50 |
+
}()
|
| 51 |
+
|
| 52 |
+
#define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
|
| 53 |
+
LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
|
| 54 |
+
[&]() { \
|
| 55 |
+
const auto& the_layout = LAYOUT; \
|
| 56 |
+
switch (the_layout) { \
|
| 57 |
+
case kSparseCsr: \
|
| 58 |
+
case kSparseCsc: \
|
| 59 |
+
return (NO_BLOCK_ACTION)(); \
|
| 60 |
+
case kSparseBsr: \
|
| 61 |
+
case kSparseBsc: \
|
| 62 |
+
return (BLOCK_ACTION)(); \
|
| 63 |
+
default: \
|
| 64 |
+
AT_ERROR( \
|
| 65 |
+
NAME, \
|
| 66 |
+
" expected sparse compressed tensor layout but got ", \
|
| 67 |
+
the_layout); \
|
| 68 |
+
} \
|
| 69 |
+
}()
|
| 70 |
+
|
| 71 |
+
#define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
|
| 72 |
+
LAYOUT, NAME, ROW_DIM_ACTION) \
|
| 73 |
+
[&]() { \
|
| 74 |
+
const auto& the_layout = LAYOUT; \
|
| 75 |
+
switch (the_layout) { \
|
| 76 |
+
case kSparseCsr: \
|
| 77 |
+
case kSparseBsr: \
|
| 78 |
+
return (ROW_DIM_ACTION)(); \
|
| 79 |
+
default: \
|
| 80 |
+
AT_ERROR( \
|
| 81 |
+
NAME, \
|
| 82 |
+
" expected sparse row compressed tensor layout but got ", \
|
| 83 |
+
the_layout); \
|
| 84 |
+
} \
|
| 85 |
+
}()
|
| 86 |
+
|
| 87 |
+
#define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
|
| 88 |
+
LAYOUT, NAME, COL_DIM_ACTION) \
|
| 89 |
+
[&]() { \
|
| 90 |
+
const auto& the_layout = LAYOUT; \
|
| 91 |
+
switch (the_layout) { \
|
| 92 |
+
case kSparseCsc: \
|
| 93 |
+
case kSparseBsc: \
|
| 94 |
+
return (COL_DIM_ACTION)(); \
|
| 95 |
+
default: \
|
| 96 |
+
AT_ERROR( \
|
| 97 |
+
NAME, \
|
| 98 |
+
" expected sparse column compressed tensor layout but got ", \
|
| 99 |
+
the_layout); \
|
| 100 |
+
} \
|
| 101 |
+
}()
|
| 102 |
+
|
| 103 |
+
#define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
|
| 104 |
+
[&]() { \
|
| 105 |
+
const auto& the_layout = LAYOUT; \
|
| 106 |
+
switch (the_layout) { \
|
| 107 |
+
case kSparseCsr: \
|
| 108 |
+
case kSparseCsc: \
|
| 109 |
+
return (ACTION)(); \
|
| 110 |
+
default: \
|
| 111 |
+
AT_ERROR( \
|
| 112 |
+
NAME, \
|
| 113 |
+
" expected sparse compressed (non-block) tensor layout but got ", \
|
| 114 |
+
the_layout); \
|
| 115 |
+
} \
|
| 116 |
+
}()
|
| 117 |
+
|
| 118 |
+
#define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
|
| 119 |
+
[&]() { \
|
| 120 |
+
const auto& the_layout = LAYOUT; \
|
| 121 |
+
switch (the_layout) { \
|
| 122 |
+
case kSparseBsr: \
|
| 123 |
+
case kSparseBsc: \
|
| 124 |
+
return (ACTION)(); \
|
| 125 |
+
default: \
|
| 126 |
+
AT_ERROR( \
|
| 127 |
+
NAME, \
|
| 128 |
+
" expected sparse compressed block tensor layout but got ", \
|
| 129 |
+
the_layout); \
|
| 130 |
+
} \
|
| 131 |
+
}()
|
| 132 |
+
|
| 133 |
+
#define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
|
| 134 |
+
AT_DISPATCH_SWITCH( \
|
| 135 |
+
TYPE, \
|
| 136 |
+
NAME, \
|
| 137 |
+
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
|
| 138 |
+
kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
|
| 139 |
+
|
| 140 |
+
namespace at::sparse_csr {
|
| 141 |
+
|
| 142 |
+
using SparseCsrTensor = Tensor;
|
| 143 |
+
|
| 144 |
+
inline bool is_sparse_compressed(const Layout& layout) {
|
| 145 |
+
switch (layout) {
|
| 146 |
+
case kSparseCsr:
|
| 147 |
+
case kSparseCsc:
|
| 148 |
+
case kSparseBsr:
|
| 149 |
+
case kSparseBsc:
|
| 150 |
+
return true;
|
| 151 |
+
default:;
|
| 152 |
+
}
|
| 153 |
+
return false;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
inline bool is_sparse_compressed(const Tensor& self) {
|
| 157 |
+
return is_sparse_compressed(self.layout());
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
|
| 161 |
+
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
|
| 162 |
+
self.layout(), "get_sparse_csr_impl", [&] {});
|
| 163 |
+
return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
inline std::string layoutToString(
|
| 167 |
+
Layout layout,
|
| 168 |
+
bool upper = false,
|
| 169 |
+
bool lower = false) {
|
| 170 |
+
switch (layout) {
|
| 171 |
+
case kSparseCsr:
|
| 172 |
+
return (upper ? "CSR" : (lower ? "csr" : "Csr"));
|
| 173 |
+
case kSparseCsc:
|
| 174 |
+
return (upper ? "CSC" : (lower ? "csc" : "Csc"));
|
| 175 |
+
case kSparseBsr:
|
| 176 |
+
return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
|
| 177 |
+
case kSparseBsc:
|
| 178 |
+
return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
|
| 179 |
+
default:
|
| 180 |
+
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
|
| 181 |
+
return "";
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
inline bool isCompressedRow(Layout layout) {
|
| 186 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 187 |
+
layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
inline bool isCompressedColumn(Layout layout) {
|
| 191 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 192 |
+
layout,
|
| 193 |
+
"isCompressedColumn",
|
| 194 |
+
[&] { return false; },
|
| 195 |
+
[&] { return true; });
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
inline std::string compressedIndicesName(Layout layout) {
|
| 199 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 200 |
+
layout,
|
| 201 |
+
"compressedIndicesName",
|
| 202 |
+
[&] { return "crow_indices"; },
|
| 203 |
+
[&] { return "ccol_indices"; });
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
inline std::string plainIndicesName(Layout layout) {
|
| 207 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 208 |
+
layout,
|
| 209 |
+
"plainIndicesName",
|
| 210 |
+
[&] { return "col_indices"; },
|
| 211 |
+
[&] { return "row_indices"; });
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
inline std::string compressedDimName(Layout layout) {
|
| 215 |
+
switch (layout) {
|
| 216 |
+
case kSparseCsr:
|
| 217 |
+
return "row";
|
| 218 |
+
case kSparseCsc:
|
| 219 |
+
return "column";
|
| 220 |
+
case kSparseBsr:
|
| 221 |
+
return "row block";
|
| 222 |
+
case kSparseBsc:
|
| 223 |
+
return "column block";
|
| 224 |
+
default:
|
| 225 |
+
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
|
| 226 |
+
return "";
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
inline std::string plainDimName(Layout layout) {
|
| 231 |
+
switch (layout) {
|
| 232 |
+
case kSparseCsr:
|
| 233 |
+
return "column";
|
| 234 |
+
case kSparseCsc:
|
| 235 |
+
return "row";
|
| 236 |
+
case kSparseBsr:
|
| 237 |
+
return "column block";
|
| 238 |
+
case kSparseBsc:
|
| 239 |
+
return "row block";
|
| 240 |
+
default:
|
| 241 |
+
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
|
| 242 |
+
return "";
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
inline size_t rowDimension(Layout layout, IntArrayRef size) {
|
| 247 |
+
return size.size() - (isCompressedRow(layout) ? 2 : 1);
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
inline size_t columnDimension(Layout layout, IntArrayRef size) {
|
| 251 |
+
return size.size() - (isCompressedColumn(layout) ? 2 : 1);
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
inline size_t compressedDimension(
|
| 255 |
+
Layout layout,
|
| 256 |
+
IntArrayRef size,
|
| 257 |
+
size_t dense_ndim = 0) {
|
| 258 |
+
return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
inline size_t plainDimension(
|
| 262 |
+
Layout layout,
|
| 263 |
+
IntArrayRef size,
|
| 264 |
+
size_t dense_ndim = 0) {
|
| 265 |
+
return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
inline int64_t numBatchDimensions(Tensor const& self) {
|
| 269 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 270 |
+
self.layout(),
|
| 271 |
+
"numBatchDimensions",
|
| 272 |
+
[&self] { return self.crow_indices().dim() - 1; },
|
| 273 |
+
[&self] { return self.ccol_indices().dim() - 1; });
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
|
| 277 |
+
return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
|
| 278 |
+
self.layout(),
|
| 279 |
+
"getCompressedPlainIndices",
|
| 280 |
+
[&self] {
|
| 281 |
+
return std::make_pair(self.crow_indices(), self.col_indices());
|
| 282 |
+
},
|
| 283 |
+
[&self] {
|
| 284 |
+
return std::make_pair(self.ccol_indices(), self.row_indices());
|
| 285 |
+
});
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
inline Layout flip_compressed_layout(Layout layout) {
|
| 289 |
+
switch (layout) {
|
| 290 |
+
case kSparseCsr:
|
| 291 |
+
return kSparseCsc;
|
| 292 |
+
case kSparseCsc:
|
| 293 |
+
return kSparseCsr;
|
| 294 |
+
case kSparseBsr:
|
| 295 |
+
return kSparseBsc;
|
| 296 |
+
case kSparseBsc:
|
| 297 |
+
return kSparseBsr;
|
| 298 |
+
default:
|
| 299 |
+
TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
|
| 300 |
+
return kSparseCsr;
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
inline DimVector getBlockSize(Tensor const& self) {
|
| 305 |
+
int64_t n_batch = numBatchDimensions(self);
|
| 306 |
+
return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
|
| 310 |
+
if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
|
| 311 |
+
int64_t n_batch = numBatchDimensions(self);
|
| 312 |
+
return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
|
| 313 |
+
} else {
|
| 314 |
+
return {};
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
template <typename binary_op_t, typename binary_op_out_t>
|
| 319 |
+
inline bool only_sparse_compressed_binary_op_trivial_cases(
|
| 320 |
+
const Tensor& self,
|
| 321 |
+
const Tensor& other,
|
| 322 |
+
const Scalar& alpha,
|
| 323 |
+
Tensor& out,
|
| 324 |
+
const binary_op_t& binary_op,
|
| 325 |
+
const binary_op_out_t& binary_op_out) {
|
| 326 |
+
// Only sparse compressed! Just like the name says :)
|
| 327 |
+
TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
|
| 328 |
+
TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
|
| 329 |
+
TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
|
| 330 |
+
|
| 331 |
+
// Bypass BLAS if there are matches in (self, other, out)
|
| 332 |
+
if (self.is_same(out) && self.is_same(other)) {
|
| 333 |
+
binary_op_out(self.values(), other.values(), alpha);
|
| 334 |
+
return true;
|
| 335 |
+
}
|
| 336 |
+
if (self.is_same(other)) {
|
| 337 |
+
auto [compressed_indices, plain_indices] =
|
| 338 |
+
at::sparse_csr::getCompressedPlainIndices(self);
|
| 339 |
+
static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
|
| 340 |
+
->set_member_tensors(
|
| 341 |
+
compressed_indices,
|
| 342 |
+
plain_indices,
|
| 343 |
+
binary_op(self.values(), other.values(), alpha),
|
| 344 |
+
self.sizes());
|
| 345 |
+
return true;
|
| 346 |
+
}
|
| 347 |
+
return false;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
inline bool only_sparse_compressed_add_trivial_cases(
|
| 351 |
+
const Tensor& self,
|
| 352 |
+
const Tensor& other,
|
| 353 |
+
const Scalar& alpha,
|
| 354 |
+
Tensor& out) {
|
| 355 |
+
return only_sparse_compressed_binary_op_trivial_cases(
|
| 356 |
+
self,
|
| 357 |
+
other,
|
| 358 |
+
alpha,
|
| 359 |
+
out,
|
| 360 |
+
[](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
|
| 361 |
+
return v1.add(v2, alpha);
|
| 362 |
+
},
|
| 363 |
+
[](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
|
| 364 |
+
return v1.add_(v2, alpha);
|
| 365 |
+
});
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
inline Tensor to_type(const Tensor& input, ScalarType dtype) {
|
| 369 |
+
auto [compressed_indices, plain_indices] =
|
| 370 |
+
at::sparse_csr::getCompressedPlainIndices(input);
|
| 371 |
+
return at::_sparse_compressed_tensor_unsafe(
|
| 372 |
+
compressed_indices,
|
| 373 |
+
plain_indices,
|
| 374 |
+
std::move(input.values()).to(dtype),
|
| 375 |
+
input.sizes(),
|
| 376 |
+
dtype,
|
| 377 |
+
input.layout(),
|
| 378 |
+
input.device(),
|
| 379 |
+
input.options().pinned_memory_opt());
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
template <typename acc_t, typename scalar_t>
|
| 383 |
+
inline std::tuple<Tensor, Tensor> create_acc_buffer(
|
| 384 |
+
TensorOptions option,
|
| 385 |
+
ScalarType type,
|
| 386 |
+
int64_t nnz = -1) {
|
| 387 |
+
Tensor new_values, new_values_acc;
|
| 388 |
+
constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
|
| 389 |
+
bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
|
| 390 |
+
if constexpr (need_acc) {
|
| 391 |
+
auto acc_dtype = CppTypeToScalarType<acc_t>::value;
|
| 392 |
+
new_values_acc = at::empty({}, option.dtype(acc_dtype));
|
| 393 |
+
new_values = is_integral ? new_values_acc : at::empty({}, option);
|
| 394 |
+
} else {
|
| 395 |
+
new_values = new_values_acc = at::empty({}, option);
|
| 396 |
+
}
|
| 397 |
+
if (nnz != -1) {
|
| 398 |
+
return std::make_tuple(
|
| 399 |
+
new_values.resize_(nnz), new_values_acc.resize_(nnz));
|
| 400 |
+
} else {
|
| 401 |
+
return std::make_tuple(new_values, new_values_acc);
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
|
| 406 |
+
if (!new_values_acc.is_same(new_values)) {
|
| 407 |
+
new_values.copy_(new_values_acc);
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
} // namespace at::sparse_csr
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TensorGeometry.h
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/TensorBase.h>
|
| 4 |
+
#include <c10/core/WrapDimMinimal.h>
|
| 5 |
+
|
| 6 |
+
namespace at {
|
| 7 |
+
|
| 8 |
+
// Return if the tensor geometry represented by `sizes` and `strides` is
|
| 9 |
+
// contiguous Although we cache is_contiguous in tensor now, this is till useful
|
| 10 |
+
// because it allows checking if a particular geometry is contiguous without
|
| 11 |
+
// explicitly constructing a tensor, e.g., when you want to choose a kernel
|
| 12 |
+
// strategy based on whether a subgeometry is contiguous.
|
| 13 |
+
TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
|
| 14 |
+
|
| 15 |
+
struct TORCH_API TensorGeometry {
|
| 16 |
+
TensorGeometry() = default;
|
| 17 |
+
|
| 18 |
+
explicit TensorGeometry(c10::SymIntArrayRef sizes)
|
| 19 |
+
: sizes_(sizes.vec()),
|
| 20 |
+
strides_(sizes.size()),
|
| 21 |
+
has_symbolic_sizes_strides_(
|
| 22 |
+
!c10::asIntArrayRefSlowOpt(sizes).has_value()) {
|
| 23 |
+
int64_t dim = static_cast<int64_t>(sizes.size());
|
| 24 |
+
c10::SymInt expected_stride = 1;
|
| 25 |
+
for (int64_t i = dim - 1; i >= 0; i--) {
|
| 26 |
+
strides_[i] = expected_stride;
|
| 27 |
+
expected_stride *= sizes_[i];
|
| 28 |
+
}
|
| 29 |
+
numel_ = expected_stride;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
explicit TensorGeometry(const TensorBase& t)
|
| 33 |
+
: sizes_(t.sym_sizes().vec()),
|
| 34 |
+
strides_(t.sym_strides().vec()),
|
| 35 |
+
storage_offset_(t.sym_storage_offset()),
|
| 36 |
+
numel_(t.sym_numel()),
|
| 37 |
+
has_symbolic_sizes_strides_(
|
| 38 |
+
t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
|
| 39 |
+
|
| 40 |
+
// true if the tensor is contiguous
|
| 41 |
+
bool is_contiguous() const;
|
| 42 |
+
|
| 43 |
+
int64_t dim() const {
|
| 44 |
+
return static_cast<int64_t>(sizes_.size());
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
int64_t size(int64_t dim) const {
|
| 48 |
+
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
|
| 49 |
+
dim = c10::maybe_wrap_dim(dim, this->dim());
|
| 50 |
+
return sizes_.at(static_cast<size_t>(dim)).as_int_unchecked();
|
| 51 |
+
}
|
| 52 |
+
c10::IntArrayRef sizes() const {
|
| 53 |
+
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
|
| 54 |
+
return c10::asIntArrayRefUnchecked(sizes_);
|
| 55 |
+
}
|
| 56 |
+
int64_t stride(int64_t dim) const {
|
| 57 |
+
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
|
| 58 |
+
dim = c10::maybe_wrap_dim(dim, this->dim());
|
| 59 |
+
return strides_.at(static_cast<size_t>(dim)).as_int_unchecked();
|
| 60 |
+
}
|
| 61 |
+
c10::IntArrayRef strides() const {
|
| 62 |
+
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
|
| 63 |
+
return c10::asIntArrayRefUnchecked(strides_);
|
| 64 |
+
}
|
| 65 |
+
int64_t storage_offset() const {
|
| 66 |
+
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
|
| 67 |
+
return storage_offset_.as_int_unchecked();
|
| 68 |
+
}
|
| 69 |
+
int64_t numel() const {
|
| 70 |
+
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
|
| 71 |
+
return numel_.as_int_unchecked();
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
c10::SymInt sym_size(int64_t dim) const {
|
| 75 |
+
dim = c10::maybe_wrap_dim(dim, this->dim());
|
| 76 |
+
return sizes_.at(static_cast<size_t>(dim));
|
| 77 |
+
}
|
| 78 |
+
c10::SymIntArrayRef sym_sizes() const {
|
| 79 |
+
return sizes_;
|
| 80 |
+
}
|
| 81 |
+
c10::SymInt sym_stride(int64_t dim) const {
|
| 82 |
+
dim = c10::maybe_wrap_dim(dim, this->dim());
|
| 83 |
+
return strides_.at(static_cast<size_t>(dim));
|
| 84 |
+
}
|
| 85 |
+
c10::SymIntArrayRef sym_strides() const {
|
| 86 |
+
return strides_;
|
| 87 |
+
}
|
| 88 |
+
c10::SymInt sym_storage_offset() const {
|
| 89 |
+
return storage_offset_;
|
| 90 |
+
}
|
| 91 |
+
c10::SymInt sym_numel() const {
|
| 92 |
+
return numel_;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
TensorGeometry transpose(int64_t dim0, int64_t dim1) {
|
| 96 |
+
TensorGeometry r = *this; // copy
|
| 97 |
+
TORCH_CHECK(
|
| 98 |
+
dim0 < dim(),
|
| 99 |
+
"transpose: dim0=",
|
| 100 |
+
dim0,
|
| 101 |
+
" out of range (dim=",
|
| 102 |
+
dim(),
|
| 103 |
+
")")
|
| 104 |
+
TORCH_CHECK(
|
| 105 |
+
dim1 < dim(),
|
| 106 |
+
"transpose: dim1=",
|
| 107 |
+
dim1,
|
| 108 |
+
" out of range (dim=",
|
| 109 |
+
dim(),
|
| 110 |
+
")")
|
| 111 |
+
std::swap(r.sizes_[dim0], r.sizes_[dim1]);
|
| 112 |
+
std::swap(r.strides_[dim0], r.strides_[dim1]);
|
| 113 |
+
return r;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
std::vector<c10::SymInt>& mutable_sizes() {
|
| 117 |
+
return sizes_;
|
| 118 |
+
}
|
| 119 |
+
std::vector<c10::SymInt>& mutable_strides() {
|
| 120 |
+
return strides_;
|
| 121 |
+
}
|
| 122 |
+
c10::SymInt& mutable_storage_offset() {
|
| 123 |
+
return storage_offset_;
|
| 124 |
+
}
|
| 125 |
+
void recompute() {
|
| 126 |
+
// recalculate numel after a change
|
| 127 |
+
c10::SymInt numel = 1;
|
| 128 |
+
for (const auto& i : sizes_) {
|
| 129 |
+
numel = numel * i;
|
| 130 |
+
}
|
| 131 |
+
numel_ = std::move(numel);
|
| 132 |
+
has_symbolic_sizes_strides_ =
|
| 133 |
+
!c10::asIntArrayRefSlowOpt(sizes_).has_value();
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
private:
|
| 137 |
+
std::vector<c10::SymInt> sizes_;
|
| 138 |
+
std::vector<c10::SymInt> strides_;
|
| 139 |
+
c10::SymInt storage_offset_;
|
| 140 |
+
c10::SymInt numel_;
|
| 141 |
+
bool has_symbolic_sizes_strides_{false};
|
| 142 |
+
};
|
| 143 |
+
|
| 144 |
+
} // namespace at
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/TracerMode.h
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 4 |
+
#include <c10/macros/Export.h>
|
| 5 |
+
#include <c10/macros/Macros.h>
|
| 6 |
+
|
| 7 |
+
// NOTE [Tracing Mode Switches]
|
| 8 |
+
//
|
| 9 |
+
// Historically, tracing function was controlled by two switches:
|
| 10 |
+
//
|
| 11 |
+
// - `AutoDispatchBelowADInplaceOrView` guard
|
| 12 |
+
//
|
| 13 |
+
// Tracing function used to be script-generated inside `VariableType_*.cpp`
|
| 14 |
+
// kernels, sharing the same `Autograd` dispatch key with autograd function.
|
| 15 |
+
// Therefore, before tracing function was moved out of VariableType,
|
| 16 |
+
// `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a
|
| 17 |
+
// side effect of disabling `Autograd` dispatching.
|
| 18 |
+
//
|
| 19 |
+
// - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
|
| 20 |
+
//
|
| 21 |
+
// It stores tracing data in a `TracingState` object in TLS. If the
|
| 22 |
+
// `TracingState` object in TLS is `null`, then tracing is paused.
|
| 23 |
+
//
|
| 24 |
+
// The `TracingState` object is created in `tracer::trace()` - the main
|
| 25 |
+
// entrance of tracing function. It's temporarily set to `null` inside
|
| 26 |
+
// generated VariableType (now TraceType) to bypass tracing for intermediate
|
| 27 |
+
// ops (ops being called by other ops). After the intermediate op call
|
| 28 |
+
// finishes it's set back to the original `TracingState` object.
|
| 29 |
+
//
|
| 30 |
+
// The `TracingState` obect in TLS can also be read/written via its Python
|
| 31 |
+
// binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs,
|
| 32 |
+
// which are also exposed as `TORCH_API`.
|
| 33 |
+
//
|
| 34 |
+
// Two new switches were introduced since tracing function was moved out of
|
| 35 |
+
// VariableType:
|
| 36 |
+
//
|
| 37 |
+
// - `tracer::impl::set_dispatch_enabled()` API
|
| 38 |
+
//
|
| 39 |
+
// Unlike the special `Autograd` dispatch key which is included in dispatch
|
| 40 |
+
// key set by default, `Tracer` dispatch key is off by default. The
|
| 41 |
+
// dispatching switch can be toggled via this new API.
|
| 42 |
+
//
|
| 43 |
+
// - `tracer::impl::NoTracerDispatchMode` guard
|
| 44 |
+
//
|
| 45 |
+
// It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView`
|
| 46 |
+
// after tracing was moved out of VariableType.
|
| 47 |
+
//
|
| 48 |
+
// Before tracing function was moved out of VariableType, tracing was enabled
|
| 49 |
+
// when the following conditions are satisfied:
|
| 50 |
+
//
|
| 51 |
+
// 1) `TracingState` object in TLS != null;
|
| 52 |
+
// - Either inside the execution scope of `tracer::trace()`, or
|
| 53 |
+
// - Eagerly called `setTracingState()` with non-null object.
|
| 54 |
+
// 2) Not inside `AutoDispatchBelowADInplaceOrView` scope;
|
| 55 |
+
//
|
| 56 |
+
// After:
|
| 57 |
+
//
|
| 58 |
+
// 1) `TracingState` object in TLS != null;
|
| 59 |
+
// 2) Has called `tracer::impl::set_dispatch_enabled(true)`;
|
| 60 |
+
// 3) Not inside `tracer::impl::NonDispatchGuard` scope;
|
| 61 |
+
//
|
| 62 |
+
// [TODOs]
|
| 63 |
+
//
|
| 64 |
+
// - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()`
|
| 65 |
+
//
|
| 66 |
+
// Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()`
|
| 67 |
+
// to keep the semantics exactly the same as before - it's confusing to keep
|
| 68 |
+
// both switches, though. We should consider simplifying/limiting the exposed
|
| 69 |
+
// `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
|
| 70 |
+
// these two can be unified.
|
| 71 |
+
//
|
| 72 |
+
// - `AutoDispatchBelowADInplaceOrView` v.s.
|
| 73 |
+
// `tracer::impl::NoTracerDispatchMode`
|
| 74 |
+
//
|
| 75 |
+
// We don't need to always set both guards together to keep semantics
|
| 76 |
+
// unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView`
|
| 77 |
+
// we don't need set the new tracer guard:
|
| 78 |
+
//
|
| 79 |
+
// * Script-generated VariableType kernels. The guard is not necessary as
|
| 80 |
+
// tracing is already disabled explicitly by `setTracingState(null)` in
|
| 81 |
+
// generated TraceType kernels - we could keep it as is or use the new guard
|
| 82 |
+
// instead.
|
| 83 |
+
//
|
| 84 |
+
// * Custom ops. Will be handled by fallback kernel for `Tracer`.
|
| 85 |
+
//
|
| 86 |
+
// * Functions that are not likely to be called in tracing context (no python
|
| 87 |
+
// binding / not an operator), e.g.: all mobile forward() wrappers, test
|
| 88 |
+
// binaries, and etc.
|
| 89 |
+
//
|
| 90 |
+
// * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp.
|
| 91 |
+
// It's not necessary as tracing is off by default.
|
| 92 |
+
//
|
| 93 |
+
// For the rest of cases we might need have both:
|
| 94 |
+
//
|
| 95 |
+
// * Functions that might be reachable from eager mode python (especially
|
| 96 |
+
// factory methods), e.g.:
|
| 97 |
+
// `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`.
|
| 98 |
+
// Without the new guard it will add `aten::empty` to the traced graph.
|
| 99 |
+
//
|
| 100 |
+
// * Some manually maintained functions, e.g.:
|
| 101 |
+
// `torch/csrc/autograd/VariableTypeManual.cpp`.
|
| 102 |
+
// Set the new guard if it's not obvious whether `setTracingState(null)`
|
| 103 |
+
// has been called before it reaches the `AutoDispatchBelowADInplaceOrView`
|
| 104 |
+
// guard.
|
| 105 |
+
//
|
| 106 |
+
// We might need tweak the usage of the new guard to optimize/fix things.
|
| 107 |
+
// It should only affect the correctness of tracing function, because the
|
| 108 |
+
// guard is essentially no-op when the master `setTracingState()` switch is
|
| 109 |
+
// off.
|
| 110 |
+
|
| 111 |
+
// TODO: move this from `at::` to `jit::torch::` after
|
| 112 |
+
// `aten/src/ATen/cpp_custom_type_hack.h` is removed.
|
| 113 |
+
|
| 114 |
+
namespace at::tracer::impl {
|
| 115 |
+
|
| 116 |
+
static inline bool is_dispatch_enabled() {
|
| 117 |
+
return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
|
| 118 |
+
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
static inline void set_dispatch_enabled(bool enabled) {
|
| 122 |
+
TORCH_INTERNAL_ASSERT(
|
| 123 |
+
!c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
|
| 124 |
+
"Cannot enable tracing within the scope of NoTracerDispatchMode!");
|
| 125 |
+
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
struct NoTracerDispatchMode {
|
| 129 |
+
c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
} // namespace at::tracer::impl
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/autocast_mode.h
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/ATen.h>
|
| 4 |
+
#include <ATen/NativeFunctions.h>
|
| 5 |
+
#include <ATen/Operators.h>
|
| 6 |
+
#include <torch/library.h>
|
| 7 |
+
|
| 8 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 9 |
+
#include <c10/util/intrusive_ptr.h>
|
| 10 |
+
|
| 11 |
+
namespace at::autocast {
|
| 12 |
+
|
| 13 |
+
TORCH_API bool is_enabled();
|
| 14 |
+
TORCH_API void set_enabled(bool enabled);
|
| 15 |
+
TORCH_API void clear_cache();
|
| 16 |
+
TORCH_API int increment_nesting();
|
| 17 |
+
TORCH_API int decrement_nesting();
|
| 18 |
+
TORCH_API bool is_cpu_enabled();
|
| 19 |
+
TORCH_API void set_cpu_enabled(bool enabled);
|
| 20 |
+
TORCH_API at::ScalarType get_autocast_gpu_dtype();
|
| 21 |
+
TORCH_API at::ScalarType get_autocast_cpu_dtype();
|
| 22 |
+
TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
|
| 23 |
+
TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
|
| 24 |
+
TORCH_API bool is_xpu_enabled();
|
| 25 |
+
TORCH_API void set_xpu_enabled(bool enabled);
|
| 26 |
+
TORCH_API at::ScalarType get_autocast_xpu_dtype();
|
| 27 |
+
TORCH_API void set_autocast_xpu_dtype(at::ScalarType dtype);
|
| 28 |
+
TORCH_API bool is_ipu_enabled();
|
| 29 |
+
TORCH_API void set_ipu_enabled(bool enabled);
|
| 30 |
+
TORCH_API at::ScalarType get_autocast_ipu_dtype();
|
| 31 |
+
TORCH_API void set_autocast_ipu_dtype(at::ScalarType dtype);
|
| 32 |
+
TORCH_API bool is_hpu_enabled();
|
| 33 |
+
TORCH_API void set_hpu_enabled(bool enabled);
|
| 34 |
+
TORCH_API at::ScalarType get_autocast_hpu_dtype();
|
| 35 |
+
TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
|
| 36 |
+
TORCH_API bool is_xla_enabled();
|
| 37 |
+
TORCH_API void set_xla_enabled(bool enabled);
|
| 38 |
+
TORCH_API at::ScalarType get_autocast_xla_dtype();
|
| 39 |
+
TORCH_API void set_autocast_xla_dtype(at::ScalarType dtype);
|
| 40 |
+
TORCH_API bool is_privateuseone_enabled();
|
| 41 |
+
TORCH_API void set_privateuseone_enabled(bool enabled);
|
| 42 |
+
TORCH_API at::ScalarType get_autocast_privateuseone_dtype();
|
| 43 |
+
TORCH_API void set_autocast_privateuseone_dtype(at::ScalarType dtype);
|
| 44 |
+
TORCH_API bool is_autocast_cache_enabled();
|
| 45 |
+
TORCH_API void set_autocast_cache_enabled(bool enabled);
|
| 46 |
+
|
| 47 |
+
namespace {
|
| 48 |
+
inline bool is_autocast_eligible(
|
| 49 |
+
const Tensor& tensor,
|
| 50 |
+
c10::DeviceType device_type) {
|
| 51 |
+
switch (device_type) {
|
| 52 |
+
case c10::DeviceType::CUDA:
|
| 53 |
+
return (tensor.is_cuda() || tensor.is_xla()) &&
|
| 54 |
+
tensor.is_floating_point();
|
| 55 |
+
case c10::DeviceType::CPU:
|
| 56 |
+
return (tensor.is_cpu() || tensor.is_mkldnn()) &&
|
| 57 |
+
tensor.is_floating_point();
|
| 58 |
+
case c10::DeviceType::XPU:
|
| 59 |
+
return tensor.is_xpu() && tensor.is_floating_point();
|
| 60 |
+
case c10::DeviceType::IPU:
|
| 61 |
+
return tensor.is_ipu() && tensor.is_floating_point();
|
| 62 |
+
case c10::DeviceType::HPU:
|
| 63 |
+
return tensor.is_hpu() && tensor.is_floating_point();
|
| 64 |
+
case c10::DeviceType::XLA:
|
| 65 |
+
return tensor.is_xla() && tensor.is_floating_point();
|
| 66 |
+
case c10::DeviceType::PrivateUse1:
|
| 67 |
+
return tensor.is_privateuseone() && tensor.is_floating_point();
|
| 68 |
+
default:
|
| 69 |
+
return false;
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
} // namespace
|
| 73 |
+
|
| 74 |
+
inline DispatchKey get_autocast_dispatch_key_from_device_type(
|
| 75 |
+
c10::DeviceType device_type) {
|
| 76 |
+
switch (device_type) {
|
| 77 |
+
case c10::DeviceType::CUDA:
|
| 78 |
+
return DispatchKey::Autocast;
|
| 79 |
+
case c10::DeviceType::CPU:
|
| 80 |
+
return DispatchKey::AutocastCPU;
|
| 81 |
+
case c10::DeviceType::XPU:
|
| 82 |
+
return DispatchKey::AutocastXPU;
|
| 83 |
+
case c10::DeviceType::IPU:
|
| 84 |
+
return DispatchKey::AutocastIPU;
|
| 85 |
+
case c10::DeviceType::HPU:
|
| 86 |
+
return DispatchKey::AutocastHPU;
|
| 87 |
+
case c10::DeviceType::XLA:
|
| 88 |
+
return DispatchKey::AutocastXLA;
|
| 89 |
+
case c10::DeviceType::PrivateUse1:
|
| 90 |
+
return DispatchKey::AutocastPrivateUse1;
|
| 91 |
+
default:
|
| 92 |
+
throw std::runtime_error(
|
| 93 |
+
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
inline at::ScalarType get_lower_precision_fp_from_device_type(
|
| 98 |
+
c10::DeviceType device_type) {
|
| 99 |
+
switch (device_type) {
|
| 100 |
+
case c10::DeviceType::CUDA:
|
| 101 |
+
return get_autocast_gpu_dtype();
|
| 102 |
+
case c10::DeviceType::CPU:
|
| 103 |
+
return get_autocast_cpu_dtype();
|
| 104 |
+
case c10::DeviceType::XPU:
|
| 105 |
+
return get_autocast_xpu_dtype();
|
| 106 |
+
case c10::DeviceType::IPU:
|
| 107 |
+
return get_autocast_ipu_dtype();
|
| 108 |
+
case c10::DeviceType::HPU:
|
| 109 |
+
return get_autocast_hpu_dtype();
|
| 110 |
+
case c10::DeviceType::XLA:
|
| 111 |
+
return get_autocast_xla_dtype();
|
| 112 |
+
case c10::DeviceType::PrivateUse1:
|
| 113 |
+
return get_autocast_privateuseone_dtype();
|
| 114 |
+
default:
|
| 115 |
+
throw std::runtime_error(
|
| 116 |
+
"unknown device type for autocast in get_lower_precision_fp_from_device_type");
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
/********************************************************************
|
| 121 |
+
Logic to extract the promote type from any Tensor or TensorList args.
|
| 122 |
+
********************************************************************/
|
| 123 |
+
|
| 124 |
+
// Overload to catch Tensor args.
|
| 125 |
+
// If nextArg is floating-point, compare its scalar_type with our
|
| 126 |
+
// current best guess for the promote type, and update if necessary.
|
| 127 |
+
inline at::ScalarType prioritize(
|
| 128 |
+
at::ScalarType current,
|
| 129 |
+
const Tensor& nextArg,
|
| 130 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
| 131 |
+
if (current == at::kDouble) {
|
| 132 |
+
AT_ERROR("promote type is double in at::autocast::prioritize");
|
| 133 |
+
return current;
|
| 134 |
+
}
|
| 135 |
+
at::ScalarType lower_precision_fp =
|
| 136 |
+
get_lower_precision_fp_from_device_type(device_type);
|
| 137 |
+
if (is_autocast_eligible(nextArg, device_type)) {
|
| 138 |
+
auto next = nextArg.scalar_type();
|
| 139 |
+
if (next == at::kDouble) {
|
| 140 |
+
return current; // ignores double tensors
|
| 141 |
+
} else if (current == at::kFloat || next == at::kFloat) {
|
| 142 |
+
return at::kFloat; // prioritizes float over lower_precision_fp
|
| 143 |
+
} else if (current == lower_precision_fp && next == lower_precision_fp) {
|
| 144 |
+
return lower_precision_fp;
|
| 145 |
+
} else {
|
| 146 |
+
AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
|
| 147 |
+
return current;
|
| 148 |
+
}
|
| 149 |
+
} else {
|
| 150 |
+
return current;
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
// Overload to catch TensorList args (for e.g. cat, stack).
|
| 155 |
+
// Reuses the overload above to process each Tensor in the list.
|
| 156 |
+
inline at::ScalarType prioritize(
|
| 157 |
+
at::ScalarType current,
|
| 158 |
+
const TensorList& list,
|
| 159 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
| 160 |
+
for (const auto& tensor : list) {
|
| 161 |
+
current = prioritize(current, tensor, device_type);
|
| 162 |
+
}
|
| 163 |
+
return current;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
inline at::ScalarType prioritize(
|
| 167 |
+
at::ScalarType current,
|
| 168 |
+
const ITensorListRef& list,
|
| 169 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
| 170 |
+
for (const auto& tensor : list) {
|
| 171 |
+
current = prioritize(current, tensor, device_type);
|
| 172 |
+
}
|
| 173 |
+
return current;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
// Template to catch non-Tensor args (no-op that returns current best guess)
|
| 177 |
+
template <typename T>
|
| 178 |
+
inline at::ScalarType prioritize(
|
| 179 |
+
at::ScalarType current,
|
| 180 |
+
T nextArg,
|
| 181 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
| 182 |
+
return current;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
// Overload for the tail case.
|
| 186 |
+
inline at::ScalarType promote_type(
|
| 187 |
+
at::ScalarType current,
|
| 188 |
+
c10::DeviceType device_type) {
|
| 189 |
+
return current;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
// Unpack args and determine if incoming lower_precision_fp tensors need to be
|
| 193 |
+
// promoted to float32. Non-Tensor arguments are ignored.
|
| 194 |
+
template <typename Arg0, typename... Args>
|
| 195 |
+
inline at::ScalarType promote_type(
|
| 196 |
+
at::ScalarType current,
|
| 197 |
+
c10::DeviceType device_type,
|
| 198 |
+
Arg0 arg0,
|
| 199 |
+
Args... args) {
|
| 200 |
+
auto new_current = prioritize(current, arg0, device_type);
|
| 201 |
+
return promote_type(new_current, device_type, args...);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
/****************************************************
|
| 205 |
+
Logic to apply cached casting to any Tensor argument.
|
| 206 |
+
****************************************************/
|
| 207 |
+
inline bool is_eligible(
|
| 208 |
+
const Tensor& arg,
|
| 209 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
| 210 |
+
return (
|
| 211 |
+
arg.defined() && is_autocast_eligible(arg, device_type) &&
|
| 212 |
+
(arg.scalar_type() != at::kDouble));
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// Overload to catch Tensor args
|
| 216 |
+
TORCH_API Tensor cached_cast(
|
| 217 |
+
at::ScalarType to_type,
|
| 218 |
+
const Tensor& arg,
|
| 219 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA);
|
| 220 |
+
|
| 221 |
+
// Overload to process optional<Tensor>
|
| 222 |
+
inline c10::optional<Tensor> cached_cast(
|
| 223 |
+
at::ScalarType to_type,
|
| 224 |
+
const c10::optional<Tensor>& arg,
|
| 225 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
| 226 |
+
if (arg.has_value()) {
|
| 227 |
+
return cached_cast(to_type, *arg, device_type);
|
| 228 |
+
} else {
|
| 229 |
+
return c10::nullopt;
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
// Overload to process TensorLists
|
| 234 |
+
inline std::vector<Tensor> cached_cast(
|
| 235 |
+
at::ScalarType to_type,
|
| 236 |
+
const TensorList& arg,
|
| 237 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
| 238 |
+
std::vector<Tensor> vec;
|
| 239 |
+
vec.reserve(arg.size());
|
| 240 |
+
for (const auto& t : arg) {
|
| 241 |
+
vec.emplace_back(cached_cast(to_type, t, device_type));
|
| 242 |
+
}
|
| 243 |
+
return vec;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
inline std::vector<Tensor> cached_cast(
|
| 247 |
+
at::ScalarType to_type,
|
| 248 |
+
const ITensorListRef& arg,
|
| 249 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
| 250 |
+
std::vector<Tensor> vec;
|
| 251 |
+
vec.reserve(arg.size());
|
| 252 |
+
for (const auto& t : arg) {
|
| 253 |
+
vec.emplace_back(cached_cast(to_type, t, device_type));
|
| 254 |
+
}
|
| 255 |
+
return vec;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
// Template to catch non-Tensor args.
|
| 259 |
+
template <typename T>
|
| 260 |
+
inline T cached_cast(
|
| 261 |
+
at::ScalarType to_type,
|
| 262 |
+
T arg,
|
| 263 |
+
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
| 264 |
+
return arg;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
/*******************************************************
|
| 268 |
+
Logic to flip an output dtype flag.
|
| 269 |
+
Keep it simple for now by assuming only one such flag is
|
| 270 |
+
present in the argument list. If I ever need a function
|
| 271 |
+
with more than flag I'll figure out something else.
|
| 272 |
+
The policy is:
|
| 273 |
+
If the user has explicity specified a dtype, respect it.
|
| 274 |
+
Otherwise, set it to the autocast type.
|
| 275 |
+
********************************************************/
|
| 276 |
+
|
| 277 |
+
// Overload to catch dtype flags
|
| 278 |
+
c10::optional<ScalarType> inline set_opt_dtype(
|
| 279 |
+
at::ScalarType to_type,
|
| 280 |
+
const c10::optional<ScalarType>& dtype) {
|
| 281 |
+
return dtype.has_value() ? dtype : to_type;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// Template to catch other args
|
| 285 |
+
template <typename T>
|
| 286 |
+
inline T set_opt_dtype(at::ScalarType to_type, T arg) {
|
| 287 |
+
return arg;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
template <typename... Args>
|
| 291 |
+
inline bool firstarg_is_eligible(
|
| 292 |
+
c10::DeviceType device_type,
|
| 293 |
+
const Tensor& arg,
|
| 294 |
+
Args... args) {
|
| 295 |
+
return is_eligible(arg, device_type);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
template <typename... Args>
|
| 299 |
+
inline at::ScalarType type_from_firstarg(
|
| 300 |
+
c10::DeviceType device_type,
|
| 301 |
+
at::ScalarType to_type,
|
| 302 |
+
const Tensor& arg,
|
| 303 |
+
Args... args) {
|
| 304 |
+
return (is_eligible(arg, device_type) ? to_type : arg.scalar_type());
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
// Policies correspond to op categories that need code-divergent handling.
|
| 308 |
+
// Wrapper templates below are specialized based on a policy template parameter.
|
| 309 |
+
enum class CastPolicy : uint8_t {
|
| 310 |
+
lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
|
| 311 |
+
// running the op. Currently, lower_precision_fp is
|
| 312 |
+
// fp16 for AutocastCUDA, and is defined by user
|
| 313 |
+
// (default bf16) for AutocastCPU or other device.
|
| 314 |
+
fp32, // Cast all inputs to at::kFloat before running the op.
|
| 315 |
+
fp32_set_opt_dtype, // Treats functions (like softmax) that
|
| 316 |
+
// 1. we'd like to run in fp32 and
|
| 317 |
+
// 2. have a c10::optional<ScalarType> arg that controls
|
| 318 |
+
// the output type.
|
| 319 |
+
// fp32_set_opt_dtype wrappers' policy is: if the output
|
| 320 |
+
// type is already set, don't touch it, otherwise, set
|
| 321 |
+
// it to at::kFloat.
|
| 322 |
+
fp32_append_dtype, // Treats functions (like norm) that
|
| 323 |
+
// 1. we'd like to run in fp32 and
|
| 324 |
+
// 2. have some overloads that accept an output type and
|
| 325 |
+
// other overloads that don't.
|
| 326 |
+
// fp32_append_dtype wrappers wrap the overloads that don't
|
| 327 |
+
// have an output dtype.
|
| 328 |
+
// The wrapper policy is: append at::kFloat to the args,
|
| 329 |
+
// and redispatch to the type-aware overload.
|
| 330 |
+
promote, // Run in the widest dtype among several args.
|
| 331 |
+
};
|
| 332 |
+
|
| 333 |
+
/********************************************************************************************************
|
| 334 |
+
Templates to provide wrapper functions
|
| 335 |
+
|
| 336 |
+
I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to
|
| 337 |
+
extract args and return type. (see also
|
| 338 |
+
https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
|
| 339 |
+
|
| 340 |
+
This strategy uses an exterior "WrapFunction" that extracts arguments on behalf
|
| 341 |
+
of (in my case several specializations of) an interior "WrapFunction_".
|
| 342 |
+
Interior WrapFunction_ specializations are defined for each CastPolicy.
|
| 343 |
+
********************************************************************************************************/
|
| 344 |
+
|
| 345 |
+
// Base template for WrapFunction_, which is specialized to contain a "call"
|
| 346 |
+
// method each CastPolicy
|
| 347 |
+
template <
|
| 348 |
+
CastPolicy policy,
|
| 349 |
+
c10::DeviceType device_type,
|
| 350 |
+
class Redispatch,
|
| 351 |
+
Redispatch* F,
|
| 352 |
+
class Ret,
|
| 353 |
+
class ArgList>
|
| 354 |
+
struct WrapFunction_ {};
|
| 355 |
+
|
| 356 |
+
// CastPolicy::lower_precision_fp General_DeviceType
|
| 357 |
+
template <
|
| 358 |
+
c10::DeviceType device_type,
|
| 359 |
+
class Redispatch,
|
| 360 |
+
Redispatch* F,
|
| 361 |
+
class Ret,
|
| 362 |
+
class... Args>
|
| 363 |
+
struct WrapFunction_<
|
| 364 |
+
CastPolicy::lower_precision_fp,
|
| 365 |
+
device_type,
|
| 366 |
+
Redispatch,
|
| 367 |
+
F,
|
| 368 |
+
Ret,
|
| 369 |
+
guts::typelist::typelist<Args...>> {
|
| 370 |
+
static Ret call(Args... args) {
|
| 371 |
+
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
| 372 |
+
get_autocast_dispatch_key_from_device_type(device_type));
|
| 373 |
+
return (*F)(cached_cast(
|
| 374 |
+
get_lower_precision_fp_from_device_type(device_type),
|
| 375 |
+
args,
|
| 376 |
+
device_type)...);
|
| 377 |
+
}
|
| 378 |
+
};
|
| 379 |
+
|
| 380 |
+
// CastPolicy::fp32 General_DeviceType
|
| 381 |
+
template <
|
| 382 |
+
c10::DeviceType device_type,
|
| 383 |
+
class Redispatch,
|
| 384 |
+
Redispatch* F,
|
| 385 |
+
class Ret,
|
| 386 |
+
class... Args>
|
| 387 |
+
struct WrapFunction_<
|
| 388 |
+
CastPolicy::fp32,
|
| 389 |
+
device_type,
|
| 390 |
+
Redispatch,
|
| 391 |
+
F,
|
| 392 |
+
Ret,
|
| 393 |
+
guts::typelist::typelist<Args...>> {
|
| 394 |
+
static Ret call(Args... args) {
|
| 395 |
+
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
| 396 |
+
get_autocast_dispatch_key_from_device_type(device_type));
|
| 397 |
+
return (*F)(cached_cast(at::kFloat, args, device_type)...);
|
| 398 |
+
}
|
| 399 |
+
};
|
| 400 |
+
|
| 401 |
+
// CastPolicy::fp32_set_opt_dtype General_DeviceType
|
| 402 |
+
template <
|
| 403 |
+
c10::DeviceType device_type,
|
| 404 |
+
class Redispatch,
|
| 405 |
+
Redispatch* F,
|
| 406 |
+
class Ret,
|
| 407 |
+
class... Args>
|
| 408 |
+
struct WrapFunction_<
|
| 409 |
+
CastPolicy::fp32_set_opt_dtype,
|
| 410 |
+
device_type,
|
| 411 |
+
Redispatch,
|
| 412 |
+
F,
|
| 413 |
+
Ret,
|
| 414 |
+
guts::typelist::typelist<Args...>> {
|
| 415 |
+
static Ret call(Args... args) {
|
| 416 |
+
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
| 417 |
+
get_autocast_dispatch_key_from_device_type(device_type));
|
| 418 |
+
if (firstarg_is_eligible(device_type, args...)) {
|
| 419 |
+
return (*F)(set_opt_dtype(at::kFloat, args)...);
|
| 420 |
+
} else {
|
| 421 |
+
// If ineligible, calls F with unaltered args. Does not set opt dtype,
|
| 422 |
+
// because setting opt dtype explicitly may interfere with internal
|
| 423 |
+
// implicit promotion decisions.
|
| 424 |
+
return (*F)(args...);
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
};
|
| 428 |
+
|
| 429 |
+
// CastPolicy::fp32_append_dtype General_DeviceType
|
| 430 |
+
template <
|
| 431 |
+
c10::DeviceType device_type,
|
| 432 |
+
class Redispatch,
|
| 433 |
+
Redispatch* F,
|
| 434 |
+
class Ret,
|
| 435 |
+
class... Args>
|
| 436 |
+
struct WrapFunction_<
|
| 437 |
+
CastPolicy::fp32_append_dtype,
|
| 438 |
+
device_type,
|
| 439 |
+
Redispatch,
|
| 440 |
+
F,
|
| 441 |
+
Ret,
|
| 442 |
+
guts::typelist::typelist<Args...>> {
|
| 443 |
+
static Ret call(Args... args) {
|
| 444 |
+
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
| 445 |
+
get_autocast_dispatch_key_from_device_type(device_type));
|
| 446 |
+
at::ScalarType out_type =
|
| 447 |
+
type_from_firstarg(device_type, at::kFloat, args...);
|
| 448 |
+
return (*F)(args..., out_type);
|
| 449 |
+
}
|
| 450 |
+
};
|
| 451 |
+
|
| 452 |
+
// CastPolicy::promote General_DeviceType
|
| 453 |
+
template <
|
| 454 |
+
c10::DeviceType device_type,
|
| 455 |
+
class Redispatch,
|
| 456 |
+
Redispatch* F,
|
| 457 |
+
class Ret,
|
| 458 |
+
class... Args>
|
| 459 |
+
struct WrapFunction_<
|
| 460 |
+
CastPolicy::promote,
|
| 461 |
+
device_type,
|
| 462 |
+
Redispatch,
|
| 463 |
+
F,
|
| 464 |
+
Ret,
|
| 465 |
+
guts::typelist::typelist<Args...>> {
|
| 466 |
+
static Ret call(Args... args) {
|
| 467 |
+
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
| 468 |
+
get_autocast_dispatch_key_from_device_type(device_type));
|
| 469 |
+
auto to_type = promote_type(
|
| 470 |
+
get_lower_precision_fp_from_device_type(device_type),
|
| 471 |
+
device_type,
|
| 472 |
+
args...);
|
| 473 |
+
return (*F)(cached_cast(to_type, args, device_type)...);
|
| 474 |
+
}
|
| 475 |
+
};
|
| 476 |
+
|
| 477 |
+
// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating
|
| 478 |
+
// core/boxing/impl/WrapFunctionIntoFunctor.h)
|
| 479 |
+
template <
|
| 480 |
+
CastPolicy policy,
|
| 481 |
+
c10::DeviceType device_type,
|
| 482 |
+
class Registered, // The signature for which we're registering. The
|
| 483 |
+
// dispatcher's calling code invokes our registered
|
| 484 |
+
// functions with arguments matching Registered, so we
|
| 485 |
+
// register WrapFunction_::call methods with a matching
|
| 486 |
+
// signature to properly field those arguments.
|
| 487 |
+
// guts::function_traits below extracts return_type and
|
| 488 |
+
// parameter_types from Registered, which WrapFunction_
|
| 489 |
+
// templates above use to declare their call methods.
|
| 490 |
+
class Redispatch, // The signature for the function we're redispatching to.
|
| 491 |
+
// In most cases this is the same as Registered, but for
|
| 492 |
+
// some ops (for example, ops where we append a dtype)
|
| 493 |
+
// it's useful to redispatch to a function with a
|
| 494 |
+
// different signature.
|
| 495 |
+
Redispatch* F> // The actual function we're redispatching to.
|
| 496 |
+
struct WrapFunction final {
|
| 497 |
+
using type = WrapFunction_<
|
| 498 |
+
policy,
|
| 499 |
+
device_type,
|
| 500 |
+
Redispatch,
|
| 501 |
+
F,
|
| 502 |
+
typename guts::function_traits<Registered>::return_type,
|
| 503 |
+
typename guts::function_traits<Registered>::parameter_types>;
|
| 504 |
+
};
|
| 505 |
+
|
| 506 |
+
/*****************************************************************************************************************
|
| 507 |
+
This section performs load-time registration for autocast wrappers.
|
| 508 |
+
|
| 509 |
+
It's debatable at what level operations should be patched. We'd like casts to
|
| 510 |
+
be autograd-exposed and precede autograd history recording, so that for
|
| 511 |
+
lower_precision_fp ops, input tensors are saved for backward in
|
| 512 |
+
lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp
|
| 513 |
+
can significantly reduce a model's memory footprint.
|
| 514 |
+
|
| 515 |
+
Option 1 (strawman): Patch only at the level of explicit calls into
|
| 516 |
+
cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are
|
| 517 |
+
guaranteed to use Tensor Cores, therefore they're the ones that will benefit
|
| 518 |
+
most from lower_precision_fp. Potential pitfall: convolutions (and other ops)
|
| 519 |
+
are wrapped in several layers of at::* calls. If one of those happens to record
|
| 520 |
+
autograd history, then we've lost the opportunity to save inputs in
|
| 521 |
+
lower_precision_fp.
|
| 522 |
+
|
| 523 |
+
Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd
|
| 524 |
+
history recording can't sneak in ahead of autocast. This mirrors Apex most
|
| 525 |
+
closely.
|
| 526 |
+
|
| 527 |
+
I think Option 2 is the right answer for all ops, not just convolutions. Option
|
| 528 |
+
2 is what I implement here.
|
| 529 |
+
*****************************************************************************************************************/
|
| 530 |
+
|
| 531 |
+
/********************************************************************************************************************
|
| 532 |
+
Explicit registration for out-of-place ops
|
| 533 |
+
|
| 534 |
+
The stuff below could be codegenned. Ed said
|
| 535 |
+
> you are going to have to write the function definition at some point, I
|
| 536 |
+
wouldn't try to get clever about it Therefore, for the moment, this is all
|
| 537 |
+
copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
|
| 538 |
+
********************************************************************************************************************/
|
| 539 |
+
|
| 540 |
+
} // namespace at::autocast
|
| 541 |
+
|
| 542 |
+
#define ADD_NS(RAW_OP) at::RAW_OP
|
| 543 |
+
|
| 544 |
+
// Common cases where registration signature matches redispatch signature
|
| 545 |
+
// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
|
| 546 |
+
#define KERNEL(DISPATCHKEY, OP, POLICY) \
|
| 547 |
+
m.impl( \
|
| 548 |
+
TORCH_SELECTIVE_NAME("aten::" #OP), \
|
| 549 |
+
&::at::autocast::WrapFunction< \
|
| 550 |
+
::at::autocast::CastPolicy::POLICY, \
|
| 551 |
+
DISPATCHKEY, \
|
| 552 |
+
decltype(ATEN_FN(OP)), \
|
| 553 |
+
decltype(ATEN_FN(OP)), \
|
| 554 |
+
&ATEN_FN(OP)>::type::call);
|
| 555 |
+
|
| 556 |
+
#define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY) \
|
| 557 |
+
m.impl( \
|
| 558 |
+
TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
|
| 559 |
+
&::at::autocast::WrapFunction< \
|
| 560 |
+
::at::autocast::CastPolicy::POLICY, \
|
| 561 |
+
DISPATCHKEY, \
|
| 562 |
+
decltype(ATEN_FN2(OP, OVERLOAD)), \
|
| 563 |
+
decltype(ATEN_FN2(OP, OVERLOAD)), \
|
| 564 |
+
&ATEN_FN2(OP, OVERLOAD)>::type::call);
|
| 565 |
+
|
| 566 |
+
// Less-common but still useful case: redispatching to a function
|
| 567 |
+
// with a new signature (e.g. appending a dtype)
|
| 568 |
+
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
| 569 |
+
DISPATCHKEY, \
|
| 570 |
+
REDISPATCH_FUNC, \
|
| 571 |
+
REGISTER_NAME, \
|
| 572 |
+
REGISTER_SIGNATURE, \
|
| 573 |
+
REDISPATCH_SIGNATURE, \
|
| 574 |
+
POLICY) \
|
| 575 |
+
m.impl( \
|
| 576 |
+
TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
|
| 577 |
+
&::at::autocast::WrapFunction< \
|
| 578 |
+
::at::autocast::CastPolicy::POLICY, \
|
| 579 |
+
DISPATCHKEY, \
|
| 580 |
+
REGISTER_SIGNATURE, \
|
| 581 |
+
REDISPATCH_SIGNATURE, \
|
| 582 |
+
&REDISPATCH_FUNC>::type::call);
|
| 583 |
+
|
| 584 |
+
// KERNEL_CPU/KERNEL_CPU2/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU
|
| 585 |
+
// registration for AutocastCPU
|
| 586 |
+
#define KERNEL_CPU(OP, POLICY) KERNEL(c10::DeviceType::CPU, OP, POLICY)
|
| 587 |
+
|
| 588 |
+
#define KERNEL_CPU2(OP, OVERLOAD, POLICY) \
|
| 589 |
+
KERNEL2(c10::DeviceType::CPU, OP, OVERLOAD, POLICY)
|
| 590 |
+
|
| 591 |
+
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \
|
| 592 |
+
REDISPATCH_FUNC, \
|
| 593 |
+
REGISTER_NAME, \
|
| 594 |
+
REGISTER_SIGNATURE, \
|
| 595 |
+
REDISPATCH_SIGNATURE, \
|
| 596 |
+
POLICY) \
|
| 597 |
+
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
| 598 |
+
c10::DeviceType::CPU, \
|
| 599 |
+
REDISPATCH_FUNC, \
|
| 600 |
+
REGISTER_NAME, \
|
| 601 |
+
REGISTER_SIGNATURE, \
|
| 602 |
+
REDISPATCH_SIGNATURE, \
|
| 603 |
+
POLICY)
|
| 604 |
+
|
| 605 |
+
// KERNEL_CUDA/KERNEL_CUDA2/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA
|
| 606 |
+
// registration for AutocastCUDA
|
| 607 |
+
#define KERNEL_CUDA(OP, POLICY) KERNEL(c10::DeviceType::CUDA, OP, POLICY)
|
| 608 |
+
|
| 609 |
+
#define KERNEL_CUDA2(OP, OVERLOAD, POLICY) \
|
| 610 |
+
KERNEL2(c10::DeviceType::CUDA, OP, OVERLOAD, POLICY)
|
| 611 |
+
|
| 612 |
+
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \
|
| 613 |
+
REDISPATCH_FUNC, \
|
| 614 |
+
REGISTER_NAME, \
|
| 615 |
+
REGISTER_SIGNATURE, \
|
| 616 |
+
REDISPATCH_SIGNATURE, \
|
| 617 |
+
POLICY) \
|
| 618 |
+
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
| 619 |
+
c10::DeviceType::CUDA, \
|
| 620 |
+
REDISPATCH_FUNC, \
|
| 621 |
+
REGISTER_NAME, \
|
| 622 |
+
REGISTER_SIGNATURE, \
|
| 623 |
+
REDISPATCH_SIGNATURE, \
|
| 624 |
+
POLICY)
|
| 625 |
+
|
| 626 |
+
// KERNEL_PRIVATEUSEONE/KERNEL_PRIVATEUSEONE2/
|
| 627 |
+
// KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
|
| 628 |
+
// registration for AutocastPrivateUse1
|
| 629 |
+
#define KERNEL_PRIVATEUSEONE(OP, POLICY) \
|
| 630 |
+
KERNEL(c10::DeviceType::PrivateUse1, OP, POLICY)
|
| 631 |
+
|
| 632 |
+
#define KERNEL_PRIVATEUSEONE2(OP, OVERLOAD, POLICY) \
|
| 633 |
+
KERNEL2(c10::DeviceType::PrivateUse1, OP, OVERLOAD, POLICY)
|
| 634 |
+
|
| 635 |
+
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
|
| 636 |
+
REDISPATCH_FUNC, \
|
| 637 |
+
REGISTER_NAME, \
|
| 638 |
+
REGISTER_SIGNATURE, \
|
| 639 |
+
REDISPATCH_SIGNATURE, \
|
| 640 |
+
POLICY) \
|
| 641 |
+
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
| 642 |
+
c10::DeviceType::PrivateUse1, \
|
| 643 |
+
REDISPATCH_FUNC, \
|
| 644 |
+
REGISTER_NAME, \
|
| 645 |
+
REGISTER_SIGNATURE, \
|
| 646 |
+
REDISPATCH_SIGNATURE, \
|
| 647 |
+
POLICY)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cpp_custom_type_hack.h
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 2 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 3 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 4 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 5 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 6 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 7 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 8 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 9 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 10 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 11 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 12 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 13 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 14 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 15 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 16 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 17 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 18 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 19 |
+
|
| 20 |
+
// YOU ARE IN THE WRONG PLACE! TURN BACK NOW!
|
| 21 |
+
|
| 22 |
+
// This code was a temporary hack to enable embedding arbitrary C++ structures
|
| 23 |
+
// into Tensors. THIS IS UNSAFE AND IS NOT SUPPORTED. IF YOU USE THIS CODE,
|
| 24 |
+
// IT __WILL__ BREAK.
|
| 25 |
+
|
| 26 |
+
// This code has been superseded by custom classes:
|
| 27 |
+
// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html
|
| 28 |
+
|
| 29 |
+
// Please use custom classes and **DO NOT ADD MORE CALLSITES TO THINGS DEFINED
|
| 30 |
+
// IN THIS FILE**.
|
| 31 |
+
|
| 32 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 33 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 34 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 35 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 36 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 37 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 38 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 39 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 40 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 41 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 42 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 43 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 44 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 45 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 46 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 47 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 48 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 49 |
+
// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP
|
| 50 |
+
|
| 51 |
+
#include <ATen/TracerMode.h>
|
| 52 |
+
#include <ATen/core/Tensor.h>
|
| 53 |
+
|
| 54 |
+
#ifndef AT_PER_OPERATOR_HEADERS
|
| 55 |
+
#include <ATen/Functions.h>
|
| 56 |
+
#else
|
| 57 |
+
#include <ATen/ops/empty.h>
|
| 58 |
+
#endif
|
| 59 |
+
|
| 60 |
+
namespace at::cpp_custom_type_hack {
|
| 61 |
+
|
| 62 |
+
template <typename T>
|
| 63 |
+
[[deprecated(
|
| 64 |
+
"Use custom classes instead: "
|
| 65 |
+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] bool
|
| 66 |
+
isa(const Tensor& packed) {
|
| 67 |
+
return (packed.scalar_type() == kByte) &&
|
| 68 |
+
(packed.storage().data_ptr().get_deleter() ==
|
| 69 |
+
caffe2::TypeMeta::Make<T>().deleteFn());
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <typename T>
|
| 73 |
+
[[deprecated(
|
| 74 |
+
"Use custom classes instead: "
|
| 75 |
+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] T&
|
| 76 |
+
cast(const Tensor& packed) {
|
| 77 |
+
TORCH_CHECK(
|
| 78 |
+
packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
|
| 79 |
+
TORCH_CHECK(
|
| 80 |
+
packed.storage().data_ptr().get_deleter() ==
|
| 81 |
+
caffe2::TypeMeta::Make<T>().deleteFn(),
|
| 82 |
+
"Expected temporary cpp type wrapper of type ",
|
| 83 |
+
caffe2::TypeMeta::TypeName<T>());
|
| 84 |
+
return *reinterpret_cast<T*>(packed.storage().data_ptr().get());
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
template <typename T>
|
| 88 |
+
[[deprecated(
|
| 89 |
+
"Use custom classes instead: "
|
| 90 |
+
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] Tensor
|
| 91 |
+
create(std::unique_ptr<T> ptr, TensorOptions options) {
|
| 92 |
+
// None of this should trace, so turn off Tracer dispatching
|
| 93 |
+
at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
|
| 94 |
+
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
| 95 |
+
|
| 96 |
+
// We store this instance away in a Tensor and register a deleter function
|
| 97 |
+
// so that we do not leak memory. On the other side, we pull out the storage's
|
| 98 |
+
// data_ptr and get the right typed pointer.
|
| 99 |
+
void* raw_ptr = ptr.release();
|
| 100 |
+
at::DataPtr at_ptr(
|
| 101 |
+
raw_ptr, raw_ptr, caffe2::TypeMeta::Make<T>().deleteFn(), at::kCPU);
|
| 102 |
+
|
| 103 |
+
// size doesn't really matter, but we can align it to the actual size
|
| 104 |
+
// returning variables because one likely want to use this hack from python
|
| 105 |
+
auto retval = at::empty({sizeof(T)}, options.device(kCPU).dtype(at::kByte));
|
| 106 |
+
retval.storage().set_data_ptr_noswap(std::move(at_ptr));
|
| 107 |
+
return retval;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
} // namespace at::cpp_custom_type_hack
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/ApplyGridUtils.cuh
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 2 |
+
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
|
| 5 |
+
namespace at::cuda {
|
| 6 |
+
|
| 7 |
+
/**
|
| 8 |
+
Computes ceil(a / b)
|
| 9 |
+
*/
|
| 10 |
+
template <typename T>
|
| 11 |
+
__host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) {
|
| 12 |
+
return (a + b - 1) / b;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
namespace {
|
| 16 |
+
|
| 17 |
+
// Threads per block for our apply kernel
|
| 18 |
+
// FIXME: use occupancy calculator instead
|
| 19 |
+
constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
|
| 20 |
+
constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
|
| 21 |
+
|
| 22 |
+
template <int step = 1>
|
| 23 |
+
inline bool getApplyGrid(uint64_t totalElements, dim3& grid, c10::DeviceIndex curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
|
| 24 |
+
if (curDevice == -1) return false;
|
| 25 |
+
uint64_t numel_per_thread = static_cast<uint64_t>(max_threads_per_block) * static_cast<uint64_t>(step);
|
| 26 |
+
uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread);
|
| 27 |
+
uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0];
|
| 28 |
+
if (numBlocks > maxGridX)
|
| 29 |
+
numBlocks = maxGridX;
|
| 30 |
+
grid = dim3(numBlocks);
|
| 31 |
+
return true;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
constexpr int getApplyBlocksPerSM() {
|
| 35 |
+
return AT_APPLY_BLOCKS_PER_SM;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
constexpr int getApplyBlockSize() {
|
| 39 |
+
return AT_APPLY_THREADS_PER_BLOCK;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
|
| 43 |
+
return dim3(max_threads_per_block);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
} // anonymous namespace
|
| 47 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/AsmUtils.cuh
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <cstdint>
|
| 3 |
+
|
| 4 |
+
// Collection of direct PTX functions
|
| 5 |
+
|
| 6 |
+
namespace at::cuda {
|
| 7 |
+
|
| 8 |
+
template <typename T>
|
| 9 |
+
struct Bitfield {};
|
| 10 |
+
|
| 11 |
+
template <>
|
| 12 |
+
struct Bitfield<unsigned int> {
|
| 13 |
+
static __device__ __host__ __forceinline__
|
| 14 |
+
unsigned int getBitfield(unsigned int val, int pos, int len) {
|
| 15 |
+
#if !defined(__CUDA_ARCH__)
|
| 16 |
+
pos &= 0xff;
|
| 17 |
+
len &= 0xff;
|
| 18 |
+
|
| 19 |
+
unsigned int m = (1u << len) - 1u;
|
| 20 |
+
return (val >> pos) & m;
|
| 21 |
+
#else
|
| 22 |
+
unsigned int ret;
|
| 23 |
+
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
|
| 24 |
+
return ret;
|
| 25 |
+
#endif
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
static __device__ __host__ __forceinline__
|
| 29 |
+
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
|
| 30 |
+
#if !defined(__CUDA_ARCH__)
|
| 31 |
+
pos &= 0xff;
|
| 32 |
+
len &= 0xff;
|
| 33 |
+
|
| 34 |
+
unsigned int m = (1u << len) - 1u;
|
| 35 |
+
toInsert &= m;
|
| 36 |
+
toInsert <<= pos;
|
| 37 |
+
m <<= pos;
|
| 38 |
+
|
| 39 |
+
return (val & ~m) | toInsert;
|
| 40 |
+
#else
|
| 41 |
+
unsigned int ret;
|
| 42 |
+
asm("bfi.b32 %0, %1, %2, %3, %4;" :
|
| 43 |
+
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
|
| 44 |
+
return ret;
|
| 45 |
+
#endif
|
| 46 |
+
}
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
template <>
|
| 50 |
+
struct Bitfield<uint64_t> {
|
| 51 |
+
static __device__ __host__ __forceinline__
|
| 52 |
+
uint64_t getBitfield(uint64_t val, int pos, int len) {
|
| 53 |
+
#if !defined(__CUDA_ARCH__)
|
| 54 |
+
pos &= 0xff;
|
| 55 |
+
len &= 0xff;
|
| 56 |
+
|
| 57 |
+
uint64_t m = (1u << len) - 1u;
|
| 58 |
+
return (val >> pos) & m;
|
| 59 |
+
#else
|
| 60 |
+
uint64_t ret;
|
| 61 |
+
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
|
| 62 |
+
return ret;
|
| 63 |
+
#endif
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
static __device__ __host__ __forceinline__
|
| 67 |
+
uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
|
| 68 |
+
#if !defined(__CUDA_ARCH__)
|
| 69 |
+
pos &= 0xff;
|
| 70 |
+
len &= 0xff;
|
| 71 |
+
|
| 72 |
+
uint64_t m = (1u << len) - 1u;
|
| 73 |
+
toInsert &= m;
|
| 74 |
+
toInsert <<= pos;
|
| 75 |
+
m <<= pos;
|
| 76 |
+
|
| 77 |
+
return (val & ~m) | toInsert;
|
| 78 |
+
#else
|
| 79 |
+
uint64_t ret;
|
| 80 |
+
asm("bfi.b64 %0, %1, %2, %3, %4;" :
|
| 81 |
+
"=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
|
| 82 |
+
return ret;
|
| 83 |
+
#endif
|
| 84 |
+
}
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
__device__ __forceinline__ int getLaneId() {
|
| 88 |
+
#if defined(USE_ROCM)
|
| 89 |
+
return __lane_id();
|
| 90 |
+
#else
|
| 91 |
+
int laneId;
|
| 92 |
+
asm("mov.s32 %0, %%laneid;" : "=r"(laneId) );
|
| 93 |
+
return laneId;
|
| 94 |
+
#endif
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
#if defined(USE_ROCM)
|
| 98 |
+
__device__ __forceinline__ unsigned long long int getLaneMaskLt() {
|
| 99 |
+
const std::uint64_t m = (1ull << getLaneId()) - 1ull;
|
| 100 |
+
return m;
|
| 101 |
+
}
|
| 102 |
+
#else
|
| 103 |
+
__device__ __forceinline__ unsigned getLaneMaskLt() {
|
| 104 |
+
unsigned mask;
|
| 105 |
+
asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
|
| 106 |
+
return mask;
|
| 107 |
+
}
|
| 108 |
+
#endif
|
| 109 |
+
|
| 110 |
+
#if defined (USE_ROCM)
|
| 111 |
+
__device__ __forceinline__ unsigned long long int getLaneMaskLe() {
|
| 112 |
+
std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
|
| 113 |
+
return m;
|
| 114 |
+
}
|
| 115 |
+
#else
|
| 116 |
+
__device__ __forceinline__ unsigned getLaneMaskLe() {
|
| 117 |
+
unsigned mask;
|
| 118 |
+
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
|
| 119 |
+
return mask;
|
| 120 |
+
}
|
| 121 |
+
#endif
|
| 122 |
+
|
| 123 |
+
#if defined(USE_ROCM)
|
| 124 |
+
__device__ __forceinline__ unsigned long long int getLaneMaskGt() {
|
| 125 |
+
const std::uint64_t m = getLaneMaskLe();
|
| 126 |
+
return m ? ~m : m;
|
| 127 |
+
}
|
| 128 |
+
#else
|
| 129 |
+
__device__ __forceinline__ unsigned getLaneMaskGt() {
|
| 130 |
+
unsigned mask;
|
| 131 |
+
asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
|
| 132 |
+
return mask;
|
| 133 |
+
}
|
| 134 |
+
#endif
|
| 135 |
+
|
| 136 |
+
#if defined(USE_ROCM)
|
| 137 |
+
__device__ __forceinline__ unsigned long long int getLaneMaskGe() {
|
| 138 |
+
const std::uint64_t m = getLaneMaskLt();
|
| 139 |
+
return ~m;
|
| 140 |
+
}
|
| 141 |
+
#else
|
| 142 |
+
__device__ __forceinline__ unsigned getLaneMaskGe() {
|
| 143 |
+
unsigned mask;
|
| 144 |
+
asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
|
| 145 |
+
return mask;
|
| 146 |
+
}
|
| 147 |
+
#endif
|
| 148 |
+
|
| 149 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDABlas.h
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
/*
|
| 3 |
+
Provides a subset of CUDA BLAS functions as templates:
|
| 4 |
+
|
| 5 |
+
gemm<Dtype>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
|
| 6 |
+
ldc)
|
| 7 |
+
|
| 8 |
+
gemv<Dtype>(transa, m, n, alpha, a, lda, x, incx, beta, y, incy)
|
| 9 |
+
|
| 10 |
+
dot<Dtype>(n, x, incx, y, incy, result)
|
| 11 |
+
|
| 12 |
+
where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot).
|
| 13 |
+
The functions are available in at::cuda::blas namespace.
|
| 14 |
+
*/
|
| 15 |
+
|
| 16 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 17 |
+
#include <ATen/OpMathType.h>
|
| 18 |
+
|
| 19 |
+
namespace at::cuda::blas {
|
| 20 |
+
|
| 21 |
+
// RAII guard that sets the CuBLAS pointer mode and restores it to
|
| 22 |
+
// its previous value when the guard is destroyed
|
| 23 |
+
class PointerModeGuard {
|
| 24 |
+
public:
|
| 25 |
+
PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
|
| 26 |
+
handle(handle) {
|
| 27 |
+
TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
|
| 28 |
+
TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
~PointerModeGuard() {
|
| 32 |
+
cublasSetPointerMode(handle, previous_mode);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
private:
|
| 36 |
+
cublasHandle_t handle;
|
| 37 |
+
cublasPointerMode_t previous_mode;
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
/* LEVEL 3 BLAS FUNCTIONS */
|
| 41 |
+
|
| 42 |
+
#define CUDABLAS_GEMM_ARGTYPES(Dtype) \
|
| 43 |
+
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
|
| 44 |
+
const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
|
| 45 |
+
Dtype *c, int64_t ldc
|
| 46 |
+
|
| 47 |
+
#define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc
|
| 48 |
+
|
| 49 |
+
template <typename Dtype>
|
| 50 |
+
inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
| 51 |
+
AT_ERROR("at::cuda::blas::gemm: not implemented for ", typeid(Dtype).name());
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
template <>
|
| 55 |
+
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
|
| 56 |
+
template <>
|
| 57 |
+
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
|
| 58 |
+
template <>
|
| 59 |
+
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
|
| 60 |
+
template <>
|
| 61 |
+
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
|
| 62 |
+
template <>
|
| 63 |
+
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
| 64 |
+
template <>
|
| 65 |
+
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
| 66 |
+
|
| 67 |
+
template <typename Dtype>
|
| 68 |
+
inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
| 69 |
+
AT_ERROR("at::cuda::blas::gemm_internal: not implemented for ", typeid(Dtype).name());
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <>
|
| 73 |
+
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double));
|
| 74 |
+
template <>
|
| 75 |
+
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float));
|
| 76 |
+
template <>
|
| 77 |
+
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
|
| 78 |
+
template <>
|
| 79 |
+
void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
|
| 80 |
+
template <>
|
| 81 |
+
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
| 82 |
+
template <>
|
| 83 |
+
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
| 84 |
+
|
| 85 |
+
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
| 86 |
+
enum GEMMAndBiasActivationEpilogue {
|
| 87 |
+
None,
|
| 88 |
+
RELU,
|
| 89 |
+
GELU,
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
// NOTE: GELU activation is not supported prior to CUDA 11.4 and will
|
| 93 |
+
// do nothing if passed in that case.
|
| 94 |
+
template <typename Dtype>
|
| 95 |
+
void gemm_and_bias(
|
| 96 |
+
bool transpose_mat1,
|
| 97 |
+
bool transpose_mat2,
|
| 98 |
+
int64_t m,
|
| 99 |
+
int64_t n,
|
| 100 |
+
int64_t k,
|
| 101 |
+
at::opmath_type<Dtype> alpha_val,
|
| 102 |
+
const Dtype* mat1_ptr,
|
| 103 |
+
int64_t mat1_ld,
|
| 104 |
+
const Dtype* mat2_ptr,
|
| 105 |
+
int64_t mat2_ld,
|
| 106 |
+
const Dtype* bias,
|
| 107 |
+
Dtype* result_ptr,
|
| 108 |
+
int64_t result_ld,
|
| 109 |
+
GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
|
| 110 |
+
|
| 111 |
+
void int8_gemm(
|
| 112 |
+
bool transpose_mat1,
|
| 113 |
+
bool transpose_mat2,
|
| 114 |
+
int64_t m,
|
| 115 |
+
int64_t n,
|
| 116 |
+
int64_t k,
|
| 117 |
+
const int8_t* mat1_ptr,
|
| 118 |
+
int64_t mat1_ld,
|
| 119 |
+
const int8_t* mat2_ptr,
|
| 120 |
+
int64_t mat2_ld,
|
| 121 |
+
int32_t* result_ptr,
|
| 122 |
+
int64_t result_ld);
|
| 123 |
+
|
| 124 |
+
void scaled_gemm(
|
| 125 |
+
char transa,
|
| 126 |
+
char transb,
|
| 127 |
+
int64_t m,
|
| 128 |
+
int64_t n,
|
| 129 |
+
int64_t k,
|
| 130 |
+
const void* mat1_ptr,
|
| 131 |
+
const void* mat1_scale_ptr,
|
| 132 |
+
int64_t mat1_ld,
|
| 133 |
+
ScalarType mat1_dtype,
|
| 134 |
+
const void* mat2_ptr,
|
| 135 |
+
const void* mat2_scale_ptr,
|
| 136 |
+
int64_t mat2_ld,
|
| 137 |
+
ScalarType mat2_dtype,
|
| 138 |
+
const void* bias_ptr,
|
| 139 |
+
ScalarType bias_dtype,
|
| 140 |
+
void* result_ptr,
|
| 141 |
+
const void* result_scale_ptr,
|
| 142 |
+
int64_t result_ld,
|
| 143 |
+
ScalarType result_dtype,
|
| 144 |
+
void* amax_ptr,
|
| 145 |
+
bool use_fast_accum);
|
| 146 |
+
#endif
|
| 147 |
+
|
| 148 |
+
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
|
| 149 |
+
char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
|
| 150 |
+
const Dtype *a, int64_t lda, int64_t stridea, \
|
| 151 |
+
const Dtype *b, int64_t ldb, int64_t strideb, \
|
| 152 |
+
at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
|
| 153 |
+
|
| 154 |
+
#define CUDABLAS_BGEMM_ARGS(Dtype) \
|
| 155 |
+
transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches
|
| 156 |
+
|
| 157 |
+
template <typename Dtype>
|
| 158 |
+
inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
| 159 |
+
AT_ERROR("at::cuda::blas::bgemm: not implemented for ", typeid(Dtype).name());
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
template <>
|
| 163 |
+
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
|
| 164 |
+
template <>
|
| 165 |
+
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
|
| 166 |
+
template <>
|
| 167 |
+
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
|
| 168 |
+
template <>
|
| 169 |
+
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
|
| 170 |
+
template <>
|
| 171 |
+
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
|
| 172 |
+
template <>
|
| 173 |
+
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
|
| 174 |
+
|
| 175 |
+
template <typename Dtype>
|
| 176 |
+
inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
| 177 |
+
AT_ERROR("at::cuda::blas::bgemm_internal: not implemented for ", typeid(Dtype).name());
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
template <>
|
| 181 |
+
void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double));
|
| 182 |
+
template <>
|
| 183 |
+
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float));
|
| 184 |
+
template <>
|
| 185 |
+
void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
|
| 186 |
+
template <>
|
| 187 |
+
void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
|
| 188 |
+
template <>
|
| 189 |
+
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
|
| 190 |
+
template <>
|
| 191 |
+
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
|
| 192 |
+
|
| 193 |
+
#if defined(USE_ROCM) && ROCM_VERSION <= 50500
|
| 194 |
+
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
|
| 195 |
+
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
|
| 196 |
+
hipblasHandle_t handle, hipblasSideMode_t side, hipblasFillMode_t uplo, \
|
| 197 |
+
hipblasOperation_t trans, hipblasDiagType_t diag, int m, int n, \
|
| 198 |
+
const Dtype *alpha, Dtype *A, int lda, Dtype *B, int ldb
|
| 199 |
+
#else
|
| 200 |
+
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
|
| 201 |
+
cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
|
| 202 |
+
cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
|
| 203 |
+
const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
|
| 204 |
+
#endif
|
| 205 |
+
|
| 206 |
+
template <typename Dtype>
|
| 207 |
+
inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
|
| 208 |
+
TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::trsm: not implemented for ", typeid(Dtype).name());
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
template <>
|
| 212 |
+
TORCH_CUDA_CU_API void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float));
|
| 213 |
+
template <>
|
| 214 |
+
TORCH_CUDA_CU_API void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double));
|
| 215 |
+
template <>
|
| 216 |
+
TORCH_CUDA_CU_API void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>));
|
| 217 |
+
template <>
|
| 218 |
+
TORCH_CUDA_CU_API void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>));
|
| 219 |
+
|
| 220 |
+
#define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \
|
| 221 |
+
cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
|
| 222 |
+
cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
|
| 223 |
+
const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \
|
| 224 |
+
int batchCount
|
| 225 |
+
|
| 226 |
+
template <typename Dtype>
|
| 227 |
+
inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) {
|
| 228 |
+
TORCH_INTERNAL_ASSERT(
|
| 229 |
+
false,
|
| 230 |
+
"at::cuda::blas::trsmBatched: not implemented for ",
|
| 231 |
+
typeid(Dtype).name());
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
template <>
|
| 235 |
+
TORCH_CUDA_CU_API void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float));
|
| 236 |
+
template <>
|
| 237 |
+
TORCH_CUDA_CU_API void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double));
|
| 238 |
+
template <>
|
| 239 |
+
TORCH_CUDA_CU_API void trsmBatched<c10::complex<float>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>));
|
| 240 |
+
template <>
|
| 241 |
+
TORCH_CUDA_CU_API void trsmBatched<c10::complex<double>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>));
|
| 242 |
+
|
| 243 |
+
/* LEVEL 2 BLAS FUNCTIONS */
|
| 244 |
+
|
| 245 |
+
#define CUDABLAS_GEMV_ARGTYPES(Dtype) \
|
| 246 |
+
char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \
|
| 247 |
+
const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy
|
| 248 |
+
|
| 249 |
+
template <typename Dtype>
|
| 250 |
+
inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) {
|
| 251 |
+
AT_ERROR("at::cuda::blas::gemv: not implemented for ", typeid(Dtype).name());
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
template <>
|
| 255 |
+
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
|
| 256 |
+
template <>
|
| 257 |
+
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
|
| 258 |
+
template <>
|
| 259 |
+
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
|
| 260 |
+
template <>
|
| 261 |
+
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
|
| 262 |
+
template <>
|
| 263 |
+
void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
|
| 264 |
+
template <>
|
| 265 |
+
void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
|
| 266 |
+
|
| 267 |
+
/* LEVEL 1 BLAS FUNCTIONS */
|
| 268 |
+
|
| 269 |
+
#define CUDABLAS_DOT_ARGTYPES(Dtype) \
|
| 270 |
+
cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \
|
| 271 |
+
int incy, Dtype *result
|
| 272 |
+
|
| 273 |
+
template <typename Dtype>
|
| 274 |
+
inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
|
| 275 |
+
AT_ERROR("at::cuda::blas::dot: not implemented for ", typeid(Dtype).name());
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
template <>
|
| 279 |
+
void dot<double>(CUDABLAS_DOT_ARGTYPES(double));
|
| 280 |
+
template <>
|
| 281 |
+
void dot<float>(CUDABLAS_DOT_ARGTYPES(float));
|
| 282 |
+
template <>
|
| 283 |
+
void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half));
|
| 284 |
+
template <>
|
| 285 |
+
void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16));
|
| 286 |
+
template <>
|
| 287 |
+
void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
|
| 288 |
+
template <>
|
| 289 |
+
void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
|
| 290 |
+
|
| 291 |
+
template <typename Dtype>
|
| 292 |
+
inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
|
| 293 |
+
AT_ERROR("at::cuda::blas::vdot: not implemented for ", typeid(Dtype).name());
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
template <>
|
| 297 |
+
void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
|
| 298 |
+
template <>
|
| 299 |
+
void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
|
| 300 |
+
|
| 301 |
+
#define CUDABLAS_GETRS_ARGTYPES(Dtype) \
|
| 302 |
+
cublasHandle_t handle, cublasOperation_t trans, \
|
| 303 |
+
int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
|
| 304 |
+
Dtype** dB_array, int ldb, int* info_array, int batchsize
|
| 305 |
+
|
| 306 |
+
template<class Dtype>
|
| 307 |
+
void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) {
|
| 308 |
+
TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::getrsBatched: not implemented for ",
|
| 309 |
+
typeid(Dtype).name());
|
| 310 |
+
}
|
| 311 |
+
template<>
|
| 312 |
+
TORCH_CUDA_CU_API void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float));
|
| 313 |
+
template<>
|
| 314 |
+
TORCH_CUDA_CU_API void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double));
|
| 315 |
+
template<>
|
| 316 |
+
TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>));
|
| 317 |
+
template<>
|
| 318 |
+
TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>));
|
| 319 |
+
|
| 320 |
+
#define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \
|
| 321 |
+
cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
|
| 322 |
+
Dtype **tau_array, int *info, int batchsize
|
| 323 |
+
|
| 324 |
+
template <class Dtype>
|
| 325 |
+
void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
|
| 326 |
+
TORCH_INTERNAL_ASSERT(
|
| 327 |
+
false,
|
| 328 |
+
"at::cuda::blas::geqrfBatched: not implemented for ",
|
| 329 |
+
typeid(Dtype).name());
|
| 330 |
+
}
|
| 331 |
+
template <>
|
| 332 |
+
TORCH_CUDA_CU_API void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float));
|
| 333 |
+
template <>
|
| 334 |
+
TORCH_CUDA_CU_API void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double));
|
| 335 |
+
template <>
|
| 336 |
+
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
|
| 337 |
+
CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>));
|
| 338 |
+
template <>
|
| 339 |
+
TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
|
| 340 |
+
CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
|
| 341 |
+
|
| 342 |
+
#define CUDABLAS_GETRF_ARGTYPES(Dtype) \
|
| 343 |
+
int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
|
| 344 |
+
|
| 345 |
+
template<class Dtype>
|
| 346 |
+
void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
|
| 347 |
+
TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented for ", typeid(Dtype).name());
|
| 348 |
+
}
|
| 349 |
+
template<>
|
| 350 |
+
TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
|
| 351 |
+
template<>
|
| 352 |
+
TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
|
| 353 |
+
template<>
|
| 354 |
+
TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
|
| 355 |
+
template<>
|
| 356 |
+
TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
|
| 357 |
+
|
| 358 |
+
#define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \
|
| 359 |
+
cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
|
| 360 |
+
|
| 361 |
+
template <class Dtype>
|
| 362 |
+
void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
|
| 363 |
+
TORCH_INTERNAL_ASSERT(false, "at::cuda::blas::gelsBatched: not implemented for ", typeid(Dtype).name());
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
template<>
|
| 367 |
+
TORCH_CUDA_CU_API void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double));
|
| 368 |
+
template<>
|
| 369 |
+
TORCH_CUDA_CU_API void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float));
|
| 370 |
+
template<>
|
| 371 |
+
TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>));
|
| 372 |
+
template<>
|
| 373 |
+
TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
|
| 374 |
+
|
| 375 |
+
} // namespace at::cuda::blas
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAConfig.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// Test these using #if AT_CUDNN_ENABLED(), not #ifdef, so that it's
|
| 4 |
+
// obvious if you forgot to include Config.h
|
| 5 |
+
// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined
|
| 6 |
+
//
|
| 7 |
+
// NB: This header MUST NOT be included from other headers; it should
|
| 8 |
+
// only be included from C++ files.
|
| 9 |
+
#define AT_CUDNN_ENABLED() 1
|
| 10 |
+
#define AT_CUSPARSELT_ENABLED() 1
|
| 11 |
+
#define AT_ROCM_ENABLED() 0
|
| 12 |
+
#define AT_MAGMA_ENABLED() 1
|
| 13 |
+
|
| 14 |
+
// Needed for hipMAGMA to correctly identify implementation
|
| 15 |
+
#if (AT_ROCM_ENABLED() && AT_MAGMA_ENABLED())
|
| 16 |
+
#define HAVE_HIP 1
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
#define NVCC_FLAGS_EXTRA "-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_90,code=sm_90"
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAContextLight.h
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
// Light-weight version of CUDAContext.h with fewer transitive includes
|
| 3 |
+
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
#include <cuda_runtime_api.h>
|
| 7 |
+
#include <cusparse.h>
|
| 8 |
+
#include <cublas_v2.h>
|
| 9 |
+
|
| 10 |
+
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
| 11 |
+
// added bf16 support
|
| 12 |
+
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
| 13 |
+
#include <cublasLt.h>
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
+
#ifdef CUDART_VERSION
|
| 17 |
+
#include <cusolverDn.h>
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
#if defined(USE_ROCM) && ROCM_VERSION >= 50300
|
| 21 |
+
#include <hipsolver/hipsolver.h>
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
#include <c10/core/Allocator.h>
|
| 25 |
+
#include <c10/cuda/CUDAFunctions.h>
|
| 26 |
+
|
| 27 |
+
namespace c10 {
|
| 28 |
+
struct Allocator;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
namespace at::cuda {
|
| 32 |
+
|
| 33 |
+
/*
|
| 34 |
+
A common CUDA interface for ATen.
|
| 35 |
+
|
| 36 |
+
This interface is distinct from CUDAHooks, which defines an interface that links
|
| 37 |
+
to both CPU-only and CUDA builds. That interface is intended for runtime
|
| 38 |
+
dispatch and should be used from files that are included in both CPU-only and
|
| 39 |
+
CUDA builds.
|
| 40 |
+
|
| 41 |
+
CUDAContext, on the other hand, should be preferred by files only included in
|
| 42 |
+
CUDA builds. It is intended to expose CUDA functionality in a consistent
|
| 43 |
+
manner.
|
| 44 |
+
|
| 45 |
+
This means there is some overlap between the CUDAContext and CUDAHooks, but
|
| 46 |
+
the choice of which to use is simple: use CUDAContext when in a CUDA-only file,
|
| 47 |
+
use CUDAHooks otherwise.
|
| 48 |
+
|
| 49 |
+
Note that CUDAContext simply defines an interface with no associated class.
|
| 50 |
+
It is expected that the modules whose functions compose this interface will
|
| 51 |
+
manage their own state. There is only a single CUDA context/state.
|
| 52 |
+
*/
|
| 53 |
+
|
| 54 |
+
/**
|
| 55 |
+
* DEPRECATED: use device_count() instead
|
| 56 |
+
*/
|
| 57 |
+
inline int64_t getNumGPUs() {
|
| 58 |
+
return c10::cuda::device_count();
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
/**
|
| 62 |
+
* CUDA is available if we compiled with CUDA, and there are one or more
|
| 63 |
+
* devices. If we compiled with CUDA but there is a driver problem, etc.,
|
| 64 |
+
* this function will report CUDA is not available (rather than raise an error.)
|
| 65 |
+
*/
|
| 66 |
+
inline bool is_available() {
|
| 67 |
+
return c10::cuda::device_count() > 0;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
TORCH_CUDA_CPP_API cudaDeviceProp* getCurrentDeviceProperties();
|
| 71 |
+
|
| 72 |
+
TORCH_CUDA_CPP_API int warp_size();
|
| 73 |
+
|
| 74 |
+
TORCH_CUDA_CPP_API cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device);
|
| 75 |
+
|
| 76 |
+
TORCH_CUDA_CPP_API bool canDeviceAccessPeer(
|
| 77 |
+
c10::DeviceIndex device,
|
| 78 |
+
c10::DeviceIndex peer_device);
|
| 79 |
+
|
| 80 |
+
TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
| 81 |
+
|
| 82 |
+
/* Handles */
|
| 83 |
+
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
| 84 |
+
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
| 85 |
+
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
| 86 |
+
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
| 87 |
+
#endif
|
| 88 |
+
|
| 89 |
+
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
| 90 |
+
|
| 91 |
+
#if defined(CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION >= 50300
|
| 92 |
+
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
|
| 93 |
+
#endif
|
| 94 |
+
|
| 95 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDADevice.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/Exceptions.h>
|
| 4 |
+
|
| 5 |
+
#include <cuda.h>
|
| 6 |
+
#include <cuda_runtime.h>
|
| 7 |
+
|
| 8 |
+
namespace at::cuda {
|
| 9 |
+
|
| 10 |
+
inline Device getDeviceFromPtr(void* ptr) {
|
| 11 |
+
cudaPointerAttributes attr{};
|
| 12 |
+
|
| 13 |
+
AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
|
| 14 |
+
|
| 15 |
+
#if !defined(USE_ROCM)
|
| 16 |
+
TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered,
|
| 17 |
+
"The specified pointer resides on host memory and is not registered with any CUDA device.");
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
return {c10::DeviceType::CUDA, static_cast<DeviceIndex>(attr.device)};
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAEvent.h
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/ATenCUDAGeneral.h>
|
| 4 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 5 |
+
#include <c10/core/impl/GPUTrace.h>
|
| 6 |
+
#include <c10/cuda/CUDAStream.h>
|
| 7 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 8 |
+
#include <ATen/cuda/Exceptions.h>
|
| 9 |
+
#include <c10/util/Exception.h>
|
| 10 |
+
|
| 11 |
+
#include <cuda_runtime_api.h>
|
| 12 |
+
|
| 13 |
+
#include <cstdint>
|
| 14 |
+
#include <utility>
|
| 15 |
+
|
| 16 |
+
namespace at::cuda {
|
| 17 |
+
|
| 18 |
+
/*
|
| 19 |
+
* CUDAEvents are movable not copyable wrappers around CUDA's events.
|
| 20 |
+
*
|
| 21 |
+
* CUDAEvents are constructed lazily when first recorded unless it is
|
| 22 |
+
* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
|
| 23 |
+
* device is acquired from the first recording stream. However, if reconstructed
|
| 24 |
+
* from a handle, the device should be explicitly specified; or if ipc_handle() is
|
| 25 |
+
* called before the event is ever recorded, it will use the current device.
|
| 26 |
+
* Later streams that record the event must match this device.
|
| 27 |
+
*/
|
| 28 |
+
struct TORCH_CUDA_CPP_API CUDAEvent {
|
| 29 |
+
// Constructors
|
| 30 |
+
// Default value for `flags` is specified below - it's cudaEventDisableTiming
|
| 31 |
+
CUDAEvent() noexcept = default;
|
| 32 |
+
CUDAEvent(unsigned int flags) noexcept : flags_{flags} {}
|
| 33 |
+
|
| 34 |
+
CUDAEvent(
|
| 35 |
+
DeviceIndex device_index, const cudaIpcEventHandle_t* handle) {
|
| 36 |
+
device_index_ = device_index;
|
| 37 |
+
CUDAGuard guard(device_index_);
|
| 38 |
+
|
| 39 |
+
AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
|
| 40 |
+
is_created_ = true;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// Note: event destruction done on creating device to avoid creating a
|
| 44 |
+
// CUDA context on other devices.
|
| 45 |
+
~CUDAEvent() {
|
| 46 |
+
try {
|
| 47 |
+
if (is_created_) {
|
| 48 |
+
CUDAGuard guard(device_index_);
|
| 49 |
+
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
| 50 |
+
if (C10_UNLIKELY(interp)) {
|
| 51 |
+
(*interp)->trace_gpu_event_deletion(reinterpret_cast<uintptr_t>(event_));
|
| 52 |
+
}
|
| 53 |
+
AT_CUDA_CHECK(cudaEventDestroy(event_));
|
| 54 |
+
}
|
| 55 |
+
} catch (...) { /* No throw */ }
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
CUDAEvent(const CUDAEvent&) = delete;
|
| 59 |
+
CUDAEvent& operator=(const CUDAEvent&) = delete;
|
| 60 |
+
|
| 61 |
+
CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); }
|
| 62 |
+
CUDAEvent& operator=(CUDAEvent&& other) noexcept {
|
| 63 |
+
if (this != &other) {
|
| 64 |
+
moveHelper(std::move(other));
|
| 65 |
+
}
|
| 66 |
+
return *this;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
operator cudaEvent_t() const { return event(); }
|
| 70 |
+
|
| 71 |
+
// Less than operator (to allow use in sets)
|
| 72 |
+
friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
|
| 73 |
+
return left.event_ < right.event_;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
optional<at::Device> device() const {
|
| 77 |
+
if (is_created_) {
|
| 78 |
+
return at::Device(at::kCUDA, device_index_);
|
| 79 |
+
} else {
|
| 80 |
+
return {};
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
bool isCreated() const { return is_created_; }
|
| 85 |
+
DeviceIndex device_index() const {return device_index_;}
|
| 86 |
+
cudaEvent_t event() const { return event_; }
|
| 87 |
+
|
| 88 |
+
// Note: cudaEventQuery can be safely called from any device
|
| 89 |
+
bool query() const {
|
| 90 |
+
if (!is_created_) {
|
| 91 |
+
return true;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
cudaError_t err = cudaEventQuery(event_);
|
| 95 |
+
if (err == cudaSuccess) {
|
| 96 |
+
return true;
|
| 97 |
+
} else if (err != cudaErrorNotReady) {
|
| 98 |
+
C10_CUDA_CHECK(err);
|
| 99 |
+
} else {
|
| 100 |
+
// ignore and clear the error if not ready
|
| 101 |
+
(void)cudaGetLastError();
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
return false;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
void record() { record(getCurrentCUDAStream()); }
|
| 108 |
+
|
| 109 |
+
void recordOnce(const CUDAStream& stream) {
|
| 110 |
+
if (!was_recorded_) record(stream);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// Note: cudaEventRecord must be called on the same device as the event.
|
| 114 |
+
void record(const CUDAStream& stream) {
|
| 115 |
+
if (!is_created_) {
|
| 116 |
+
createEvent(stream.device_index());
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_,
|
| 120 |
+
" does not match recording stream's device ", stream.device_index(), ".");
|
| 121 |
+
CUDAGuard guard(device_index_);
|
| 122 |
+
AT_CUDA_CHECK(cudaEventRecord(event_, stream));
|
| 123 |
+
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
| 124 |
+
if (C10_UNLIKELY(interp)) {
|
| 125 |
+
(*interp)->trace_gpu_event_record(
|
| 126 |
+
reinterpret_cast<uintptr_t>(event_),
|
| 127 |
+
reinterpret_cast<uintptr_t>(stream.stream())
|
| 128 |
+
);
|
| 129 |
+
}
|
| 130 |
+
was_recorded_ = true;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
// Note: cudaStreamWaitEvent must be called on the same device as the stream.
|
| 134 |
+
// The event has no actual GPU resources associated with it.
|
| 135 |
+
void block(const CUDAStream& stream) {
|
| 136 |
+
if (is_created_) {
|
| 137 |
+
CUDAGuard guard(stream.device_index());
|
| 138 |
+
AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
|
| 139 |
+
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
| 140 |
+
if (C10_UNLIKELY(interp)) {
|
| 141 |
+
(*interp)->trace_gpu_event_wait(
|
| 142 |
+
reinterpret_cast<uintptr_t>(event_),
|
| 143 |
+
reinterpret_cast<uintptr_t>(stream.stream())
|
| 144 |
+
);
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// Note: cudaEventElapsedTime can be safely called from any device
|
| 150 |
+
float elapsed_time(const CUDAEvent& other) const {
|
| 151 |
+
TORCH_CHECK(is_created_ && other.isCreated(),
|
| 152 |
+
"Both events must be recorded before calculating elapsed time.");
|
| 153 |
+
float time_ms = 0;
|
| 154 |
+
// raise cudaErrorNotReady if either event is recorded but not yet completed
|
| 155 |
+
AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
|
| 156 |
+
return time_ms;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
// Note: cudaEventSynchronize can be safely called from any device
|
| 160 |
+
void synchronize() const {
|
| 161 |
+
if (is_created_) {
|
| 162 |
+
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
| 163 |
+
if (C10_UNLIKELY(interp)) {
|
| 164 |
+
(*interp)->trace_gpu_event_synchronization(reinterpret_cast<uintptr_t>(event_));
|
| 165 |
+
}
|
| 166 |
+
AT_CUDA_CHECK(cudaEventSynchronize(event_));
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
// Note: cudaIpcGetEventHandle must be called on the same device as the event
|
| 171 |
+
void ipc_handle(cudaIpcEventHandle_t * handle) {
|
| 172 |
+
if (!is_created_) {
|
| 173 |
+
// this CUDAEvent object was initially constructed from flags but event_
|
| 174 |
+
// is not created yet.
|
| 175 |
+
createEvent(getCurrentCUDAStream().device_index());
|
| 176 |
+
}
|
| 177 |
+
CUDAGuard guard(device_index_);
|
| 178 |
+
AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
private:
|
| 182 |
+
unsigned int flags_ = cudaEventDisableTiming;
|
| 183 |
+
bool is_created_ = false;
|
| 184 |
+
bool was_recorded_ = false;
|
| 185 |
+
DeviceIndex device_index_ = -1;
|
| 186 |
+
cudaEvent_t event_{};
|
| 187 |
+
|
| 188 |
+
void createEvent(DeviceIndex device_index) {
|
| 189 |
+
device_index_ = device_index;
|
| 190 |
+
CUDAGuard guard(device_index_);
|
| 191 |
+
AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
|
| 192 |
+
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
| 193 |
+
if (C10_UNLIKELY(interp)) {
|
| 194 |
+
(*interp)->trace_gpu_event_creation(reinterpret_cast<uintptr_t>(event_));
|
| 195 |
+
}
|
| 196 |
+
is_created_ = true;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
void moveHelper(CUDAEvent&& other) {
|
| 200 |
+
std::swap(flags_, other.flags_);
|
| 201 |
+
std::swap(is_created_, other.is_created_);
|
| 202 |
+
std::swap(was_recorded_, other.was_recorded_);
|
| 203 |
+
std::swap(device_index_, other.device_index_);
|
| 204 |
+
std::swap(event_, other.event_);
|
| 205 |
+
}
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDASparse.h
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
#if defined(USE_ROCM)
|
| 5 |
+
#include <hipsparse/hipsparse-version.h>
|
| 6 |
+
#define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
// cuSparse Generic API added in CUDA 10.1
|
| 10 |
+
// Windows support added in CUDA 11.0
|
| 11 |
+
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32)))
|
| 12 |
+
#define AT_USE_CUSPARSE_GENERIC_API() 1
|
| 13 |
+
#else
|
| 14 |
+
#define AT_USE_CUSPARSE_GENERIC_API() 0
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
// cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0
|
| 18 |
+
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
|
| 19 |
+
(CUSPARSE_VERSION < 12000)
|
| 20 |
+
#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1
|
| 21 |
+
#else
|
| 22 |
+
#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
|
| 26 |
+
(CUSPARSE_VERSION >= 12000)
|
| 27 |
+
#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1
|
| 28 |
+
#else
|
| 29 |
+
#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0
|
| 30 |
+
#endif
|
| 31 |
+
|
| 32 |
+
#if defined(USE_ROCM)
|
| 33 |
+
// hipSparse const API added in v2.4.0
|
| 34 |
+
#if HIPSPARSE_VERSION >= 200400
|
| 35 |
+
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1
|
| 36 |
+
#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
|
| 37 |
+
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
| 38 |
+
#else
|
| 39 |
+
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
|
| 40 |
+
#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1
|
| 41 |
+
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
| 42 |
+
#endif
|
| 43 |
+
#else // USE_ROCM
|
| 44 |
+
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
|
| 45 |
+
#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
|
| 46 |
+
#define AT_USE_HIPSPARSE_GENERIC_API() 0
|
| 47 |
+
#endif // USE_ROCM
|
| 48 |
+
|
| 49 |
+
// cuSparse Generic API spsv function was added in CUDA 11.3.0
|
| 50 |
+
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
|
| 51 |
+
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1
|
| 52 |
+
#else
|
| 53 |
+
#define AT_USE_CUSPARSE_GENERIC_SPSV() 0
|
| 54 |
+
#endif
|
| 55 |
+
|
| 56 |
+
// cuSparse Generic API spsm function was added in CUDA 11.3.1
|
| 57 |
+
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600)
|
| 58 |
+
#define AT_USE_CUSPARSE_GENERIC_SPSM() 1
|
| 59 |
+
#else
|
| 60 |
+
#define AT_USE_CUSPARSE_GENERIC_SPSM() 0
|
| 61 |
+
#endif
|
| 62 |
+
|
| 63 |
+
// cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400)
|
| 64 |
+
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400)
|
| 65 |
+
#define AT_USE_CUSPARSE_GENERIC_SDDMM() 1
|
| 66 |
+
#else
|
| 67 |
+
#define AT_USE_CUSPARSE_GENERIC_SDDMM() 0
|
| 68 |
+
#endif
|
| 69 |
+
|
| 70 |
+
// BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0)
|
| 71 |
+
#if defined(CUDART_VERSION) || \
|
| 72 |
+
(defined(USE_ROCM) && ROCM_VERSION >= 40500 )
|
| 73 |
+
#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1
|
| 74 |
+
#else
|
| 75 |
+
#define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0
|
| 76 |
+
#endif
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDASparseDescriptors.h
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/Tensor.h>
|
| 4 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 5 |
+
#include <ATen/cuda/CUDASparse.h>
|
| 6 |
+
|
| 7 |
+
#include <c10/core/ScalarType.h>
|
| 8 |
+
|
| 9 |
+
#if defined(USE_ROCM)
|
| 10 |
+
#include <type_traits>
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
namespace at::cuda::sparse {
|
| 14 |
+
|
| 15 |
+
template <typename T, cusparseStatus_t (*destructor)(T*)>
|
| 16 |
+
struct CuSparseDescriptorDeleter {
|
| 17 |
+
void operator()(T* x) {
|
| 18 |
+
if (x != nullptr) {
|
| 19 |
+
TORCH_CUDASPARSE_CHECK(destructor(x));
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
template <typename T, cusparseStatus_t (*destructor)(T*)>
|
| 25 |
+
class CuSparseDescriptor {
|
| 26 |
+
public:
|
| 27 |
+
T* descriptor() const {
|
| 28 |
+
return descriptor_.get();
|
| 29 |
+
}
|
| 30 |
+
T* descriptor() {
|
| 31 |
+
return descriptor_.get();
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
protected:
|
| 35 |
+
std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
#if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
|
| 39 |
+
template <typename T, cusparseStatus_t (*destructor)(const T*)>
|
| 40 |
+
struct ConstCuSparseDescriptorDeleter {
|
| 41 |
+
void operator()(T* x) {
|
| 42 |
+
if (x != nullptr) {
|
| 43 |
+
TORCH_CUDASPARSE_CHECK(destructor(x));
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
template <typename T, cusparseStatus_t (*destructor)(const T*)>
|
| 49 |
+
class ConstCuSparseDescriptor {
|
| 50 |
+
public:
|
| 51 |
+
T* descriptor() const {
|
| 52 |
+
return descriptor_.get();
|
| 53 |
+
}
|
| 54 |
+
T* descriptor() {
|
| 55 |
+
return descriptor_.get();
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
protected:
|
| 59 |
+
std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
|
| 60 |
+
};
|
| 61 |
+
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS
|
| 62 |
+
|
| 63 |
+
#if defined(USE_ROCM)
|
| 64 |
+
using cusparseMatDescr = std::remove_pointer<hipsparseMatDescr_t>::type;
|
| 65 |
+
using cusparseDnMatDescr = std::remove_pointer<hipsparseDnMatDescr_t>::type;
|
| 66 |
+
using cusparseDnVecDescr = std::remove_pointer<hipsparseDnVecDescr_t>::type;
|
| 67 |
+
using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
|
| 68 |
+
using cusparseSpMatDescr = std::remove_pointer<hipsparseSpMatDescr_t>::type;
|
| 69 |
+
using cusparseSpGEMMDescr = std::remove_pointer<hipsparseSpGEMMDescr_t>::type;
|
| 70 |
+
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
|
| 71 |
+
using bsrsv2Info = std::remove_pointer<bsrsv2Info_t>::type;
|
| 72 |
+
using bsrsm2Info = std::remove_pointer<bsrsm2Info_t>::type;
|
| 73 |
+
#endif
|
| 74 |
+
#endif
|
| 75 |
+
|
| 76 |
+
// NOTE: This is only needed for CUDA 11 and earlier, since CUDA 12 introduced
|
| 77 |
+
// API for const descriptors
|
| 78 |
+
cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr);
|
| 79 |
+
|
| 80 |
+
class TORCH_CUDA_CPP_API CuSparseMatDescriptor
|
| 81 |
+
: public CuSparseDescriptor<cusparseMatDescr, &cusparseDestroyMatDescr> {
|
| 82 |
+
public:
|
| 83 |
+
CuSparseMatDescriptor() {
|
| 84 |
+
cusparseMatDescr_t raw_descriptor;
|
| 85 |
+
TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
|
| 86 |
+
descriptor_.reset(raw_descriptor);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
CuSparseMatDescriptor(bool upper, bool unit) {
|
| 90 |
+
cusparseFillMode_t fill_mode =
|
| 91 |
+
upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
|
| 92 |
+
cusparseDiagType_t diag_type =
|
| 93 |
+
unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
|
| 94 |
+
cusparseMatDescr_t raw_descriptor;
|
| 95 |
+
TORCH_CUDASPARSE_CHECK(cusparseCreateMatDescr(&raw_descriptor));
|
| 96 |
+
TORCH_CUDASPARSE_CHECK(cusparseSetMatFillMode(raw_descriptor, fill_mode));
|
| 97 |
+
TORCH_CUDASPARSE_CHECK(cusparseSetMatDiagType(raw_descriptor, diag_type));
|
| 98 |
+
descriptor_.reset(raw_descriptor);
|
| 99 |
+
}
|
| 100 |
+
};
|
| 101 |
+
|
| 102 |
+
#if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
|
| 103 |
+
|
| 104 |
+
class TORCH_CUDA_CPP_API CuSparseBsrsv2Info
|
| 105 |
+
: public CuSparseDescriptor<bsrsv2Info, &cusparseDestroyBsrsv2Info> {
|
| 106 |
+
public:
|
| 107 |
+
CuSparseBsrsv2Info() {
|
| 108 |
+
bsrsv2Info_t raw_descriptor;
|
| 109 |
+
TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsv2Info(&raw_descriptor));
|
| 110 |
+
descriptor_.reset(raw_descriptor);
|
| 111 |
+
}
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
+
class TORCH_CUDA_CPP_API CuSparseBsrsm2Info
|
| 115 |
+
: public CuSparseDescriptor<bsrsm2Info, &cusparseDestroyBsrsm2Info> {
|
| 116 |
+
public:
|
| 117 |
+
CuSparseBsrsm2Info() {
|
| 118 |
+
bsrsm2Info_t raw_descriptor;
|
| 119 |
+
TORCH_CUDASPARSE_CHECK(cusparseCreateBsrsm2Info(&raw_descriptor));
|
| 120 |
+
descriptor_.reset(raw_descriptor);
|
| 121 |
+
}
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
|
| 125 |
+
|
| 126 |
+
#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
| 127 |
+
|
| 128 |
+
cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
|
| 129 |
+
|
| 130 |
+
#if AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS()
|
| 131 |
+
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
|
| 132 |
+
: public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> {
|
| 133 |
+
public:
|
| 134 |
+
explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor
|
| 138 |
+
: public CuSparseDescriptor<const cusparseDnMatDescr, &destroyConstDnMat> {
|
| 139 |
+
public:
|
| 140 |
+
explicit CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
|
| 141 |
+
cusparseDnMatDescr* unsafe_mutable_descriptor() const {
|
| 142 |
+
return const_cast<cusparseDnMatDescr*>(descriptor());
|
| 143 |
+
}
|
| 144 |
+
cusparseDnMatDescr* unsafe_mutable_descriptor() {
|
| 145 |
+
return const_cast<cusparseDnMatDescr*>(descriptor());
|
| 146 |
+
}
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
|
| 150 |
+
: public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> {
|
| 151 |
+
public:
|
| 152 |
+
explicit CuSparseDnVecDescriptor(const Tensor& input);
|
| 153 |
+
};
|
| 154 |
+
|
| 155 |
+
class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
|
| 156 |
+
: public CuSparseDescriptor<cusparseSpMatDescr, &cusparseDestroySpMat> {};
|
| 157 |
+
|
| 158 |
+
#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
|
| 159 |
+
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
|
| 160 |
+
: public ConstCuSparseDescriptor<
|
| 161 |
+
cusparseDnMatDescr,
|
| 162 |
+
&cusparseDestroyDnMat> {
|
| 163 |
+
public:
|
| 164 |
+
explicit CuSparseDnMatDescriptor(
|
| 165 |
+
const Tensor& input,
|
| 166 |
+
int64_t batch_offset = -1);
|
| 167 |
+
};
|
| 168 |
+
|
| 169 |
+
class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor
|
| 170 |
+
: public ConstCuSparseDescriptor<
|
| 171 |
+
const cusparseDnMatDescr,
|
| 172 |
+
&destroyConstDnMat> {
|
| 173 |
+
public:
|
| 174 |
+
explicit CuSparseConstDnMatDescriptor(
|
| 175 |
+
const Tensor& input,
|
| 176 |
+
int64_t batch_offset = -1);
|
| 177 |
+
cusparseDnMatDescr* unsafe_mutable_descriptor() const {
|
| 178 |
+
return const_cast<cusparseDnMatDescr*>(descriptor());
|
| 179 |
+
}
|
| 180 |
+
cusparseDnMatDescr* unsafe_mutable_descriptor() {
|
| 181 |
+
return const_cast<cusparseDnMatDescr*>(descriptor());
|
| 182 |
+
}
|
| 183 |
+
};
|
| 184 |
+
|
| 185 |
+
class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
|
| 186 |
+
: public ConstCuSparseDescriptor<
|
| 187 |
+
cusparseDnVecDescr,
|
| 188 |
+
&cusparseDestroyDnVec> {
|
| 189 |
+
public:
|
| 190 |
+
explicit CuSparseDnVecDescriptor(const Tensor& input);
|
| 191 |
+
};
|
| 192 |
+
|
| 193 |
+
class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
|
| 194 |
+
: public ConstCuSparseDescriptor<
|
| 195 |
+
cusparseSpMatDescr,
|
| 196 |
+
&cusparseDestroySpMat> {};
|
| 197 |
+
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
|
| 198 |
+
|
| 199 |
+
class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
|
| 200 |
+
: public CuSparseSpMatDescriptor {
|
| 201 |
+
public:
|
| 202 |
+
explicit CuSparseSpMatCsrDescriptor(const Tensor& input, int64_t batch_offset = -1);
|
| 203 |
+
|
| 204 |
+
std::tuple<int64_t, int64_t, int64_t> get_size() {
|
| 205 |
+
int64_t rows, cols, nnz;
|
| 206 |
+
TORCH_CUDASPARSE_CHECK(cusparseSpMatGetSize(
|
| 207 |
+
this->descriptor(),
|
| 208 |
+
&rows,
|
| 209 |
+
&cols,
|
| 210 |
+
&nnz));
|
| 211 |
+
return std::make_tuple(rows, cols, nnz);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
void set_tensor(const Tensor& input) {
|
| 215 |
+
auto crow_indices = input.crow_indices();
|
| 216 |
+
auto col_indices = input.col_indices();
|
| 217 |
+
auto values = input.values();
|
| 218 |
+
|
| 219 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(crow_indices.is_contiguous());
|
| 220 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(col_indices.is_contiguous());
|
| 221 |
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
|
| 222 |
+
TORCH_CUDASPARSE_CHECK(cusparseCsrSetPointers(
|
| 223 |
+
this->descriptor(),
|
| 224 |
+
crow_indices.data_ptr(),
|
| 225 |
+
col_indices.data_ptr(),
|
| 226 |
+
values.data_ptr()));
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
#if AT_USE_CUSPARSE_GENERIC_SPSV()
|
| 230 |
+
void set_mat_fill_mode(bool upper) {
|
| 231 |
+
cusparseFillMode_t fill_mode =
|
| 232 |
+
upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER;
|
| 233 |
+
TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
|
| 234 |
+
this->descriptor(),
|
| 235 |
+
CUSPARSE_SPMAT_FILL_MODE,
|
| 236 |
+
&fill_mode,
|
| 237 |
+
sizeof(fill_mode)));
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
void set_mat_diag_type(bool unit) {
|
| 241 |
+
cusparseDiagType_t diag_type =
|
| 242 |
+
unit ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT;
|
| 243 |
+
TORCH_CUDASPARSE_CHECK(cusparseSpMatSetAttribute(
|
| 244 |
+
this->descriptor(),
|
| 245 |
+
CUSPARSE_SPMAT_DIAG_TYPE,
|
| 246 |
+
&diag_type,
|
| 247 |
+
sizeof(diag_type)));
|
| 248 |
+
}
|
| 249 |
+
#endif
|
| 250 |
+
};
|
| 251 |
+
|
| 252 |
+
#if AT_USE_CUSPARSE_GENERIC_SPSV()
|
| 253 |
+
class TORCH_CUDA_CPP_API CuSparseSpSVDescriptor
|
| 254 |
+
: public CuSparseDescriptor<cusparseSpSVDescr, &cusparseSpSV_destroyDescr> {
|
| 255 |
+
public:
|
| 256 |
+
CuSparseSpSVDescriptor() {
|
| 257 |
+
cusparseSpSVDescr_t raw_descriptor;
|
| 258 |
+
TORCH_CUDASPARSE_CHECK(cusparseSpSV_createDescr(&raw_descriptor));
|
| 259 |
+
descriptor_.reset(raw_descriptor);
|
| 260 |
+
}
|
| 261 |
+
};
|
| 262 |
+
#endif
|
| 263 |
+
|
| 264 |
+
#if AT_USE_CUSPARSE_GENERIC_SPSM()
|
| 265 |
+
class TORCH_CUDA_CPP_API CuSparseSpSMDescriptor
|
| 266 |
+
: public CuSparseDescriptor<cusparseSpSMDescr, &cusparseSpSM_destroyDescr> {
|
| 267 |
+
public:
|
| 268 |
+
CuSparseSpSMDescriptor() {
|
| 269 |
+
cusparseSpSMDescr_t raw_descriptor;
|
| 270 |
+
TORCH_CUDASPARSE_CHECK(cusparseSpSM_createDescr(&raw_descriptor));
|
| 271 |
+
descriptor_.reset(raw_descriptor);
|
| 272 |
+
}
|
| 273 |
+
};
|
| 274 |
+
#endif
|
| 275 |
+
|
| 276 |
+
#if (defined(USE_ROCM) && ROCM_VERSION >= 50200) || !defined(USE_ROCM)
|
| 277 |
+
class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
|
| 278 |
+
: public CuSparseDescriptor<cusparseSpGEMMDescr, &cusparseSpGEMM_destroyDescr> {
|
| 279 |
+
public:
|
| 280 |
+
CuSparseSpGEMMDescriptor() {
|
| 281 |
+
cusparseSpGEMMDescr_t raw_descriptor;
|
| 282 |
+
TORCH_CUDASPARSE_CHECK(cusparseSpGEMM_createDescr(&raw_descriptor));
|
| 283 |
+
descriptor_.reset(raw_descriptor);
|
| 284 |
+
}
|
| 285 |
+
};
|
| 286 |
+
#endif
|
| 287 |
+
|
| 288 |
+
#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
| 289 |
+
|
| 290 |
+
} // namespace at::cuda::sparse
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CUDAUtils.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
|
| 5 |
+
namespace at::cuda {
|
| 6 |
+
|
| 7 |
+
// Check if every tensor in a list of tensors matches the current
|
| 8 |
+
// device.
|
| 9 |
+
inline bool check_device(ArrayRef<Tensor> ts) {
|
| 10 |
+
if (ts.empty()) {
|
| 11 |
+
return true;
|
| 12 |
+
}
|
| 13 |
+
Device curDevice = Device(kCUDA, current_device());
|
| 14 |
+
for (const Tensor& t : ts) {
|
| 15 |
+
if (t.device() != curDevice) return false;
|
| 16 |
+
}
|
| 17 |
+
return true;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/CachingHostAllocator.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Allocator.h>
|
| 4 |
+
#include <c10/cuda/CUDAStream.h>
|
| 5 |
+
|
| 6 |
+
namespace at::cuda {
|
| 7 |
+
|
| 8 |
+
//
|
| 9 |
+
// A caching allocator for CUDA host allocations (pinned memory).
|
| 10 |
+
//
|
| 11 |
+
// This provides a drop-in replacement for THCudaHostAllocator, which re-uses
|
| 12 |
+
// freed pinned (page-locked) memory allocations. This avoids device
|
| 13 |
+
// synchronizations due to cudaFreeHost calls.
|
| 14 |
+
//
|
| 15 |
+
// To ensure correct behavior, THCCachingHostAllocator_recordEvent must be
|
| 16 |
+
// called anytime a pointer from this allocator is used in a cudaMemcpyAsync
|
| 17 |
+
// call between host and device, and passed the corresponding context from the
|
| 18 |
+
// allocation. This is currently invoked by at::native::copy_kernel_cuda.
|
| 19 |
+
//
|
| 20 |
+
// Note that this allocator does not split larger allocations into smaller
|
| 21 |
+
// blocks, unlike the caching device allocator.
|
| 22 |
+
//
|
| 23 |
+
TORCH_CUDA_CPP_API c10::Allocator* getCachingHostAllocator();
|
| 24 |
+
|
| 25 |
+
// Records an event in the specified stream. The allocation corresponding to the
|
| 26 |
+
// input `ptr`/`ctx` will not be re-used until the event has occurred.
|
| 27 |
+
TORCH_CUDA_CPP_API bool
|
| 28 |
+
CachingHostAllocator_recordEvent(void* ptr, void* ctx, c10::cuda::CUDAStream stream);
|
| 29 |
+
|
| 30 |
+
// Releases cached pinned memory allocations via cudaHostFree
|
| 31 |
+
TORCH_CUDA_CPP_API void CachingHostAllocator_emptyCache();
|
| 32 |
+
|
| 33 |
+
inline TORCH_CUDA_CPP_API at::DataPtr HostAlloc(size_t size) {
|
| 34 |
+
return getCachingHostAllocator()->allocate(size);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
} // namespace at::cuda
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/DeviceUtils.cuh
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <c10/util/complex.h>
|
| 5 |
+
#include <c10/util/Half.h>
|
| 6 |
+
|
| 7 |
+
__device__ __forceinline__ unsigned int ACTIVE_MASK()
|
| 8 |
+
{
|
| 9 |
+
#if !defined(USE_ROCM)
|
| 10 |
+
return __activemask();
|
| 11 |
+
#else
|
| 12 |
+
// will be ignored anyway
|
| 13 |
+
return 0xffffffff;
|
| 14 |
+
#endif
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) {
|
| 18 |
+
#if !defined(USE_ROCM)
|
| 19 |
+
return __syncwarp(mask);
|
| 20 |
+
#endif
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
#if defined(USE_ROCM)
|
| 24 |
+
__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
|
| 25 |
+
{
|
| 26 |
+
return __ballot(predicate);
|
| 27 |
+
}
|
| 28 |
+
#else
|
| 29 |
+
__device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
|
| 30 |
+
{
|
| 31 |
+
#if !defined(USE_ROCM)
|
| 32 |
+
return __ballot_sync(mask, predicate);
|
| 33 |
+
#else
|
| 34 |
+
return __ballot(predicate);
|
| 35 |
+
#endif
|
| 36 |
+
}
|
| 37 |
+
#endif
|
| 38 |
+
|
| 39 |
+
template <typename T>
|
| 40 |
+
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
| 41 |
+
{
|
| 42 |
+
#if !defined(USE_ROCM)
|
| 43 |
+
return __shfl_xor_sync(mask, value, laneMask, width);
|
| 44 |
+
#else
|
| 45 |
+
return __shfl_xor(value, laneMask, width);
|
| 46 |
+
#endif
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
template <typename T>
|
| 50 |
+
__device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff)
|
| 51 |
+
{
|
| 52 |
+
#if !defined(USE_ROCM)
|
| 53 |
+
return __shfl_sync(mask, value, srcLane, width);
|
| 54 |
+
#else
|
| 55 |
+
return __shfl(value, srcLane, width);
|
| 56 |
+
#endif
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
template <typename T>
|
| 60 |
+
__device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
|
| 61 |
+
{
|
| 62 |
+
#if !defined(USE_ROCM)
|
| 63 |
+
return __shfl_up_sync(mask, value, delta, width);
|
| 64 |
+
#else
|
| 65 |
+
return __shfl_up(value, delta, width);
|
| 66 |
+
#endif
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
template <typename T>
|
| 70 |
+
__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
|
| 71 |
+
{
|
| 72 |
+
#if !defined(USE_ROCM)
|
| 73 |
+
return __shfl_down_sync(mask, value, delta, width);
|
| 74 |
+
#else
|
| 75 |
+
return __shfl_down(value, delta, width);
|
| 76 |
+
#endif
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
#if defined(USE_ROCM)
|
| 80 |
+
template<>
|
| 81 |
+
__device__ __forceinline__ int64_t WARP_SHFL_DOWN<int64_t>(int64_t value, unsigned int delta, int width , unsigned int mask)
|
| 82 |
+
{
|
| 83 |
+
//(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
|
| 84 |
+
int2 a = *reinterpret_cast<int2*>(&value);
|
| 85 |
+
a.x = __shfl_down(a.x, delta);
|
| 86 |
+
a.y = __shfl_down(a.y, delta);
|
| 87 |
+
return *reinterpret_cast<int64_t*>(&a);
|
| 88 |
+
}
|
| 89 |
+
#endif
|
| 90 |
+
|
| 91 |
+
template<>
|
| 92 |
+
__device__ __forceinline__ c10::Half WARP_SHFL_DOWN<c10::Half>(c10::Half value, unsigned int delta, int width, unsigned int mask)
|
| 93 |
+
{
|
| 94 |
+
return c10::Half(WARP_SHFL_DOWN<unsigned short>(value.x, delta, width, mask), c10::Half::from_bits_t{});
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
template <typename T>
|
| 98 |
+
__device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
|
| 99 |
+
{
|
| 100 |
+
#if !defined(USE_ROCM)
|
| 101 |
+
return c10::complex<T>(
|
| 102 |
+
__shfl_down_sync(mask, value.real_, delta, width),
|
| 103 |
+
__shfl_down_sync(mask, value.imag_, delta, width));
|
| 104 |
+
#else
|
| 105 |
+
return c10::complex<T>(
|
| 106 |
+
__shfl_down(value.real_, delta, width),
|
| 107 |
+
__shfl_down(value.imag_, delta, width));
|
| 108 |
+
#endif
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
/**
|
| 112 |
+
* For CC 3.5+, perform a load using __ldg
|
| 113 |
+
*/
|
| 114 |
+
template <typename T>
|
| 115 |
+
__device__ __forceinline__ T doLdg(const T* p) {
|
| 116 |
+
#if __CUDA_ARCH__ >= 350 && !defined(USE_ROCM)
|
| 117 |
+
return __ldg(p);
|
| 118 |
+
#else
|
| 119 |
+
return *p;
|
| 120 |
+
#endif
|
| 121 |
+
}
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/include/ATen/cuda/EmptyTensor.h
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/core/TensorBase.h>
|
| 3 |
+
|
| 4 |
+
namespace at::detail {
|
| 5 |
+
|
| 6 |
+
TORCH_CUDA_CPP_API TensorBase empty_cuda(
|
| 7 |
+
IntArrayRef size,
|
| 8 |
+
ScalarType dtype,
|
| 9 |
+
c10::optional<Device> device_opt,
|
| 10 |
+
c10::optional<c10::MemoryFormat> memory_format_opt);
|
| 11 |
+
|
| 12 |
+
TORCH_CUDA_CPP_API TensorBase empty_cuda(
|
| 13 |
+
IntArrayRef size,
|
| 14 |
+
c10::optional<ScalarType> dtype_opt,
|
| 15 |
+
c10::optional<Layout> layout_opt,
|
| 16 |
+
c10::optional<Device> device_opt,
|
| 17 |
+
c10::optional<bool> pin_memory_opt,
|
| 18 |
+
c10::optional<c10::MemoryFormat> memory_format_opt);
|
| 19 |
+
|
| 20 |
+
TORCH_CUDA_CPP_API TensorBase empty_cuda(
|
| 21 |
+
IntArrayRef size,
|
| 22 |
+
const TensorOptions &options);
|
| 23 |
+
|
| 24 |
+
TORCH_CUDA_CPP_API TensorBase empty_strided_cuda(
|
| 25 |
+
IntArrayRef size,
|
| 26 |
+
IntArrayRef stride,
|
| 27 |
+
ScalarType dtype,
|
| 28 |
+
c10::optional<Device> device_opt);
|
| 29 |
+
|
| 30 |
+
TORCH_CUDA_CPP_API TensorBase empty_strided_cuda(
|
| 31 |
+
IntArrayRef size,
|
| 32 |
+
IntArrayRef stride,
|
| 33 |
+
c10::optional<ScalarType> dtype_opt,
|
| 34 |
+
c10::optional<Layout> layout_opt,
|
| 35 |
+
c10::optional<Device> device_opt,
|
| 36 |
+
c10::optional<bool> pin_memory_opt);
|
| 37 |
+
|
| 38 |
+
TORCH_CUDA_CPP_API TensorBase empty_strided_cuda(
|
| 39 |
+
IntArrayRef size,
|
| 40 |
+
IntArrayRef stride,
|
| 41 |
+
const TensorOptions &options);
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
} // namespace at::detail
|